mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-21 16:19:56 +00:00
Compare commits
72 Commits
task/align
...
embedded-v
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
412193c602 | ||
|
|
5e67febf57 | ||
|
|
ee348ba007 | ||
|
|
3d3055dc7f | ||
|
|
2f4ddf0796 | ||
|
|
98d533c8e8 | ||
|
|
ef4ea2e311 | ||
|
|
b41d11bbbe | ||
|
|
f37e228cc2 | ||
|
|
640a267556 | ||
|
|
17359cdc1e | ||
|
|
7e5846a1ee | ||
|
|
517bea0daf | ||
|
|
896530fd82 | ||
|
|
354fd004c7 | ||
|
|
c28e41e82b | ||
|
|
02b9fe704b | ||
|
|
5e200fa571 | ||
|
|
7d61975f6c | ||
|
|
62b36112ea | ||
|
|
df9a6fb020 | ||
|
|
b1b04f9ec6 | ||
|
|
fe15688f20 | ||
|
|
2285db2b62 | ||
|
|
b3f0f53a23 | ||
|
|
5eec9962ba | ||
|
|
393c102f45 | ||
|
|
b41fbad5e1 | ||
|
|
24a5f2252c | ||
|
|
9d189bb3e8 | ||
|
|
8e2505b59c | ||
|
|
97bc1eebde | ||
|
|
32a5a061b8 | ||
|
|
d927ef468a | ||
|
|
d3f3e08035 | ||
|
|
6bb66e0fad | ||
|
|
bc407527f4 | ||
|
|
5543404188 | ||
|
|
c2fdf62f1f | ||
|
|
b9f5264e36 | ||
|
|
97d0a6776f | ||
|
|
7e7e056f3a | ||
|
|
785f94d13f | ||
|
|
bfb6750b13 | ||
|
|
f5e1057127 | ||
|
|
ee393d0e62 | ||
|
|
0b8fc5da59 | ||
|
|
2d0a54f31a | ||
|
|
61ec8d67de | ||
|
|
76add0b9b2 | ||
|
|
a11341f57a | ||
|
|
b135d462d6 | ||
|
|
da37a28951 | ||
|
|
4f884d9f30 | ||
|
|
2bed8b641b | ||
|
|
b4f696272a | ||
|
|
6d937af7a0 | ||
|
|
db5b6cfbb7 | ||
|
|
e75948753a | ||
|
|
047cc958b5 | ||
|
|
cd005ef9a9 | ||
|
|
44ed0c1992 | ||
|
|
d6d3fa95c7 | ||
|
|
fa90283781 | ||
|
|
8bf13b0d0c | ||
|
|
a8541a1529 | ||
|
|
94068d3ebc | ||
|
|
738c585ee7 | ||
|
|
9b5541d17d | ||
|
|
7123e6d1f4 | ||
|
|
62cf9e873b | ||
|
|
9f0aa1ce26 |
106
.github/workflows/proto-version-check.yml
vendored
106
.github/workflows/proto-version-check.yml
vendored
@@ -3,74 +3,60 @@ name: Proto Version Check
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "**/*.proto"
|
||||
- "**/*.pb.go"
|
||||
- "**/generate.sh"
|
||||
- "proto-tools.env"
|
||||
- ".github/workflows/proto-version-check.yml"
|
||||
|
||||
jobs:
|
||||
regenerate-and-diff:
|
||||
name: Regenerate proto and verify no drift
|
||||
check-proto-versions:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Load pinned proto toolchain versions
|
||||
run: |
|
||||
# shellcheck source=/dev/null
|
||||
. ./proto-tools.env
|
||||
{
|
||||
echo "PROTOC_VERSION=${PROTOC_VERSION}"
|
||||
echo "PROTOC_GEN_GO_VERSION=${PROTOC_GEN_GO_VERSION}"
|
||||
echo "PROTOC_GEN_GO_GRPC_VERSION=${PROTOC_GEN_GO_GRPC_VERSION}"
|
||||
} >> "$GITHUB_ENV"
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
- name: Check for proto tool version changes
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
script: |
|
||||
const files = await github.paginate(github.rest.pulls.listFiles, {
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: context.issue.number,
|
||||
per_page: 100,
|
||||
});
|
||||
|
||||
- name: Setup protoc
|
||||
uses: arduino/setup-protoc@f4d5893b897028ff5739576ea0409746887fa536 # v3.0.0
|
||||
with:
|
||||
version: ${{ env.PROTOC_VERSION }}
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
const pbFiles = files.filter(f => f.filename.endsWith('.pb.go'));
|
||||
const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename);
|
||||
if (missingPatch.length > 0) {
|
||||
core.setFailed(
|
||||
`Cannot inspect patch data for:\n` +
|
||||
missingPatch.map(f => `- ${f}`).join('\n') +
|
||||
`\nThis can happen with very large PRs. Verify proto versions manually.`
|
||||
);
|
||||
return;
|
||||
}
|
||||
const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
|
||||
const violations = [];
|
||||
|
||||
- name: Install protoc plugins
|
||||
run: |
|
||||
go install "google.golang.org/protobuf/cmd/protoc-gen-go@${PROTOC_GEN_GO_VERSION}"
|
||||
go install "google.golang.org/grpc/cmd/protoc-gen-go-grpc@${PROTOC_GEN_GO_GRPC_VERSION}"
|
||||
echo "$(go env GOPATH)/bin" >> "$GITHUB_PATH"
|
||||
for (const file of pbFiles) {
|
||||
const changed = file.patch
|
||||
.split('\n')
|
||||
.filter(line => versionPattern.test(line));
|
||||
if (changed.length > 0) {
|
||||
violations.push({
|
||||
file: file.filename,
|
||||
lines: changed,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
- name: Verify protoc version matches pin
|
||||
run: |
|
||||
actual=$(protoc --version | awk '{print $2}')
|
||||
if [[ "$actual" != "$PROTOC_VERSION" ]]; then
|
||||
echo "::error::protoc $actual does not match pinned $PROTOC_VERSION"
|
||||
exit 1
|
||||
fi
|
||||
if (violations.length > 0) {
|
||||
const details = violations.map(v =>
|
||||
`${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}`
|
||||
).join('\n\n');
|
||||
|
||||
- name: Regenerate all proto bindings
|
||||
run: |
|
||||
set -euo pipefail
|
||||
for script in \
|
||||
client/proto/generate.sh \
|
||||
shared/signal/proto/generate.sh \
|
||||
shared/management/proto/generate.sh \
|
||||
flow/proto/generate.sh \
|
||||
encryption/testprotos/generate.sh; do
|
||||
echo "::group::$script"
|
||||
bash "$script"
|
||||
echo "::endgroup::"
|
||||
done
|
||||
core.setFailed(
|
||||
`Proto version strings changed in generated files.\n` +
|
||||
`This usually means the wrong protoc or protoc-gen-go version was used.\n` +
|
||||
`Regenerate with the matching tool versions.\n\n` +
|
||||
details
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
- name: Fail if regeneration changed any tracked or untracked file
|
||||
run: |
|
||||
if [[ -n "$(git status --porcelain --untracked-files=all)" ]]; then
|
||||
echo "::error::Generated proto files drift from .proto sources or pinned tool versions."
|
||||
echo "Run the generate.sh scripts locally with the toolchain in proto-tools.env and commit the result."
|
||||
git status --short
|
||||
exit 1
|
||||
fi
|
||||
console.log('No proto version string changes detected');
|
||||
|
||||
4
.github/workflows/wasm-build-validation.yml
vendored
4
.github/workflows/wasm-build-validation.yml
vendored
@@ -61,8 +61,8 @@ jobs:
|
||||
|
||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||
|
||||
if [ ${SIZE} -gt 58720256 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||
if [ ${SIZE} -gt 62914560 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 60MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
153
README.md
153
README.md
@@ -1,134 +1,147 @@
|
||||
|
||||
<div align="center">
|
||||
<p align="center">
|
||||
<img width="234" src="docs/media/logo-full.png" alt="NetBird logo"/>
|
||||
</p>
|
||||
<p align="center">
|
||||
<a href="https://sonarcloud.io/dashboard?id=netbirdio_netbird">
|
||||
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" alt="SonarCloud alert status"/>
|
||||
</a>
|
||||
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" alt="BSD-3 License"/>
|
||||
</a>
|
||||
<br/>
|
||||
<br/>
|
||||
<p align="center">
|
||||
<img width="234" src="docs/media/logo-full.png"/>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://img.shields.io/badge/license-BSD--3-blue)">
|
||||
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" />
|
||||
</a>
|
||||
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://docs.netbird.io/slack-url">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack" alt="NetBird Slack"/>
|
||||
</a>
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
<a href="https://forum.netbird.io">
|
||||
<img src="https://img.shields.io/badge/community%20forum-@netbird-red.svg?logo=discourse" alt="Community forum"/>
|
||||
</a>
|
||||
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
|
||||
</a>
|
||||
<br>
|
||||
<a href="https://gurubase.io/g/netbird">
|
||||
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF" alt="Gurubase: Ask NetBird Guru"/>
|
||||
</a>
|
||||
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
<p align="center">
|
||||
<strong>
|
||||
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
||||
</strong>
|
||||
<strong>
|
||||
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
<strong>
|
||||
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
|
||||
</strong>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
||||
<br/>
|
||||
|
||||
</strong>
|
||||
<br>
|
||||
<strong>
|
||||
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
|
||||
</strong>
|
||||
<br>
|
||||
<br>
|
||||
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
|
||||
New: NetBird terraform provider
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<br>
|
||||
|
||||
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
|
||||
|
||||
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||
|
||||
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
||||
|
||||
### Open Source Network Security in a Single Platform
|
||||
|
||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||
|
||||
### Self-host NetBird (video)
|
||||
|
||||
### Self-Host NetBird (Video)
|
||||
[](https://youtu.be/bZAgpT6nzaQ)
|
||||
|
||||
### Key features
|
||||
|
||||
| Connectivity | Management | Security | Automation | Platforms |
|
||||
|---|---|---|---|---|
|
||||
| ✓ [Kernel WireGuard](https://docs.netbird.io/about-netbird/why-wireguard-with-netbird) | ✓ [Admin Web UI](https://github.com/netbirdio/dashboard) | ✓ [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) | ✓ [Public API](https://docs.netbird.io/api) | ✓ [Linux](https://docs.netbird.io/get-started/install/linux) |
|
||||
| ✓ [Peer-to-peer connections](https://docs.netbird.io/about-netbird/how-netbird-works) | ✓ Auto peer discovery and configuration | ✓ [Access control: groups & rules](https://docs.netbird.io/how-to/manage-network-access) | ✓ [Setup keys for bulk provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) | ✓ [macOS](https://docs.netbird.io/get-started/install/macos) |
|
||||
| ✓ Connection relay fallback | ✓ [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) | ✓ [Activity logging](https://docs.netbird.io/how-to/audit-events-logging) | ✓ [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) | ✓ [Windows](https://docs.netbird.io/get-started/install/windows) |
|
||||
| ✓ [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) | ✓ [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) | ✓ [Traffic events](https://docs.netbird.io/manage/activity/traffic-events-logging) | ✓ [IdP groups sync with JWT](https://docs.netbird.io/manage/team/idp-sync) | ✓ [Android](https://docs.netbird.io/get-started/install/android) |
|
||||
| ✓ [Domain-based DNS routes](https://docs.netbird.io/manage/dns/dns-aliases-for-routed-networks) | ✓ [Custom DNS zones](https://docs.netbird.io/manage/dns/custom-zones) | ✓ [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) | ✓ [Terraform provider](https://registry.terraform.io/providers/netbirdio/netbird/latest) | ✓ [Android TV](https://docs.netbird.io/get-started/install/android-tv) |
|
||||
| ✓ [Exit nodes](https://docs.netbird.io/manage/network-routes/use-cases/exit-nodes) | ✓ [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) | ✓ Peer-to-peer encryption | ✓ [Ansible collection](https://github.com/netbirdio/ansible-netbird) | ✓ [iOS](https://docs.netbird.io/get-started/install/ios) |
|
||||
| ✓ [IPv6 dual-stack overlay](https://docs.netbird.io/manage/settings/ipv6) | ✓ [Multi-account profile switching](https://docs.netbird.io/client/profiles) | ✓ [SSH with central access policies](https://docs.netbird.io/manage/peers/ssh) | | ✓ [Apple TV](https://docs.netbird.io/get-started/install/tvos) |
|
||||
| ✓ [Browser SSH & RDP](https://docs.netbird.io/manage/peers/browser-client) | | ✓ [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) | | ✓ FreeBSD |
|
||||
| ✓ [Reverse proxy with auto-TLS](https://docs.netbird.io/manage/reverse-proxy) | | ✓ [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication) | | ✓ [pfSense](https://docs.netbird.io/get-started/install/pfsense) |
|
||||
| | | | | ✓ [OPNsense](https://docs.netbird.io/get-started/install/opnsense) |
|
||||
| | | | | ✓ [MikroTik RouterOS](https://docs.netbird.io/use-cases/homelab/client-on-mikrotik-router) |
|
||||
| | | | | ✓ OpenWRT |
|
||||
| | | | | ✓ [Synology](https://docs.netbird.io/get-started/install/synology) |
|
||||
| | | | | ✓ [TrueNAS](https://docs.netbird.io/get-started/install/truenas) |
|
||||
| | | | | ✓ [Proxmox](https://docs.netbird.io/get-started/install/proxmox-ve) |
|
||||
| | | | | ✓ [Raspberry Pi](https://docs.netbird.io/get-started/install/raspberrypi) |
|
||||
| | | | | ✓ [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) |
|
||||
| | | | | ✓ [Container](https://docs.netbird.io/get-started/install/docker) |
|
||||
| Connectivity | Management | Security | Automation| Platforms |
|
||||
|----|----|----|----|----|
|
||||
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
|
||||
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
|
||||
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
|
||||
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
|
||||
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
|
||||
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
|
||||
||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
|
||||
||||| <ul><li>- \[x] Docker</ui></li> |
|
||||
|
||||
### Quickstart with NetBird Cloud
|
||||
|
||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install).
|
||||
- Follow the steps to sign up with Google, Microsoft, GitHub or your email address.
|
||||
- Check the NetBird [admin UI](https://app.netbird.io/).
|
||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
|
||||
- Follow the steps to sign-up with Google, Microsoft, GitHub or your email address.
|
||||
- Check NetBird [admin UI](https://app.netbird.io/).
|
||||
- Add more machines.
|
||||
|
||||
### Quickstart with self-hosted NetBird
|
||||
|
||||
This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM. Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IdPs.
|
||||
> This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM.
|
||||
Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IDPs.
|
||||
|
||||
**Infrastructure requirements:**
|
||||
- A Linux VM with at least **1 CPU** and **2 GB** of memory.
|
||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port **3478**.
|
||||
- A **public domain** name pointing to the VM.
|
||||
- A Linux VM with at least **1CPU** and **2GB** of memory.
|
||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port: **3478**.
|
||||
- **Public domain** name pointing to the VM.
|
||||
|
||||
**Software requirements:**
|
||||
- Docker with the Compose plugin (Compose v2 or higher). See the [Docker installation guide](https://docs.docker.com/engine/install/).
|
||||
- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
|
||||
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
|
||||
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
|
||||
- [curl](https://curl.se/) installed.
|
||||
Usually available in the official repositories and can be installed with `sudo apt install curl` or `sudo yum install curl`
|
||||
|
||||
**Steps**
|
||||
- Download and run the installation script:
|
||||
```bash
|
||||
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash
|
||||
```
|
||||
- Once finished, you can manage the resources via `docker-compose`
|
||||
|
||||
### A bit on NetBird internals
|
||||
- Every machine in the network runs the [NetBird agent](client/), which manages WireGuard.
|
||||
- Every agent connects to the [Management Service](management/), which holds network state, manages peer IPs, and distributes updates to agents.
|
||||
- Agents use ICE (via [pion/ice](https://github.com/pion/ice)) to discover connection candidates for peer-to-peer connections.
|
||||
- Candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
||||
- Agents negotiate a connection through the [Signal Service](signal/), exchanging end-to-end encrypted messages with candidates.
|
||||
- When NAT traversal fails (e.g. mobile carrier-grade NAT) and a direct p2p connection isn't possible, the system falls back to a [Relay Service](relay/) and a secure WireGuard tunnel is established through it.
|
||||
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
|
||||
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
|
||||
- NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines.
|
||||
- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
||||
- Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates.
|
||||
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
|
||||
|
||||
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
||||
|
||||
<p float="left" align="middle">
|
||||
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700" alt="NetBird high-level architecture diagram"/>
|
||||
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/>
|
||||
</p>
|
||||
|
||||
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||
|
||||
### Community projects
|
||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) - terminal UI for managing NetBird peers, routes, and settings
|
||||
- [caddy-netbird](https://github.com/lixmal/caddy-netbird) - Caddy plugin that embeds a NetBird client for proxying HTTP and TCP/UDP traffic through NetBird networks
|
||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
|
||||
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings
|
||||
|
||||
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
||||
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
||||
|
||||
### Support acknowledgement
|
||||
|
||||
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by the Federal Ministry of Education and Research of the Federal Republic of Germany. Together with the [CISPA Helmholtz Center for Information Security](https://cispa.de/en), NetBird brings security best practices and simplicity to private networking.
|
||||
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking.
|
||||
|
||||

|
||||
|
||||
### Acknowledgements
|
||||
We build on open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE](https://github.com/pion/ice), and [Rosenpass](https://rosenpass.eu). We greatly appreciate the work these projects are doing, and we'd love it if you could support them too (e.g., by starring or contributing).
|
||||
### Testimonials
|
||||
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution).
|
||||
|
||||
### Legal
|
||||
This repository is licensed under the BSD-3-Clause license, which applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
||||
This repository is licensed under BSD-3-Clause license that applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
||||
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
||||
|
||||
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
||||
|
||||
@@ -361,6 +361,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
req.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
req.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
req.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
@@ -467,6 +470,9 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
ic.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
ic.EnableSSHRoot = &enableSSHRoot
|
||||
@@ -595,6 +601,9 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
loginRequest.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
loginRequest.EnableSSHRoot = &enableSSHRoot
|
||||
|
||||
73
client/cmd/vnc_agent.go
Normal file
73
client/cmd/vnc_agent.go
Normal file
@@ -0,0 +1,73 @@
|
||||
//go:build windows || (darwin && !ios)
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
var vncAgentPort uint16
|
||||
|
||||
func init() {
|
||||
vncAgentCmd.Flags().Uint16Var(&vncAgentPort, "port", 15900, "Port for the VNC agent to listen on")
|
||||
rootCmd.AddCommand(vncAgentCmd)
|
||||
}
|
||||
|
||||
// vncAgentCmd runs a VNC server inside the user's interactive session,
|
||||
// listening on localhost. The NetBird service spawns it: on Windows via
|
||||
// CreateProcessAsUser into the console session, on macOS via
|
||||
// launchctl asuser into the Aqua session.
|
||||
var vncAgentCmd = &cobra.Command{
|
||||
Use: "vnc-agent",
|
||||
Short: "Run VNC capture agent (internal, spawned by service)",
|
||||
Hidden: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
log.SetReportCaller(true)
|
||||
log.SetFormatter(&log.JSONFormatter{})
|
||||
log.SetOutput(os.Stderr)
|
||||
|
||||
log.Infof("VNC agent starting on 127.0.0.1:%d", vncAgentPort)
|
||||
|
||||
token := os.Getenv("NB_VNC_AGENT_TOKEN")
|
||||
if token == "" {
|
||||
return fmt.Errorf("NB_VNC_AGENT_TOKEN not set; agent requires a token from the service")
|
||||
}
|
||||
// Drop the token from our process environment so any child the
|
||||
// agent spawns does not inherit it, and casual debugging tools
|
||||
// that dump /proc/<pid>/environ (or the Windows equivalent) on a
|
||||
// running agent don't surface the loopback shared secret.
|
||||
if err := os.Unsetenv("NB_VNC_AGENT_TOKEN"); err != nil {
|
||||
log.Debugf("unset NB_VNC_AGENT_TOKEN: %v", err)
|
||||
}
|
||||
|
||||
capturer, injector, err := newAgentResources()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// The per-user agent listens only on loopback and is gated by an
|
||||
// agent token shared with the daemon, so no X25519 identity key
|
||||
// is needed; auth is disabled at the RFB layer.
|
||||
srv := vncserver.New(capturer, injector, nil)
|
||||
srv.SetDisableAuth(true)
|
||||
srv.SetAgentToken(token)
|
||||
|
||||
addr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), vncAgentPort)
|
||||
loopback := netip.PrefixFrom(netip.AddrFrom4([4]byte{127, 0, 0, 0}), 8)
|
||||
if err := srv.Start(cmd.Context(), addr, loopback); err != nil {
|
||||
return fmt.Errorf("start vnc server: %w", err)
|
||||
}
|
||||
log.Infof("vnc-agent listening on 127.0.0.1:%d, ready", vncAgentPort)
|
||||
|
||||
<-cmd.Context().Done()
|
||||
log.Info("vnc-agent context cancelled, shutting down")
|
||||
return srv.Stop()
|
||||
},
|
||||
SilenceUsage: true,
|
||||
}
|
||||
18
client/cmd/vnc_agent_darwin.go
Normal file
18
client/cmd/vnc_agent_darwin.go
Normal file
@@ -0,0 +1,18 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newAgentResources() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
|
||||
capturer := vncserver.NewMacPoller()
|
||||
injector, err := vncserver.NewMacInputInjector()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("macOS input injector: %w", err)
|
||||
}
|
||||
return capturer, injector, nil
|
||||
}
|
||||
15
client/cmd/vnc_agent_windows.go
Normal file
15
client/cmd/vnc_agent_windows.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newAgentResources() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
|
||||
sessionID := vncserver.GetCurrentSessionID()
|
||||
log.Infof("VNC agent running in Windows session %d", sessionID)
|
||||
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector(), nil
|
||||
}
|
||||
9
client/cmd/vnc_flags.go
Normal file
9
client/cmd/vnc_flags.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package cmd
|
||||
|
||||
const serverVNCAllowedFlag = "allow-server-vnc"
|
||||
|
||||
var serverVNCAllowed bool
|
||||
|
||||
func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&serverVNCAllowed, serverVNCAllowedFlag, false, "Allow embedded VNC server on peer")
|
||||
}
|
||||
@@ -52,10 +52,9 @@ func (m *externalChainMonitor) start() {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.cancel = cancel
|
||||
done := make(chan struct{})
|
||||
m.done = done
|
||||
m.done = make(chan struct{})
|
||||
|
||||
go m.run(ctx, done)
|
||||
go m.run(ctx)
|
||||
}
|
||||
|
||||
func (m *externalChainMonitor) stop() {
|
||||
@@ -73,8 +72,8 @@ func (m *externalChainMonitor) stop() {
|
||||
<-done
|
||||
}
|
||||
|
||||
func (m *externalChainMonitor) run(ctx context.Context, done chan struct{}) {
|
||||
defer close(done)
|
||||
func (m *externalChainMonitor) run(ctx context.Context) {
|
||||
defer close(m.done)
|
||||
|
||||
bo := &backoff.ExponentialBackOff{
|
||||
InitialInterval: externalMonitorInitInterval,
|
||||
|
||||
@@ -315,6 +315,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
a.config.RosenpassEnabled,
|
||||
a.config.RosenpassPermissive,
|
||||
a.config.ServerSSHAllowed,
|
||||
a.config.ServerVNCAllowed,
|
||||
a.config.DisableClientRoutes,
|
||||
a.config.DisableServerRoutes,
|
||||
a.config.DisableDNS,
|
||||
|
||||
@@ -562,6 +562,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
ServerVNCAllowed: config.ServerVNCAllowed != nil && *config.ServerVNCAllowed,
|
||||
EnableSSHRoot: config.EnableSSHRoot,
|
||||
EnableSSHSFTP: config.EnableSSHSFTP,
|
||||
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
||||
@@ -644,6 +645,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.RosenpassEnabled,
|
||||
config.RosenpassPermissive,
|
||||
config.ServerSSHAllowed,
|
||||
config.ServerVNCAllowed,
|
||||
config.DisableClientRoutes,
|
||||
config.DisableServerRoutes,
|
||||
config.DisableDNS,
|
||||
|
||||
@@ -636,6 +636,9 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
if g.internalConfig.SSHJWTCacheTTL != nil {
|
||||
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
|
||||
}
|
||||
if g.internalConfig.ServerVNCAllowed != nil {
|
||||
configContent.WriteString(fmt.Sprintf("ServerVNCAllowed: %v\n", *g.internalConfig.ServerVNCAllowed))
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||
|
||||
@@ -862,6 +862,7 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
||||
RosenpassEnabled: true,
|
||||
RosenpassPermissive: true,
|
||||
ServerSSHAllowed: &bTrue,
|
||||
ServerVNCAllowed: &bTrue,
|
||||
EnableSSHRoot: &bTrue,
|
||||
EnableSSHSFTP: &bTrue,
|
||||
EnableSSHLocalPortForwarding: &bTrue,
|
||||
|
||||
@@ -123,6 +123,7 @@ type EngineConfig struct {
|
||||
RosenpassPermissive bool
|
||||
|
||||
ServerSSHAllowed bool
|
||||
ServerVNCAllowed bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
@@ -205,6 +206,7 @@ type Engine struct {
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServer sshServer
|
||||
vncSrv vncServer
|
||||
|
||||
statusRecorder *peer.Status
|
||||
|
||||
@@ -320,6 +322,10 @@ func (e *Engine) Stop() error {
|
||||
log.Warnf("failed to stop SSH server: %v", err)
|
||||
}
|
||||
|
||||
if err := e.stopVNCServer(); err != nil {
|
||||
log.Warnf("failed to stop VNC server: %v", err)
|
||||
}
|
||||
|
||||
e.cleanupSSHConfig()
|
||||
|
||||
if e.ingressGatewayMgr != nil {
|
||||
@@ -1010,6 +1016,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -1057,6 +1064,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := e.updateVNC(); err != nil {
|
||||
log.Warnf("failed handling VNC server setup: %v", err)
|
||||
}
|
||||
|
||||
state := e.statusRecorder.GetLocalPeerState()
|
||||
state.IP = e.wgInterface.Address().String()
|
||||
state.IPv6 = e.wgInterface.Address().IPv6String()
|
||||
@@ -1182,6 +1193,7 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -1371,6 +1383,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||
}
|
||||
|
||||
// VNC auth: always sync, including nil so cleared auth on the management
|
||||
// side is applied locally, and so it isn't skipped on the RemotePeersIsEmpty
|
||||
// cleanup path.
|
||||
e.updateVNCServerAuth(networkMap.GetVncAuth())
|
||||
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
|
||||
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
||||
@@ -1826,6 +1843,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
|
||||
236
client/internal/engine_vnc.go
Normal file
236
client/internal/engine_vnc.go
Normal file
@@ -0,0 +1,236 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/metrics"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
const (
|
||||
vncExternalPort uint16 = 5900
|
||||
vncInternalPort uint16 = 25900
|
||||
)
|
||||
|
||||
type vncServer interface {
|
||||
Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error
|
||||
Stop() error
|
||||
ActiveSessions() []vncserver.ActiveSessionInfo
|
||||
}
|
||||
|
||||
func (e *Engine) setupVNCPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
|
||||
return fmt.Errorf("add VNC port redirection: %w", err)
|
||||
}
|
||||
log.Infof("VNC port redirection: %s:%d -> %s:%d", localAddr, vncExternalPort, localAddr, vncInternalPort)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) cleanupVNCPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
|
||||
return fmt.Errorf("remove VNC port redirection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateVNC handles starting/stopping the VNC server based on the config flag.
|
||||
func (e *Engine) updateVNC() error {
|
||||
if !e.config.ServerVNCAllowed {
|
||||
if e.vncSrv != nil {
|
||||
log.Info("VNC server disabled, stopping")
|
||||
}
|
||||
return e.stopVNCServer()
|
||||
}
|
||||
|
||||
if e.config.BlockInbound {
|
||||
log.Info("VNC server disabled because inbound connections are blocked")
|
||||
return e.stopVNCServer()
|
||||
}
|
||||
|
||||
if e.vncSrv != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.startVNCServer()
|
||||
}
|
||||
|
||||
func (e *Engine) startVNCServer() error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wg interface not initialized")
|
||||
}
|
||||
|
||||
capturer, injector, ok := newPlatformVNC()
|
||||
if !ok {
|
||||
log.Debug("VNC server not supported on this platform")
|
||||
return nil
|
||||
}
|
||||
|
||||
netbirdIP := e.wgInterface.Address().IP
|
||||
|
||||
srv := vncserver.New(capturer, injector, e.config.WgPrivateKey[:])
|
||||
if e.clientMetrics != nil {
|
||||
srv.SetSessionRecorder(func(t vncserver.SessionTick) {
|
||||
e.clientMetrics.RecordVNCSessionTick(e.ctx, metrics.VNCSessionTick{
|
||||
Period: t.Period,
|
||||
BytesOut: t.BytesOut,
|
||||
Writes: t.Writes,
|
||||
FBUs: t.FBUs,
|
||||
MaxFBUBytes: t.MaxFBUBytes,
|
||||
MaxFBURects: t.MaxFBURects,
|
||||
MaxWriteBytes: t.MaxWriteBytes,
|
||||
WriteNanos: t.WriteNanos,
|
||||
})
|
||||
})
|
||||
}
|
||||
if vncNeedsServiceMode() {
|
||||
log.Info("VNC: running in Session 0, enabling service mode (agent proxy)")
|
||||
srv.SetServiceMode(true)
|
||||
}
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
srv.SetNetstackNet(netstackNet)
|
||||
}
|
||||
|
||||
listenAddr := netip.AddrPortFrom(netbirdIP, vncInternalPort)
|
||||
network := e.wgInterface.Address().Network
|
||||
if err := srv.Start(e.ctx, listenAddr, network); err != nil {
|
||||
return fmt.Errorf("start VNC server: %w", err)
|
||||
}
|
||||
|
||||
e.vncSrv = srv
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.RegisterNetstackService(nftypes.TCP, vncInternalPort)
|
||||
log.Debugf("registered VNC service with netstack for TCP:%d", vncInternalPort)
|
||||
}
|
||||
}
|
||||
|
||||
if err := e.setupVNCPortRedirection(); err != nil {
|
||||
log.Warnf("setup VNC port redirection: %v", err)
|
||||
}
|
||||
|
||||
log.Info("VNC server enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// updateVNCServerAuth updates VNC fine-grained access control from management.
|
||||
func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
|
||||
if vncAuth == nil || e.vncSrv == nil {
|
||||
return
|
||||
}
|
||||
|
||||
vncSrv, ok := e.vncSrv.(*vncserver.Server)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
protoUsers := vncAuth.GetAuthorizedUsers()
|
||||
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
|
||||
for i, hash := range protoUsers {
|
||||
if len(hash) != 16 {
|
||||
log.Warnf("invalid VNC auth hash length %d, expected 16", len(hash))
|
||||
return
|
||||
}
|
||||
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
|
||||
}
|
||||
|
||||
machineUsers := make(map[string][]uint32)
|
||||
for osUser, indexes := range vncAuth.GetMachineUsers() {
|
||||
machineUsers[osUser] = indexes.GetIndexes()
|
||||
}
|
||||
|
||||
sessionPubKeys := make([]sshauth.SessionPubKey, 0, len(vncAuth.GetSessionPubKeys()))
|
||||
for _, e := range vncAuth.GetSessionPubKeys() {
|
||||
pub := e.GetPubKey()
|
||||
if len(pub) != 32 {
|
||||
log.Warnf("VNC session pubkey wrong length %d", len(pub))
|
||||
continue
|
||||
}
|
||||
hash := e.GetUserIdHash()
|
||||
if len(hash) != 16 {
|
||||
log.Warnf("VNC session user id hash wrong length %d", len(hash))
|
||||
continue
|
||||
}
|
||||
sessionPubKeys = append(sessionPubKeys, sshauth.SessionPubKey{
|
||||
PubKey: pub,
|
||||
UserIDHash: sshuserhash.UserIDHash(hash),
|
||||
})
|
||||
}
|
||||
|
||||
vncSrv.UpdateVNCAuth(&sshauth.Config{
|
||||
AuthorizedUsers: authorizedUsers,
|
||||
MachineUsers: machineUsers,
|
||||
SessionPubKeys: sessionPubKeys,
|
||||
})
|
||||
}
|
||||
|
||||
// GetVNCServerStatus returns whether the VNC server is running and the list
|
||||
// of active VNC sessions.
|
||||
func (e *Engine) GetVNCServerStatus() (enabled bool, sessions []vncserver.ActiveSessionInfo) {
|
||||
if e.vncSrv == nil {
|
||||
return false, nil
|
||||
}
|
||||
return true, e.vncSrv.ActiveSessions()
|
||||
}
|
||||
|
||||
func (e *Engine) stopVNCServer() error {
|
||||
if e.vncSrv == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.cleanupVNCPortRedirection(); err != nil {
|
||||
log.Warnf("cleanup VNC port redirection: %v", err)
|
||||
}
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.UnregisterNetstackService(nftypes.TCP, vncInternalPort)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("stopping VNC server")
|
||||
err := e.vncSrv.Stop()
|
||||
e.vncSrv = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop VNC server: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
31
client/internal/engine_vnc_console_freebsd.go
Normal file
31
client/internal/engine_vnc_console_freebsd.go
Normal file
@@ -0,0 +1,31 @@
|
||||
//go:build freebsd
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
// newConsoleVNC builds the FreeBSD console fallback: vt(4) framebuffer
|
||||
// for capture, /dev/uinput for input. The uinput device requires the
|
||||
// `uinput` kernel module (`kldload uinput`); without it, input init
|
||||
// fails and we drop to a stub injector so the user still gets a
|
||||
// view-only screen mirror.
|
||||
func newConsoleVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
|
||||
poller := vncserver.NewFBPoller("")
|
||||
w, h := poller.Width(), poller.Height()
|
||||
if w == 0 || h == 0 {
|
||||
poller.Close()
|
||||
return nil, nil, fmt.Errorf("vt framebuffer init failed (vt may not allow mmap on this driver)")
|
||||
}
|
||||
if inj, err := vncserver.NewUInputInjector(w, h); err == nil {
|
||||
return poller, inj, nil
|
||||
} else {
|
||||
log.Infof("VNC console: uinput unavailable (%v); view-only mode. Run `kldload uinput` to enable input.", err)
|
||||
return poller, &vncserver.StubInputInjector{}, nil
|
||||
}
|
||||
}
|
||||
30
client/internal/engine_vnc_console_linux.go
Normal file
30
client/internal/engine_vnc_console_linux.go
Normal file
@@ -0,0 +1,30 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
// newConsoleVNC builds a framebuffer + uinput VNC backend for boxes
|
||||
// without a running X server. Used as the auto-fallback when
|
||||
// newPlatformVNC can't reach X. Returns an error when /dev/fb0 or
|
||||
// /dev/uinput aren't usable so the caller can drop back to a stub.
|
||||
func newConsoleVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
|
||||
poller := vncserver.NewFBPoller("")
|
||||
w, h := poller.Width(), poller.Height()
|
||||
if w == 0 || h == 0 {
|
||||
poller.Close()
|
||||
return nil, nil, fmt.Errorf("framebuffer capturer init failed (is /dev/fb0 readable?)")
|
||||
}
|
||||
inj, err := vncserver.NewUInputInjector(w, h)
|
||||
if err != nil {
|
||||
log.Debugf("uinput unavailable, falling back to view-only VNC: %v", err)
|
||||
return poller, &vncserver.StubInputInjector{}, nil
|
||||
}
|
||||
return poller, inj, nil
|
||||
}
|
||||
34
client/internal/engine_vnc_darwin.go
Normal file
34
client/internal/engine_vnc_darwin.go
Normal file
@@ -0,0 +1,34 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
|
||||
capturer := vncserver.NewMacPoller()
|
||||
// Prompt for Screen Recording at server-enable time rather than first
|
||||
// client-connect. The native prompt is far easier for users to act on
|
||||
// in the moment they toggled VNC on than later when "the screen looks
|
||||
// like wallpaper" would otherwise be the only clue.
|
||||
vncserver.PrimeScreenCapturePermission()
|
||||
injector, err := vncserver.NewMacInputInjector()
|
||||
if err != nil {
|
||||
log.Debugf("VNC: macOS input injector: %v", err)
|
||||
return capturer, &vncserver.StubInputInjector{}, true
|
||||
}
|
||||
return capturer, injector, true
|
||||
}
|
||||
|
||||
// vncNeedsServiceMode reports whether the running process is a system
|
||||
// LaunchDaemon (root, parented by launchd). Daemons sit in the global
|
||||
// bootstrap namespace and cannot talk to WindowServer; we route capture
|
||||
// through a per-user agent in that case.
|
||||
func vncNeedsServiceMode() bool {
|
||||
return os.Geteuid() == 0 && os.Getppid() == 1
|
||||
}
|
||||
17
client/internal/engine_vnc_stub.go
Normal file
17
client/internal/engine_vnc_stub.go
Normal file
@@ -0,0 +1,17 @@
|
||||
//go:build js || ios || android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type vncServer interface{}
|
||||
|
||||
func (e *Engine) updateVNC() error { return nil }
|
||||
|
||||
func (e *Engine) updateVNCServerAuth(_ *mgmProto.VNCAuth) {
|
||||
// no-op on platforms without a VNC server
|
||||
}
|
||||
|
||||
func (e *Engine) stopVNCServer() error { return nil }
|
||||
13
client/internal/engine_vnc_windows.go
Normal file
13
client/internal/engine_vnc_windows.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build windows
|
||||
|
||||
package internal
|
||||
|
||||
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
|
||||
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector(), true
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return vncserver.GetCurrentSessionID() == 0
|
||||
}
|
||||
35
client/internal/engine_vnc_x11.go
Normal file
35
client/internal/engine_vnc_x11.go
Normal file
@@ -0,0 +1,35 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
|
||||
// Prefer X11 when an X server is reachable. NewX11InputInjector probes
|
||||
// DISPLAY (and /proc) eagerly, so a non-nil error here means no X.
|
||||
injector, err := vncserver.NewX11InputInjector("")
|
||||
if err == nil {
|
||||
return vncserver.NewX11Poller(""), injector, true
|
||||
}
|
||||
log.Debugf("VNC: X11 not available: %v", err)
|
||||
|
||||
// Fallback for headless / pre-X states (kernel console, login manager
|
||||
// without X, physical server in recovery): stream the framebuffer and
|
||||
// inject input via /dev/uinput.
|
||||
consoleCap, consoleInj, err := newConsoleVNC()
|
||||
if err == nil {
|
||||
log.Infof("VNC: using framebuffer console capture (%dx%d)", consoleCap.Width(), consoleCap.Height())
|
||||
return consoleCap, consoleInj, true
|
||||
}
|
||||
log.Debugf("VNC: framebuffer console fallback unavailable: %v", err)
|
||||
|
||||
return &vncserver.StubCapturer{}, &vncserver.StubInputInjector{}, false
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return false
|
||||
}
|
||||
@@ -120,6 +120,36 @@ func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentI
|
||||
m.trimLocked()
|
||||
}
|
||||
|
||||
func (m *influxDBMetrics) RecordVNCSessionTick(_ context.Context, agentInfo AgentInfo, tick VNCSessionTick) {
|
||||
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s",
|
||||
agentInfo.DeploymentType.String(),
|
||||
agentInfo.Version,
|
||||
agentInfo.OS,
|
||||
agentInfo.Arch,
|
||||
agentInfo.peerID,
|
||||
)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.samples = append(m.samples, influxSample{
|
||||
measurement: "netbird_vnc_traffic",
|
||||
tags: tags,
|
||||
fields: map[string]float64{
|
||||
"period_seconds": tick.Period.Seconds(),
|
||||
"bytes_out": float64(tick.BytesOut),
|
||||
"writes": float64(tick.Writes),
|
||||
"fbus": float64(tick.FBUs),
|
||||
"max_fbu_bytes": float64(tick.MaxFBUBytes),
|
||||
"max_fbu_rects": float64(tick.MaxFBURects),
|
||||
"max_write_bytes": float64(tick.MaxWriteBytes),
|
||||
"write_time_seconds": float64(tick.WriteNanos) / 1e9,
|
||||
},
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
m.trimLocked()
|
||||
}
|
||||
|
||||
func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) {
|
||||
result := "success"
|
||||
if !success {
|
||||
|
||||
@@ -59,6 +59,11 @@ type metricsImplementation interface {
|
||||
// RecordLoginDuration records how long the login to management took
|
||||
RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool)
|
||||
|
||||
// RecordVNCSessionTick records a periodic snapshot of one VNC
|
||||
// session's wire activity. Called once per metricsConn tick interval
|
||||
// (and once at session close), only when the tick saw activity.
|
||||
RecordVNCSessionTick(ctx context.Context, agentInfo AgentInfo, tick VNCSessionTick)
|
||||
|
||||
// Export exports metrics in InfluxDB line protocol format
|
||||
Export(w io.Writer) error
|
||||
|
||||
@@ -78,6 +83,21 @@ type ClientMetrics struct {
|
||||
pushCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// VNCSessionTick is one sampling slice of a VNC session's wire activity.
|
||||
// BytesOut / Writes / FBUs / WriteNanos are deltas observed during this
|
||||
// tick; Max* fields are the high-water marks observed during the tick.
|
||||
// Period is the wall-clock duration the deltas cover.
|
||||
type VNCSessionTick struct {
|
||||
Period time.Duration
|
||||
BytesOut uint64
|
||||
Writes uint64
|
||||
FBUs uint64
|
||||
MaxFBUBytes uint64
|
||||
MaxFBURects uint64
|
||||
MaxWriteBytes uint64
|
||||
WriteNanos uint64
|
||||
}
|
||||
|
||||
// ConnectionStageTimestamps holds timestamps for each connection stage
|
||||
type ConnectionStageTimestamps struct {
|
||||
SignalingReceived time.Time // First signal received from remote peer (both initial and reconnection)
|
||||
@@ -127,6 +147,17 @@ func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Du
|
||||
c.impl.RecordSyncDuration(ctx, agentInfo, duration)
|
||||
}
|
||||
|
||||
// RecordVNCSessionTick records a periodic snapshot of one VNC session.
|
||||
func (c *ClientMetrics) RecordVNCSessionTick(ctx context.Context, tick VNCSessionTick) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.mu.RLock()
|
||||
agentInfo := c.agentInfo
|
||||
c.mu.RUnlock()
|
||||
c.impl.RecordVNCSessionTick(ctx, agentInfo, tick)
|
||||
}
|
||||
|
||||
// RecordLoginDuration records how long the login to management server took
|
||||
func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) {
|
||||
if c == nil {
|
||||
|
||||
@@ -73,6 +73,9 @@ func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.
|
||||
func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) {
|
||||
}
|
||||
|
||||
func (m *mockMetrics) RecordVNCSessionTick(_ context.Context, _ AgentInfo, _ VNCSessionTick) {
|
||||
}
|
||||
|
||||
func (m *mockMetrics) Export(w io.Writer) error {
|
||||
if m.exportData != "" {
|
||||
_, err := w.Write([]byte(m.exportData))
|
||||
|
||||
@@ -65,6 +65,7 @@ type ConfigInput struct {
|
||||
StateFilePath string
|
||||
PreSharedKey *string
|
||||
ServerSSHAllowed *bool
|
||||
ServerVNCAllowed *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
@@ -116,6 +117,7 @@ type Config struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed *bool
|
||||
ServerVNCAllowed *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
@@ -418,6 +420,21 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.ServerVNCAllowed != nil {
|
||||
if config.ServerVNCAllowed == nil || *input.ServerVNCAllowed != *config.ServerVNCAllowed {
|
||||
if *input.ServerVNCAllowed {
|
||||
log.Infof("enabling VNC server")
|
||||
} else {
|
||||
log.Infof("disabling VNC server")
|
||||
}
|
||||
config.ServerVNCAllowed = input.ServerVNCAllowed
|
||||
updated = true
|
||||
}
|
||||
} else if config.ServerVNCAllowed == nil {
|
||||
config.ServerVNCAllowed = util.True()
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
||||
if *input.EnableSSHRoot {
|
||||
log.Infof("enabling SSH root login")
|
||||
|
||||
@@ -74,6 +74,14 @@ func New(filePath string) *Manager {
|
||||
}
|
||||
}
|
||||
|
||||
// FilePath returns the path of the underlying state file.
|
||||
func (m *Manager) FilePath() string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
return m.filePath
|
||||
}
|
||||
|
||||
// Start starts the state manager periodic save routine
|
||||
func (m *Manager) Start() {
|
||||
if m == nil {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -205,6 +205,8 @@ message LoginRequest {
|
||||
optional bool disableSSHAuth = 38;
|
||||
optional int32 sshJWTCacheTTL = 39;
|
||||
optional bool disable_ipv6 = 40;
|
||||
|
||||
optional bool serverVNCAllowed = 41;
|
||||
}
|
||||
|
||||
message LoginResponse {
|
||||
@@ -314,6 +316,8 @@ message GetConfigResponse {
|
||||
int32 sshJWTCacheTTL = 26;
|
||||
|
||||
bool disable_ipv6 = 27;
|
||||
|
||||
bool serverVNCAllowed = 28;
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
@@ -394,6 +398,22 @@ message SSHServerState {
|
||||
repeated SSHSessionInfo sessions = 2;
|
||||
}
|
||||
|
||||
// VNCSessionInfo contains information about an active VNC session
|
||||
message VNCSessionInfo {
|
||||
string remoteAddress = 1;
|
||||
string mode = 2;
|
||||
string username = 3;
|
||||
// userID is the Noise-verified session identity (hashed user ID from
|
||||
// the ACL session-key entry), empty when auth is disabled.
|
||||
string userID = 4;
|
||||
}
|
||||
|
||||
// VNCServerState contains the latest state of the VNC server
|
||||
message VNCServerState {
|
||||
bool enabled = 1;
|
||||
repeated VNCSessionInfo sessions = 2;
|
||||
}
|
||||
|
||||
// FullStatus contains the full state held by the Status instance
|
||||
message FullStatus {
|
||||
ManagementState managementState = 1;
|
||||
@@ -408,6 +428,7 @@ message FullStatus {
|
||||
|
||||
bool lazyConnectionEnabled = 9;
|
||||
SSHServerState sshServerState = 10;
|
||||
VNCServerState vncServerState = 11;
|
||||
}
|
||||
|
||||
// Networks
|
||||
@@ -678,6 +699,8 @@ message SetConfigRequest {
|
||||
optional bool disableSSHAuth = 33;
|
||||
optional int32 sshJWTCacheTTL = 34;
|
||||
optional bool disable_ipv6 = 35;
|
||||
|
||||
optional bool serverVNCAllowed = 36;
|
||||
}
|
||||
|
||||
message SetConfigResponse{}
|
||||
|
||||
@@ -9,21 +9,9 @@ then
|
||||
fi
|
||||
|
||||
old_pwd=$(pwd)
|
||||
script_path=$(dirname "$(realpath "$0")")
|
||||
script_path=$(dirname $(realpath "$0"))
|
||||
cd "$script_path"
|
||||
|
||||
repo_root=$(git rev-parse --show-toplevel)
|
||||
# shellcheck source=/dev/null
|
||||
. "$repo_root/proto-tools.env"
|
||||
|
||||
actual_protoc=$(protoc --version | awk '{print $2}')
|
||||
if [[ "$actual_protoc" != "$PROTOC_VERSION" ]]; then
|
||||
echo "ERROR: protoc version $actual_protoc differs from pinned $PROTOC_VERSION" >&2
|
||||
echo "Install protoc $PROTOC_VERSION from https://github.com/protocolbuffers/protobuf/releases" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
go install "google.golang.org/protobuf/cmd/protoc-gen-go@${PROTOC_GEN_GO_VERSION}"
|
||||
go install "google.golang.org/grpc/cmd/protoc-gen-go-grpc@${PROTOC_GEN_GO_GRPC_VERSION}"
|
||||
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6
|
||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
|
||||
protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional
|
||||
cd "$old_pwd"
|
||||
|
||||
@@ -376,6 +376,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
config.RosenpassPermissive = msg.RosenpassPermissive
|
||||
config.DisableAutoConnect = msg.DisableAutoConnect
|
||||
config.ServerSSHAllowed = msg.ServerSSHAllowed
|
||||
config.ServerVNCAllowed = msg.ServerVNCAllowed
|
||||
config.NetworkMonitor = msg.NetworkMonitor
|
||||
config.DisableClientRoutes = msg.DisableClientRoutes
|
||||
config.DisableServerRoutes = msg.DisableServerRoutes
|
||||
@@ -1136,6 +1137,7 @@ func (s *Server) Status(
|
||||
pbFullStatus := fullStatus.ToProto()
|
||||
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||
pbFullStatus.SshServerState = s.getSSHServerState()
|
||||
pbFullStatus.VncServerState = s.getVNCServerState()
|
||||
statusResponse.FullStatus = pbFullStatus
|
||||
}
|
||||
|
||||
@@ -1175,6 +1177,37 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
|
||||
return sshServerState
|
||||
}
|
||||
|
||||
// getVNCServerState retrieves the current VNC server state.
|
||||
func (s *Server) getVNCServerState() *proto.VNCServerState {
|
||||
s.mutex.Lock()
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
enabled, sessions := engine.GetVNCServerStatus()
|
||||
pbSessions := make([]*proto.VNCSessionInfo, 0, len(sessions))
|
||||
for _, sess := range sessions {
|
||||
pbSessions = append(pbSessions, &proto.VNCSessionInfo{
|
||||
RemoteAddress: sess.RemoteAddress,
|
||||
Mode: sess.Mode,
|
||||
Username: sess.Username,
|
||||
UserID: sess.UserID,
|
||||
})
|
||||
}
|
||||
return &proto.VNCServerState{
|
||||
Enabled: enabled,
|
||||
Sessions: pbSessions,
|
||||
}
|
||||
}
|
||||
|
||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||
func (s *Server) GetPeerSSHHostKey(
|
||||
ctx context.Context,
|
||||
@@ -1531,6 +1564,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
Mtu: int64(cfg.MTU),
|
||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||
ServerVNCAllowed: cfg.ServerVNCAllowed != nil && *cfg.ServerVNCAllowed,
|
||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||
|
||||
@@ -58,6 +58,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
rosenpassEnabled := true
|
||||
rosenpassPermissive := true
|
||||
serverSSHAllowed := true
|
||||
serverVNCAllowed := true
|
||||
interfaceName := "utun100"
|
||||
wireguardPort := int64(51820)
|
||||
preSharedKey := "test-psk"
|
||||
@@ -83,6 +84,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
RosenpassPermissive: &rosenpassPermissive,
|
||||
ServerSSHAllowed: &serverSSHAllowed,
|
||||
ServerVNCAllowed: &serverVNCAllowed,
|
||||
InterfaceName: &interfaceName,
|
||||
WireguardPort: &wireguardPort,
|
||||
OptionalPreSharedKey: &preSharedKey,
|
||||
@@ -127,6 +129,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
|
||||
require.NotNil(t, cfg.ServerSSHAllowed)
|
||||
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
|
||||
require.NotNil(t, cfg.ServerVNCAllowed)
|
||||
require.Equal(t, serverVNCAllowed, *cfg.ServerVNCAllowed)
|
||||
require.Equal(t, interfaceName, cfg.WgIface)
|
||||
require.Equal(t, int(wireguardPort), cfg.WgPort)
|
||||
require.Equal(t, preSharedKey, cfg.PreSharedKey)
|
||||
@@ -179,6 +183,7 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
||||
"RosenpassEnabled": true,
|
||||
"RosenpassPermissive": true,
|
||||
"ServerSSHAllowed": true,
|
||||
"ServerVNCAllowed": true,
|
||||
"InterfaceName": true,
|
||||
"WireguardPort": true,
|
||||
"OptionalPreSharedKey": true,
|
||||
@@ -240,6 +245,7 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
||||
"enable-rosenpass": "RosenpassEnabled",
|
||||
"rosenpass-permissive": "RosenpassPermissive",
|
||||
"allow-server-ssh": "ServerSSHAllowed",
|
||||
"allow-server-vnc": "ServerVNCAllowed",
|
||||
"interface-name": "InterfaceName",
|
||||
"wireguard-port": "WireguardPort",
|
||||
"preshared-key": "OptionalPreSharedKey",
|
||||
|
||||
@@ -15,13 +15,16 @@ const (
|
||||
DefaultUserIDClaim = "sub"
|
||||
// Wildcard is a special user ID that matches all users
|
||||
Wildcard = "*"
|
||||
// sessionPubKeyLen is the size of an X25519 static public key in bytes.
|
||||
sessionPubKeyLen = 32
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyUserID = errors.New("JWT user ID is empty")
|
||||
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
|
||||
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
|
||||
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
|
||||
ErrEmptyUserID = errors.New("JWT user ID is empty")
|
||||
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
|
||||
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
|
||||
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
|
||||
ErrSessionKeyNotKnown = errors.New("session pubkey not registered")
|
||||
)
|
||||
|
||||
// Authorizer handles SSH fine-grained access control authorization
|
||||
@@ -35,6 +38,12 @@ type Authorizer struct {
|
||||
// machineUsers maps OS login usernames to lists of authorized user indexes
|
||||
machineUsers map[string][]uint32
|
||||
|
||||
// sessionPubKeys maps an X25519 static public key (as map-safe
|
||||
// array) to the hashed user identity that key authenticates as.
|
||||
// Populated from management's temporary-access flow; used by VNC to
|
||||
// authenticate via the Noise_IK handshake.
|
||||
sessionPubKeys map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash
|
||||
|
||||
// mu protects the list of users
|
||||
mu sync.RWMutex
|
||||
}
|
||||
@@ -50,13 +59,25 @@ type Config struct {
|
||||
// MachineUsers maps OS login usernames to indexes in AuthorizedUsers
|
||||
// If a user wants to login as a specific OS user, their index must be in the corresponding list
|
||||
MachineUsers map[string][]uint32
|
||||
|
||||
// SessionPubKeys binds ephemeral X25519 static public keys to hashed
|
||||
// user identities. Populated for VNC; ignored on the SSH side.
|
||||
SessionPubKeys []SessionPubKey
|
||||
}
|
||||
|
||||
// SessionPubKey is a single ephemeral-key entry: the 32-byte X25519
|
||||
// static public key plus the hashed user identity it authenticates as.
|
||||
type SessionPubKey struct {
|
||||
PubKey []byte
|
||||
UserIDHash sshuserhash.UserIDHash
|
||||
}
|
||||
|
||||
// NewAuthorizer creates a new SSH authorizer with empty configuration
|
||||
func NewAuthorizer() *Authorizer {
|
||||
a := &Authorizer{
|
||||
userIDClaim: DefaultUserIDClaim,
|
||||
machineUsers: make(map[string][]uint32),
|
||||
userIDClaim: DefaultUserIDClaim,
|
||||
machineUsers: make(map[string][]uint32),
|
||||
sessionPubKeys: make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash),
|
||||
}
|
||||
|
||||
return a
|
||||
@@ -72,6 +93,7 @@ func (a *Authorizer) Update(config *Config) {
|
||||
a.userIDClaim = DefaultUserIDClaim
|
||||
a.authorizedUsers = []sshuserhash.UserIDHash{}
|
||||
a.machineUsers = make(map[string][]uint32)
|
||||
a.sessionPubKeys = make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash)
|
||||
log.Info("SSH authorization cleared")
|
||||
return
|
||||
}
|
||||
@@ -94,8 +116,29 @@ func (a *Authorizer) Update(config *Config) {
|
||||
}
|
||||
a.machineUsers = machineUsers
|
||||
|
||||
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings",
|
||||
len(config.AuthorizedUsers), len(machineUsers))
|
||||
sessionPubKeys := make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash, len(config.SessionPubKeys))
|
||||
conflicted := make(map[[sessionPubKeyLen]byte]struct{})
|
||||
for _, e := range config.SessionPubKeys {
|
||||
if len(e.PubKey) != sessionPubKeyLen {
|
||||
continue
|
||||
}
|
||||
var key [sessionPubKeyLen]byte
|
||||
copy(key[:], e.PubKey)
|
||||
if _, bad := conflicted[key]; bad {
|
||||
continue
|
||||
}
|
||||
if existing, ok := sessionPubKeys[key]; ok && existing != e.UserIDHash {
|
||||
log.Warnf("SSH auth: session pubkey bound to conflicting user hashes; dropping binding")
|
||||
delete(sessionPubKeys, key)
|
||||
conflicted[key] = struct{}{}
|
||||
continue
|
||||
}
|
||||
sessionPubKeys[key] = e.UserIDHash
|
||||
}
|
||||
a.sessionPubKeys = sessionPubKeys
|
||||
|
||||
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings, %d session pubkeys",
|
||||
len(config.AuthorizedUsers), len(machineUsers), len(sessionPubKeys))
|
||||
}
|
||||
|
||||
// Authorize validates if a user is authorized to login as the specified OS user.
|
||||
@@ -155,6 +198,38 @@ func (a *Authorizer) GetUserIDClaim() string {
|
||||
return a.userIDClaim
|
||||
}
|
||||
|
||||
// LookupSessionKey resolves a Noise-verified static public key to the
|
||||
// hashed user identity registered with it. Fails closed when the key is
|
||||
// unknown.
|
||||
func (a *Authorizer) LookupSessionKey(pubKey []byte) (sshuserhash.UserIDHash, error) {
|
||||
var zero sshuserhash.UserIDHash
|
||||
if len(pubKey) != sessionPubKeyLen {
|
||||
return zero, fmt.Errorf("session pubkey wrong length: %d", len(pubKey))
|
||||
}
|
||||
var key [sessionPubKeyLen]byte
|
||||
copy(key[:], pubKey)
|
||||
a.mu.RLock()
|
||||
hash, ok := a.sessionPubKeys[key]
|
||||
a.mu.RUnlock()
|
||||
if !ok {
|
||||
return zero, ErrSessionKeyNotKnown
|
||||
}
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
// AuthorizeOSUserBySessionKey resolves the OS-user mapping for a session
|
||||
// key. Mirrors Authorize but skips the JWT-hash step since the key has
|
||||
// already been verified and the user identity hash is in hand.
|
||||
func (a *Authorizer) AuthorizeOSUserBySessionKey(userIDHash sshuserhash.UserIDHash, osUsername string) (string, error) {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
userIndex, found := a.findUserIndex(userIDHash)
|
||||
if !found {
|
||||
return "", fmt.Errorf("session user (hash: %s) not in authorized list for OS user %q: %w", userIDHash, osUsername, ErrUserNotAuthorized)
|
||||
}
|
||||
return a.checkMachineUserMapping("session", osUsername, userIndex)
|
||||
}
|
||||
|
||||
// findUserIndex finds the index of a hashed user ID in the authorized users list
|
||||
// Returns the index and true if found, 0 and false if not found
|
||||
func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -610,3 +611,61 @@ func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
|
||||
}
|
||||
|
||||
func TestAuthorizer_LookupSessionKey_Valid(t *testing.T) {
|
||||
pub := bytesRepeat(0x11, sessionPubKeyLen)
|
||||
userHash, err := sshauth.HashUserID("alice")
|
||||
require.NoError(t, err)
|
||||
|
||||
a := NewAuthorizer()
|
||||
a.Update(&Config{
|
||||
AuthorizedUsers: []sshauth.UserIDHash{userHash},
|
||||
MachineUsers: map[string][]uint32{Wildcard: {0}},
|
||||
SessionPubKeys: []SessionPubKey{{PubKey: pub, UserIDHash: userHash}},
|
||||
})
|
||||
|
||||
got, err := a.LookupSessionKey(pub)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userHash, got)
|
||||
|
||||
if _, err := a.AuthorizeOSUserBySessionKey(got, "alice"); err != nil {
|
||||
t.Fatalf("AuthorizeOSUserBySessionKey: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizer_LookupSessionKey_UnknownPub(t *testing.T) {
|
||||
a := NewAuthorizer()
|
||||
a.Update(&Config{})
|
||||
_, err := a.LookupSessionKey(bytesRepeat(0x22, sessionPubKeyLen))
|
||||
require.ErrorIs(t, err, ErrSessionKeyNotKnown)
|
||||
}
|
||||
|
||||
func TestAuthorizer_LookupSessionKey_WrongLength(t *testing.T) {
|
||||
a := NewAuthorizer()
|
||||
_, err := a.LookupSessionKey([]byte("short"))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAuthorizer_LookupSessionKey_UpdateClears(t *testing.T) {
|
||||
pub := bytesRepeat(0x33, sessionPubKeyLen)
|
||||
userHash, err := sshauth.HashUserID("alice")
|
||||
require.NoError(t, err)
|
||||
|
||||
a := NewAuthorizer()
|
||||
a.Update(&Config{SessionPubKeys: []SessionPubKey{{PubKey: pub, UserIDHash: userHash}}})
|
||||
if _, err := a.LookupSessionKey(pub); err != nil {
|
||||
t.Fatalf("setup lookup: %v", err)
|
||||
}
|
||||
a.Update(&Config{})
|
||||
if _, err := a.LookupSessionKey(pub); !errors.Is(err, ErrSessionKeyNotKnown) {
|
||||
t.Fatalf("expected ErrSessionKeyNotKnown, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func bytesRepeat(b byte, n int) []byte {
|
||||
out := make([]byte, n)
|
||||
for i := range out {
|
||||
out[i] = b
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -131,6 +131,18 @@ type SSHServerStateOutput struct {
|
||||
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
|
||||
}
|
||||
|
||||
type VNCSessionOutput struct {
|
||||
RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"`
|
||||
Mode string `json:"mode" yaml:"mode"`
|
||||
Username string `json:"username,omitempty" yaml:"username,omitempty"`
|
||||
UserID string `json:"userID,omitempty" yaml:"userID,omitempty"`
|
||||
}
|
||||
|
||||
type VNCServerStateOutput struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Sessions []VNCSessionOutput `json:"sessions" yaml:"sessions"`
|
||||
}
|
||||
|
||||
type OutputOverview struct {
|
||||
Peers PeersStateOutput `json:"peers" yaml:"peers"`
|
||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||
@@ -153,6 +165,7 @@ type OutputOverview struct {
|
||||
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
|
||||
ProfileName string `json:"profileName" yaml:"profileName"`
|
||||
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
|
||||
VNCServerState VNCServerStateOutput `json:"vncServer" yaml:"vncServer"`
|
||||
}
|
||||
|
||||
// ConvertToStatusOutputOverview converts protobuf status to the output overview.
|
||||
@@ -173,6 +186,7 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
|
||||
|
||||
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
||||
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
|
||||
vncServerOverview := mapVNCServer(pbFullStatus.GetVncServerState())
|
||||
peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter)
|
||||
|
||||
overview := OutputOverview{
|
||||
@@ -197,6 +211,7 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
|
||||
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
|
||||
ProfileName: opts.ProfileName,
|
||||
SSHServerState: sshServerOverview,
|
||||
VNCServerState: vncServerOverview,
|
||||
}
|
||||
|
||||
if opts.Anonymize {
|
||||
@@ -271,6 +286,25 @@ func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput {
|
||||
}
|
||||
}
|
||||
|
||||
func mapVNCServer(state *proto.VNCServerState) VNCServerStateOutput {
|
||||
if state == nil {
|
||||
return VNCServerStateOutput{Sessions: []VNCSessionOutput{}}
|
||||
}
|
||||
sessions := make([]VNCSessionOutput, 0, len(state.GetSessions()))
|
||||
for _, sess := range state.GetSessions() {
|
||||
sessions = append(sessions, VNCSessionOutput{
|
||||
RemoteAddress: sess.GetRemoteAddress(),
|
||||
Mode: sess.GetMode(),
|
||||
Username: sess.GetUsername(),
|
||||
UserID: sess.GetUserID(),
|
||||
})
|
||||
}
|
||||
return VNCServerStateOutput{
|
||||
Enabled: state.GetEnabled(),
|
||||
Sessions: sessions,
|
||||
}
|
||||
}
|
||||
|
||||
func mapPeers(
|
||||
peers []*proto.PeerState,
|
||||
statusFilter string,
|
||||
@@ -533,6 +567,34 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
}
|
||||
}
|
||||
|
||||
vncServerStatus := "Disabled"
|
||||
if o.VNCServerState.Enabled {
|
||||
vncSessionCount := len(o.VNCServerState.Sessions)
|
||||
if vncSessionCount > 0 {
|
||||
sessionWord := "session"
|
||||
if vncSessionCount > 1 {
|
||||
sessionWord = "sessions"
|
||||
}
|
||||
vncServerStatus = fmt.Sprintf("Enabled (%d active %s)", vncSessionCount, sessionWord)
|
||||
} else {
|
||||
vncServerStatus = "Enabled"
|
||||
}
|
||||
|
||||
if showSSHSessions && vncSessionCount > 0 {
|
||||
for _, sess := range o.VNCServerState.Sessions {
|
||||
var line string
|
||||
if sess.UserID != "" {
|
||||
line = fmt.Sprintf("[%s@%s -> %s] mode=%s",
|
||||
sess.UserID, sess.RemoteAddress, sess.Username, sess.Mode)
|
||||
} else {
|
||||
line = fmt.Sprintf("[%s] mode=%s user=%s",
|
||||
sess.RemoteAddress, sess.Mode, sess.Username)
|
||||
}
|
||||
vncServerStatus += "\n " + line
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
|
||||
|
||||
var forwardingRulesString string
|
||||
@@ -563,6 +625,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
"Quantum resistance: %s\n"+
|
||||
"Lazy connection: %s\n"+
|
||||
"SSH Server: %s\n"+
|
||||
"VNC Server: %s\n"+
|
||||
"Networks: %s\n"+
|
||||
"%s"+
|
||||
"Peers count: %s\n",
|
||||
@@ -581,6 +644,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
rosenpassEnabledStatus,
|
||||
lazyConnectionEnabledStatus,
|
||||
sshServerStatus,
|
||||
vncServerStatus,
|
||||
networks,
|
||||
forwardingRulesString,
|
||||
peersCountString,
|
||||
@@ -960,6 +1024,19 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
overview.Relays.Details[i] = detail
|
||||
}
|
||||
|
||||
anonymizeNSServerGroups(a, overview)
|
||||
|
||||
for i, route := range overview.Networks {
|
||||
overview.Networks[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||
|
||||
anonymizeEvents(a, overview)
|
||||
anonymizeServerSessions(a, overview)
|
||||
}
|
||||
|
||||
func anonymizeNSServerGroups(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
for i, nsGroup := range overview.NSServerGroups {
|
||||
for j, domain := range nsGroup.Domains {
|
||||
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
|
||||
@@ -971,13 +1048,9 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i, route := range overview.Networks {
|
||||
overview.Networks[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||
|
||||
func anonymizeEvents(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
for i, event := range overview.Events {
|
||||
overview.Events[i].Message = a.AnonymizeString(event.Message)
|
||||
overview.Events[i].UserMessage = a.AnonymizeString(event.UserMessage)
|
||||
@@ -986,13 +1059,23 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
event.Metadata[k] = a.AnonymizeString(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeRemoteAddress(a *anonymize.Anonymizer, addr string) string {
|
||||
if host, port, err := net.SplitHostPort(addr); err == nil {
|
||||
return fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
|
||||
}
|
||||
return a.AnonymizeIPString(addr)
|
||||
}
|
||||
|
||||
func anonymizeServerSessions(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
for i, session := range overview.SSHServerState.Sessions {
|
||||
if host, port, err := net.SplitHostPort(session.RemoteAddress); err == nil {
|
||||
overview.SSHServerState.Sessions[i].RemoteAddress = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
|
||||
} else {
|
||||
overview.SSHServerState.Sessions[i].RemoteAddress = a.AnonymizeIPString(session.RemoteAddress)
|
||||
}
|
||||
overview.SSHServerState.Sessions[i].RemoteAddress = anonymizeRemoteAddress(a, session.RemoteAddress)
|
||||
overview.SSHServerState.Sessions[i].Command = a.AnonymizeString(session.Command)
|
||||
}
|
||||
for i, sess := range overview.VNCServerState.Sessions {
|
||||
overview.VNCServerState.Sessions[i].RemoteAddress = anonymizeRemoteAddress(a, sess.RemoteAddress)
|
||||
overview.VNCServerState.Sessions[i].Username = a.AnonymizeString(sess.Username)
|
||||
overview.VNCServerState.Sessions[i].UserID = a.AnonymizeString(sess.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,6 +240,10 @@ var overview = OutputOverview{
|
||||
Enabled: false,
|
||||
Sessions: []SSHSessionOutput{},
|
||||
},
|
||||
VNCServerState: VNCServerStateOutput{
|
||||
Enabled: false,
|
||||
Sessions: []VNCSessionOutput{},
|
||||
},
|
||||
}
|
||||
|
||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||
@@ -404,6 +408,10 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"sshServer":{
|
||||
"enabled":false,
|
||||
"sessions":[]
|
||||
},
|
||||
"vncServer":{
|
||||
"enabled":false,
|
||||
"sessions":[]
|
||||
}
|
||||
}`
|
||||
// @formatter:on
|
||||
@@ -513,6 +521,9 @@ profileName: ""
|
||||
sshServer:
|
||||
enabled: false
|
||||
sessions: []
|
||||
vncServer:
|
||||
enabled: false
|
||||
sessions: []
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedYAML, yaml)
|
||||
@@ -582,6 +593,7 @@ Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Lazy connection: false
|
||||
SSH Server: Disabled
|
||||
VNC Server: Disabled
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
||||
@@ -607,6 +619,7 @@ Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Lazy connection: false
|
||||
SSH Server: Disabled
|
||||
VNC Server: Disabled
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`
|
||||
|
||||
@@ -62,6 +62,7 @@ type Info struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed bool
|
||||
ServerVNCAllowed bool
|
||||
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
@@ -83,6 +84,7 @@ type Info struct {
|
||||
func (i *Info) SetFlags(
|
||||
rosenpassEnabled, rosenpassPermissive bool,
|
||||
serverSSHAllowed *bool,
|
||||
serverVNCAllowed *bool,
|
||||
disableClientRoutes, disableServerRoutes,
|
||||
disableDNS, disableFirewall, blockLANAccess, blockInbound, disableIPv6, lazyConnectionEnabled bool,
|
||||
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
|
||||
@@ -93,6 +95,9 @@ func (i *Info) SetFlags(
|
||||
if serverSSHAllowed != nil {
|
||||
i.ServerSSHAllowed = *serverSSHAllowed
|
||||
}
|
||||
if serverVNCAllowed != nil {
|
||||
i.ServerVNCAllowed = *serverVNCAllowed
|
||||
}
|
||||
|
||||
i.DisableClientRoutes = disableClientRoutes
|
||||
i.DisableServerRoutes = disableServerRoutes
|
||||
|
||||
@@ -249,6 +249,7 @@ type serviceClient struct {
|
||||
mQuit *systray.MenuItem
|
||||
mNetworks *systray.MenuItem
|
||||
mAllowSSH *systray.MenuItem
|
||||
mAllowVNC *systray.MenuItem
|
||||
mAutoConnect *systray.MenuItem
|
||||
mEnableRosenpass *systray.MenuItem
|
||||
mLazyConnEnabled *systray.MenuItem
|
||||
@@ -1045,6 +1046,7 @@ func (s *serviceClient) onTrayReady() {
|
||||
|
||||
s.mSettings = systray.AddMenuItem("Settings", disabledMenuDescr)
|
||||
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false)
|
||||
s.mAllowVNC = s.mSettings.AddSubMenuItemCheckbox("Allow VNC", allowVNCMenuDescr, false)
|
||||
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false)
|
||||
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false)
|
||||
s.mLazyConnEnabled = s.mSettings.AddSubMenuItemCheckbox("Enable Lazy Connections", lazyConnMenuDescr, false)
|
||||
@@ -1452,6 +1454,7 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
|
||||
|
||||
config.DisableAutoConnect = cfg.DisableAutoConnect
|
||||
config.ServerSSHAllowed = &cfg.ServerSSHAllowed
|
||||
config.ServerVNCAllowed = &cfg.ServerVNCAllowed
|
||||
config.RosenpassEnabled = cfg.RosenpassEnabled
|
||||
config.RosenpassPermissive = cfg.RosenpassPermissive
|
||||
config.DisableNotifications = &cfg.DisableNotifications
|
||||
@@ -1547,6 +1550,12 @@ func (s *serviceClient) loadSettings() {
|
||||
s.mAllowSSH.Uncheck()
|
||||
}
|
||||
|
||||
if cfg.ServerVNCAllowed {
|
||||
s.mAllowVNC.Check()
|
||||
} else {
|
||||
s.mAllowVNC.Uncheck()
|
||||
}
|
||||
|
||||
if cfg.DisableAutoConnect {
|
||||
s.mAutoConnect.Uncheck()
|
||||
} else {
|
||||
@@ -1586,6 +1595,7 @@ func (s *serviceClient) loadSettings() {
|
||||
func (s *serviceClient) updateConfig() error {
|
||||
disableAutoStart := !s.mAutoConnect.Checked()
|
||||
sshAllowed := s.mAllowSSH.Checked()
|
||||
vncAllowed := s.mAllowVNC.Checked()
|
||||
rosenpassEnabled := s.mEnableRosenpass.Checked()
|
||||
lazyConnectionEnabled := s.mLazyConnEnabled.Checked()
|
||||
blockInbound := s.mBlockInbound.Checked()
|
||||
@@ -1614,6 +1624,7 @@ func (s *serviceClient) updateConfig() error {
|
||||
Username: currUser.Username,
|
||||
DisableAutoConnect: &disableAutoStart,
|
||||
ServerSSHAllowed: &sshAllowed,
|
||||
ServerVNCAllowed: &vncAllowed,
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
LazyConnectionEnabled: &lazyConnectionEnabled,
|
||||
BlockInbound: &blockInbound,
|
||||
|
||||
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
const (
|
||||
allowSSHMenuDescr = "Allow SSH connections"
|
||||
allowVNCMenuDescr = "Allow embedded VNC server"
|
||||
autoConnectMenuDescr = "Connect automatically when the service starts"
|
||||
quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass"
|
||||
lazyConnMenuDescr = "[Experimental] Enable lazy connections"
|
||||
|
||||
@@ -39,6 +39,8 @@ func (h *eventHandler) listen(ctx context.Context) {
|
||||
h.handleDisconnectClick()
|
||||
case <-h.client.mAllowSSH.ClickedCh:
|
||||
h.handleAllowSSHClick()
|
||||
case <-h.client.mAllowVNC.ClickedCh:
|
||||
h.handleAllowVNCClick()
|
||||
case <-h.client.mAutoConnect.ClickedCh:
|
||||
h.handleAutoConnectClick()
|
||||
case <-h.client.mEnableRosenpass.ClickedCh:
|
||||
@@ -134,6 +136,15 @@ func (h *eventHandler) handleAllowSSHClick() {
|
||||
|
||||
}
|
||||
|
||||
func (h *eventHandler) handleAllowVNCClick() {
|
||||
h.toggleCheckbox(h.client.mAllowVNC)
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
h.toggleCheckbox(h.client.mAllowVNC) // revert checkbox state on error
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
h.client.notifier.Send("Error", "Failed to update VNC settings")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *eventHandler) handleAutoConnectClick() {
|
||||
h.toggleCheckbox(h.client.mAutoConnect)
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
|
||||
327
client/vnc/server/agent_darwin.go
Normal file
327
client/vnc/server/agent_darwin.go
Normal file
@@ -0,0 +1,327 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// darwinAgentManager spawns a per-user VNC agent on demand and keeps it
|
||||
// alive across multiple client connections within the same console-user
|
||||
// session. A new agent is spawned the first time a client connects, or
|
||||
// whenever the console user changes underneath us.
|
||||
//
|
||||
// Lifecycle is lazy by design: a daemon that never receives a VNC
|
||||
// connection never spawns anything. The trade-off versus an eager spawn
|
||||
// (the Windows model) is that the first VNC client pays the launchctl
|
||||
// asuser + listen-readiness wait, ~hundreds of milliseconds in practice.
|
||||
// That cost only repeats on user switch.
|
||||
type darwinAgentManager struct {
|
||||
mu sync.Mutex
|
||||
authToken string
|
||||
port uint16
|
||||
uid uint32
|
||||
running bool
|
||||
}
|
||||
|
||||
func newDarwinAgentManager(ctx context.Context) *darwinAgentManager {
|
||||
m := &darwinAgentManager{port: agentPort}
|
||||
go m.watchConsoleUser(ctx)
|
||||
return m
|
||||
}
|
||||
|
||||
// watchConsoleUser kills the cached agent whenever the console user
|
||||
// changes (logout, fast user switch, login window). Without it the daemon
|
||||
// keeps proxying to an agent whose TCC grant and WindowServer access
|
||||
// belong to a user who is no longer at the screen, so the new user only
|
||||
// ever sees the locked-screen wallpaper. Killing the agent breaks the
|
||||
// loopback TCP that the daemon proxies into, the client disconnects, and
|
||||
// the next reconnect runs ensure() against the new console uid.
|
||||
func (m *darwinAgentManager) watchConsoleUser(ctx context.Context) {
|
||||
t := time.NewTicker(2 * time.Second)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
uid, err := consoleUserID()
|
||||
m.mu.Lock()
|
||||
if !m.running {
|
||||
m.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
if err != nil || uid != m.uid {
|
||||
prev := m.uid
|
||||
m.killLocked()
|
||||
m.mu.Unlock()
|
||||
if err != nil {
|
||||
log.Infof("console user gone (was uid=%d): %v; agent stopped", prev, err)
|
||||
} else {
|
||||
log.Infof("console user changed %d -> %d; agent stopped, will respawn on next connect", prev, uid)
|
||||
}
|
||||
continue
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ensure returns a token good for proxyToAgent. It spawns or respawns the
|
||||
// per-user agent process as needed and waits until it is listening on the
|
||||
// loopback port. Each ensure call is serialized so concurrent VNC clients
|
||||
// share the same agent.
|
||||
func (m *darwinAgentManager) ensure(ctx context.Context) (string, error) {
|
||||
consoleUID, err := consoleUserID()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("no console user: %w", err)
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.running && m.uid == consoleUID && vncAgentRunning() {
|
||||
return m.authToken, nil
|
||||
}
|
||||
m.killLocked()
|
||||
// Reap any stray external vnc-agent so the new token is the only one
|
||||
// the freshly spawned agent will accept on the loopback port.
|
||||
killAllVNCAgents()
|
||||
|
||||
token, err := generateAuthToken()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate agent auth token: %w", err)
|
||||
}
|
||||
if err := spawnAgentForUser(consoleUID, m.port, token); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := waitForAgent(ctx, m.port, 5*time.Second); err != nil {
|
||||
killAllVNCAgents()
|
||||
return "", fmt.Errorf("agent did not start listening: %w", err)
|
||||
}
|
||||
m.authToken = token
|
||||
m.uid = consoleUID
|
||||
m.running = true
|
||||
log.Infof("spawned VNC agent for console uid=%d on port %d", consoleUID, m.port)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// stop terminates the spawned agent, if any. Intended for daemon shutdown.
|
||||
func (m *darwinAgentManager) stop() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.killLocked()
|
||||
}
|
||||
|
||||
func (m *darwinAgentManager) killLocked() {
|
||||
if !m.running {
|
||||
return
|
||||
}
|
||||
killAllVNCAgents()
|
||||
m.running = false
|
||||
m.authToken = ""
|
||||
m.uid = 0
|
||||
}
|
||||
|
||||
// errNoConsoleUser is the sentinel callers use to recognise the
|
||||
// "login window showing, no user signed in" state and surface it as a
|
||||
// distinct condition to the VNC client.
|
||||
var errNoConsoleUser = errors.New("no user logged into console")
|
||||
|
||||
// consoleUserID returns the uid of the user currently sitting at the
|
||||
// console (the one whose Aqua session is active). Returns
|
||||
// errNoConsoleUser when nobody is logged in: at the login window
|
||||
// /dev/console is owned by root.
|
||||
func consoleUserID() (uint32, error) {
|
||||
info, err := os.Stat("/dev/console")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("stat /dev/console: %w", err)
|
||||
}
|
||||
st, ok := info.Sys().(*syscall.Stat_t)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("/dev/console stat has unexpected type")
|
||||
}
|
||||
if st.Uid == 0 {
|
||||
return 0, errNoConsoleUser
|
||||
}
|
||||
return st.Uid, nil
|
||||
}
|
||||
|
||||
// spawnAgentForUser uses launchctl asuser to start a netbird vnc-agent
|
||||
// process inside the target user's launchd bootstrap namespace. That is
|
||||
// the only spawn mode on macOS that gives the child access to the user's
|
||||
// WindowServer. The agent's stderr is relogged into the daemon log so
|
||||
// startup failures are not silently lost when the readiness check times
|
||||
// out.
|
||||
func spawnAgentForUser(uid uint32, port uint16, token string) error {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve own executable: %w", err)
|
||||
}
|
||||
cmd := exec.Command(
|
||||
"/bin/launchctl", "asuser", strconv.FormatUint(uint64(uid), 10),
|
||||
exe, vncAgentSubcommand, "--port", strconv.FormatUint(uint64(port), 10),
|
||||
)
|
||||
cmd.Env = append(os.Environ(), agentTokenEnvVar+"="+token)
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("agent stderr pipe: %w", err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("launchctl asuser: %w", err)
|
||||
}
|
||||
go func() {
|
||||
defer stderr.Close()
|
||||
relogAgentStream(stderr)
|
||||
}()
|
||||
go func() { _ = cmd.Wait() }()
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitForAgent dials the loopback port until the agent answers. Used to
|
||||
// gate proxy attempts until the spawned process has finished its Start.
|
||||
func waitForAgent(ctx context.Context, port uint16, wait time.Duration) error {
|
||||
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||
deadline := time.Now().Add(wait)
|
||||
for time.Now().Before(deadline) {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
c, err := net.DialTimeout("tcp", addr, 200*time.Millisecond)
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
return nil
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
return fmt.Errorf("timeout dialing %s", addr)
|
||||
}
|
||||
|
||||
// vncAgentRunning reports whether any vnc-agent process exists on the
|
||||
// system. The daemon owns the only port-15900 listener model, so any
|
||||
// match is "the" agent.
|
||||
func vncAgentRunning() bool {
|
||||
pids, err := vncAgentPIDs()
|
||||
if err != nil {
|
||||
log.Debugf("scan for vnc-agent: %v", err)
|
||||
return false
|
||||
}
|
||||
return len(pids) > 0
|
||||
}
|
||||
|
||||
// killAllVNCAgents sends SIGTERM to every process whose argv contains
|
||||
// "vnc-agent", waits briefly for them to exit, and escalates to SIGKILL
|
||||
// for any that remain. We enumerate kern.proc.all rather than
|
||||
// kern.proc.uid because launchctl asuser preserves the caller's uid
|
||||
// (root) on the spawned child, so a uid-scoped filter would never match.
|
||||
func killAllVNCAgents() {
|
||||
pids, err := vncAgentPIDs()
|
||||
if err != nil {
|
||||
log.Debugf("scan for vnc-agent: %v", err)
|
||||
return
|
||||
}
|
||||
for _, pid := range pids {
|
||||
_ = syscall.Kill(pid, syscall.SIGTERM)
|
||||
}
|
||||
if len(pids) == 0 {
|
||||
return
|
||||
}
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
remaining, _ := vncAgentPIDs()
|
||||
if len(remaining) == 0 {
|
||||
return
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
leftover, _ := vncAgentPIDs()
|
||||
for _, pid := range leftover {
|
||||
_ = syscall.Kill(pid, syscall.SIGKILL)
|
||||
}
|
||||
}
|
||||
|
||||
// vncAgentPIDs returns the pids of vnc-agent subprocesses spawned from
|
||||
// this binary. Matches exactly on argv[0] == our own executable path
|
||||
// AND argv[1] == "vnc-agent" so unrelated processes that happen to have
|
||||
// the same name elsewhere in argv are not targeted. Skips pid 0 and 1
|
||||
// defensively.
|
||||
func vncAgentPIDs() ([]int, error) {
|
||||
procs, err := unix.SysctlKinfoProcSlice("kern.proc.all")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sysctl kern.proc.all: %w", err)
|
||||
}
|
||||
ownExe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve own executable: %w", err)
|
||||
}
|
||||
var out []int
|
||||
for i := range procs {
|
||||
pid := int(procs[i].Proc.P_pid)
|
||||
if pid <= 1 {
|
||||
continue
|
||||
}
|
||||
argv, err := procArgv(pid)
|
||||
if err != nil || !argvIsVNCAgent(argv, ownExe) {
|
||||
continue
|
||||
}
|
||||
out = append(out, pid)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// procArgv reads the kernel's stored argv for pid via the kern.procargs2
|
||||
// sysctl. Format: 4-byte argc, then argv[0..argc) each NUL-terminated,
|
||||
// then envp, then padding. We only need argv so we stop after argc.
|
||||
func procArgv(pid int) ([]string, error) {
|
||||
raw, err := unix.SysctlRaw("kern.procargs2", pid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(raw) < 4 {
|
||||
return nil, fmt.Errorf("procargs2 truncated")
|
||||
}
|
||||
argc := int(raw[0]) | int(raw[1])<<8 | int(raw[2])<<16 | int(raw[3])<<24
|
||||
body := raw[4:]
|
||||
// Skip the executable path (NUL-terminated) and any zero padding that
|
||||
// follows before argv[0].
|
||||
end := bytes.IndexByte(body, 0)
|
||||
if end < 0 {
|
||||
return nil, fmt.Errorf("procargs2 path unterminated")
|
||||
}
|
||||
body = body[end+1:]
|
||||
for len(body) > 0 && body[0] == 0 {
|
||||
body = body[1:]
|
||||
}
|
||||
args := make([]string, 0, argc)
|
||||
for i := 0; i < argc; i++ {
|
||||
end := bytes.IndexByte(body, 0)
|
||||
if end < 0 {
|
||||
break
|
||||
}
|
||||
args = append(args, string(body[:end]))
|
||||
body = body[end+1:]
|
||||
}
|
||||
return args, nil
|
||||
}
|
||||
|
||||
// argvIsVNCAgent reports whether argv belongs to a vnc-agent subprocess
|
||||
// spawned from our binary. Requires argv[0] to match ownExe exactly and
|
||||
// argv[1] to be the vnc-agent subcommand. Matches the spawn shape in
|
||||
// spawnAgentForUser and rejects anything else.
|
||||
func argvIsVNCAgent(argv []string, ownExe string) bool {
|
||||
if len(argv) < 2 || ownExe == "" {
|
||||
return false
|
||||
}
|
||||
return argv[0] == ownExe && argv[1] == vncAgentSubcommand
|
||||
}
|
||||
178
client/vnc/server/agent_ipc.go
Normal file
178
client/vnc/server/agent_ipc.go
Normal file
@@ -0,0 +1,178 @@
|
||||
//go:build darwin || windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
crand "crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// agentPort is the TCP loopback port on which a per-session VNC agent
|
||||
// listens. The daemon dials this port and presents agentToken before
|
||||
// proxying VNC bytes. The choice of TCP (rather than a Unix socket or
|
||||
// named pipe) is intentional: it lets the same proxy/handshake code
|
||||
// run on every platform; the token does the access control.
|
||||
agentPort uint16 = 15900
|
||||
|
||||
// agentTokenLen is the size of the random per-spawn token in bytes.
|
||||
agentTokenLen = 32
|
||||
|
||||
// agentTokenEnvVar names the environment variable the daemon uses to
|
||||
// hand the per-spawn token to the agent child. Out-of-band channels
|
||||
// like this keep the secret out of the command line, where listings
|
||||
// such as `ps` or Windows tasklist would expose it.
|
||||
agentTokenEnvVar = "NB_VNC_AGENT_TOKEN" // #nosec G101 -- env var name, not a credential
|
||||
|
||||
// vncAgentSubcommand is the CLI subcommand the daemon invokes to start
|
||||
// the per-session agent process. Must match cmd.vncAgentCmd.Use in
|
||||
// client/cmd/vnc_agent.go.
|
||||
vncAgentSubcommand = "vnc-agent"
|
||||
)
|
||||
|
||||
// generateAuthToken returns a fresh hex-encoded random token for one
|
||||
// daemon→agent session. The daemon hands this to the spawned agent
|
||||
// out-of-band (env var on Windows) and verifies it on every connection
|
||||
// the agent accepts.
|
||||
func generateAuthToken() (string, error) {
|
||||
b := make([]byte, agentTokenLen)
|
||||
if _, err := crand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("read random: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// proxyToAgent dials the per-session agent on TCP loopback, writes the
|
||||
// raw token bytes, and then copies bytes in both directions until either
|
||||
// side closes. The token has to land on the wire before any VNC byte so
|
||||
// the agent's listening Server can apply verifyAgentToken before letting
|
||||
// real RFB traffic through.
|
||||
func proxyToAgent(ctx context.Context, client net.Conn, port uint16, authToken string) {
|
||||
defer client.Close()
|
||||
|
||||
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||
agentConn, err := dialAgentWithRetry(ctx, addr)
|
||||
if err != nil {
|
||||
log.Warnf("proxy cannot reach agent at %s: %v", addr, err)
|
||||
return
|
||||
}
|
||||
defer agentConn.Close()
|
||||
|
||||
tokenBytes, err := hex.DecodeString(authToken)
|
||||
if err != nil || len(tokenBytes) != agentTokenLen {
|
||||
log.Warnf("invalid auth token (len=%d): %v", len(tokenBytes), err)
|
||||
return
|
||||
}
|
||||
if _, err := agentConn.Write(tokenBytes); err != nil {
|
||||
log.Warnf("send auth token to agent: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("proxy connected to agent, starting bidirectional copy")
|
||||
done := make(chan struct{}, 2)
|
||||
cp := func(label string, dst, src net.Conn) {
|
||||
n, err := io.Copy(dst, src)
|
||||
log.Debugf("proxy %s: %d bytes, err=%v", label, n, err)
|
||||
done <- struct{}{}
|
||||
}
|
||||
go cp("client→agent", agentConn, client)
|
||||
go cp("agent→client", client, agentConn)
|
||||
<-done
|
||||
}
|
||||
|
||||
// relogAgentStream reads log lines from the agent's stderr and re-emits
|
||||
// them through the daemon's logrus, so the merged log keeps a single
|
||||
// format. JSON lines (the agent's normal output) are parsed and dispatched
|
||||
// by level; plain-text lines (cobra errors, panic traces) are forwarded
|
||||
// verbatim so early-startup failures stay visible.
|
||||
func relogAgentStream(r io.Reader) {
|
||||
entry := log.WithField("component", "vnc-agent")
|
||||
scanner := bufio.NewScanner(r)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
if line[0] != '{' {
|
||||
entry.Warn(string(line))
|
||||
continue
|
||||
}
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(line, &m); err != nil {
|
||||
entry.Warn(string(line))
|
||||
continue
|
||||
}
|
||||
msg, _ := m["msg"].(string)
|
||||
if msg == "" {
|
||||
continue
|
||||
}
|
||||
fields := make(log.Fields)
|
||||
for k, v := range m {
|
||||
switch k {
|
||||
case "msg", "level", "time", "func":
|
||||
continue
|
||||
case "caller":
|
||||
fields["source"] = v
|
||||
default:
|
||||
fields[k] = v
|
||||
}
|
||||
}
|
||||
e := entry.WithFields(fields)
|
||||
switch m["level"] {
|
||||
case "error":
|
||||
e.Error(msg)
|
||||
case "warning":
|
||||
e.Warn(msg)
|
||||
case "debug":
|
||||
e.Debug(msg)
|
||||
case "trace":
|
||||
e.Trace(msg)
|
||||
default:
|
||||
e.Info(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dialAgentWithRetry retries the loopback connect for up to ~10 s so the
|
||||
// daemon does not race the agent's first listen. Returns the live conn or
|
||||
// the final error. Aborts early when ctx is cancelled so a Stop() during
|
||||
// service-mode startup doesn't leave a goroutine sleeping for 10 s.
|
||||
func dialAgentWithRetry(ctx context.Context, addr string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
var lastErr error
|
||||
for range 50 {
|
||||
if err := ctx.Err(); err != nil {
|
||||
if lastErr == nil {
|
||||
lastErr = err
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
dialCtx, cancel := context.WithTimeout(ctx, time.Second)
|
||||
c, err := d.DialContext(dialCtx, "tcp", addr)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return c, nil
|
||||
}
|
||||
lastErr = err
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if errors.Is(lastErr, context.Canceled) || errors.Is(lastErr, context.DeadlineExceeded) {
|
||||
lastErr = ctx.Err()
|
||||
}
|
||||
return nil, lastErr
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
692
client/vnc/server/agent_windows.go
Normal file
692
client/vnc/server/agent_windows.go
Normal file
@@ -0,0 +1,692 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
stillActive = 259
|
||||
|
||||
tokenPrimary = 1
|
||||
securityImpersonation = 2
|
||||
tokenSessionID = 12
|
||||
|
||||
createUnicodeEnvironment = 0x00000400
|
||||
createNoWindow = 0x08000000
|
||||
createSuspended = 0x00000004
|
||||
createBreakawayFromJob = 0x01000000
|
||||
)
|
||||
|
||||
var (
|
||||
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
advapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
||||
userenv = windows.NewLazySystemDLL("userenv.dll")
|
||||
|
||||
procWTSGetActiveConsoleSessionId = kernel32.NewProc("WTSGetActiveConsoleSessionId")
|
||||
procCreateJobObjectW = kernel32.NewProc("CreateJobObjectW")
|
||||
procSetInformationJobObject = kernel32.NewProc("SetInformationJobObject")
|
||||
procAssignProcessToJobObject = kernel32.NewProc("AssignProcessToJobObject")
|
||||
procSetTokenInformation = advapi32.NewProc("SetTokenInformation")
|
||||
procCreateEnvironmentBlock = userenv.NewProc("CreateEnvironmentBlock")
|
||||
procDestroyEnvironmentBlock = userenv.NewProc("DestroyEnvironmentBlock")
|
||||
|
||||
wtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll")
|
||||
procWTSEnumerateSessionsW = wtsapi32.NewProc("WTSEnumerateSessionsW")
|
||||
procWTSFreeMemory = wtsapi32.NewProc("WTSFreeMemory")
|
||||
procWTSQuerySessionInformation = wtsapi32.NewProc("WTSQuerySessionInformationW")
|
||||
|
||||
iphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
|
||||
procGetExtendedTcpTable = iphlpapi.NewProc("GetExtendedTcpTable")
|
||||
)
|
||||
|
||||
// GetCurrentSessionID returns the session ID of the current process.
|
||||
func GetCurrentSessionID() uint32 {
|
||||
var token windows.Token
|
||||
if err := windows.OpenProcessToken(windows.CurrentProcess(),
|
||||
windows.TOKEN_QUERY, &token); err != nil {
|
||||
return 0
|
||||
}
|
||||
defer token.Close()
|
||||
var id uint32
|
||||
var ret uint32
|
||||
_ = windows.GetTokenInformation(token, windows.TokenSessionId,
|
||||
(*byte)(unsafe.Pointer(&id)), 4, &ret)
|
||||
return id
|
||||
}
|
||||
|
||||
func getConsoleSessionID() uint32 {
|
||||
r, _, _ := procWTSGetActiveConsoleSessionId.Call()
|
||||
return uint32(r)
|
||||
}
|
||||
|
||||
const (
|
||||
wtsActive = 0
|
||||
wtsConnected = 1
|
||||
wtsDisconnected = 4
|
||||
)
|
||||
|
||||
// getActiveSessionID returns the session ID of the best session to attach to.
|
||||
// On a Windows Server with no console display attached, session 1 still
|
||||
// reports WTSActive (login screen "owns" the console), so a naive
|
||||
// first-active-wins pick lands on a session with no actual rendering.
|
||||
// Preference order:
|
||||
// 1. Active session with a user logged in (RDP user in session ≥2)
|
||||
// 2. Active session without a user (console at login screen)
|
||||
// 3. Console session ID
|
||||
func getActiveSessionID() uint32 {
|
||||
var sessionInfo uintptr
|
||||
var count uint32
|
||||
|
||||
r, _, _ := procWTSEnumerateSessionsW.Call(
|
||||
0, // WTS_CURRENT_SERVER_HANDLE
|
||||
0, // reserved
|
||||
1, // version
|
||||
uintptr(unsafe.Pointer(&sessionInfo)),
|
||||
uintptr(unsafe.Pointer(&count)),
|
||||
)
|
||||
if r == 0 || count == 0 {
|
||||
return getConsoleSessionID()
|
||||
}
|
||||
defer func() { _, _, _ = procWTSFreeMemory.Call(sessionInfo) }()
|
||||
|
||||
type wtsSession struct {
|
||||
SessionID uint32
|
||||
Station *uint16
|
||||
State uint32
|
||||
}
|
||||
sessions := unsafe.Slice((*wtsSession)(unsafe.Pointer(sessionInfo)), count)
|
||||
|
||||
var withUser uint32
|
||||
var withUserFound bool
|
||||
var anyActive uint32
|
||||
var anyActiveFound bool
|
||||
for _, s := range sessions {
|
||||
if s.SessionID == 0 {
|
||||
continue
|
||||
}
|
||||
if s.State != wtsActive {
|
||||
continue
|
||||
}
|
||||
if !anyActiveFound {
|
||||
anyActive = s.SessionID
|
||||
anyActiveFound = true
|
||||
}
|
||||
if !withUserFound && wtsSessionHasUser(s.SessionID) {
|
||||
withUser = s.SessionID
|
||||
withUserFound = true
|
||||
}
|
||||
}
|
||||
if withUserFound {
|
||||
return withUser
|
||||
}
|
||||
if anyActiveFound {
|
||||
return anyActive
|
||||
}
|
||||
return getConsoleSessionID()
|
||||
}
|
||||
|
||||
// reapOrphanOnPort finds any process listening on 127.0.0.1:port and, if
|
||||
// it's a netbird vnc-agent left over from a previous service instance,
|
||||
// terminates it. Verified by image-name match so we never kill an
|
||||
// unrelated process that happens to use the same port.
|
||||
func reapOrphanOnPort(port uint16) {
|
||||
pid := tcpListenerPID(port)
|
||||
if pid == 0 || pid == uint32(windows.GetCurrentProcessId()) {
|
||||
return
|
||||
}
|
||||
h, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION|windows.PROCESS_TERMINATE|windows.SYNCHRONIZE, false, pid)
|
||||
if err != nil {
|
||||
log.Warnf("reap on port %d: open PID=%d: %v", port, pid, err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = windows.CloseHandle(h) }()
|
||||
if !isOurAgentProcess(h) {
|
||||
log.Warnf("reap on port %d: PID=%d is not a netbird vnc-agent, leaving it alone", port, pid)
|
||||
return
|
||||
}
|
||||
if err := windows.TerminateProcess(h, 0); err != nil {
|
||||
log.Warnf("reap on port %d: terminate PID=%d: %v", port, pid, err)
|
||||
return
|
||||
}
|
||||
log.Infof("reaped orphan vnc-agent PID=%d holding port %d", pid, port)
|
||||
}
|
||||
|
||||
// isOurAgentProcess returns true if the given process handle points at a
|
||||
// netbird.exe binary at the same path as the current process. We compare
|
||||
// full paths (case-insensitive on Windows) so co-installed netbird binaries
|
||||
// from a different install dir or unrelated apps named netbird.exe don't
|
||||
// get killed.
|
||||
func isOurAgentProcess(h windows.Handle) bool {
|
||||
var size uint32 = windows.MAX_PATH
|
||||
buf := make([]uint16, size)
|
||||
if err := windows.QueryFullProcessImageName(h, 0, &buf[0], &size); err != nil {
|
||||
return false
|
||||
}
|
||||
target := strings.ToLower(windows.UTF16ToString(buf[:size]))
|
||||
selfExe, err := os.Executable()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return target == strings.ToLower(selfExe)
|
||||
}
|
||||
|
||||
// tcpListenerPID returns the PID of the process listening on 127.0.0.1:port,
|
||||
// or 0 if none. Uses GetExtendedTcpTable with TCP_TABLE_OWNER_PID_LISTENER.
|
||||
func tcpListenerPID(port uint16) uint32 {
|
||||
const tcpTableOwnerPidListener = 3
|
||||
const afInet = 2
|
||||
|
||||
// MIB_TCPROW_OWNER_PID layout: state(4) + localAddr(4) + localPort(4) +
|
||||
// remoteAddr(4) + remotePort(4) + owningPid(4) = 24 bytes.
|
||||
const rowSize = 24
|
||||
|
||||
var size uint32
|
||||
_, _, _ = procGetExtendedTcpTable.Call(0, uintptr(unsafe.Pointer(&size)), 0, afInet, tcpTableOwnerPidListener, 0)
|
||||
if size == 0 {
|
||||
return 0
|
||||
}
|
||||
buf := make([]byte, size)
|
||||
r, _, _ := procGetExtendedTcpTable.Call(
|
||||
uintptr(unsafe.Pointer(&buf[0])),
|
||||
uintptr(unsafe.Pointer(&size)),
|
||||
0, afInet, tcpTableOwnerPidListener, 0,
|
||||
)
|
||||
if r != 0 {
|
||||
return 0
|
||||
}
|
||||
count := binary.LittleEndian.Uint32(buf[:4])
|
||||
for i := uint32(0); i < count; i++ {
|
||||
off := 4 + int(i)*rowSize
|
||||
if off+rowSize > len(buf) {
|
||||
break
|
||||
}
|
||||
// localPort is stored big-endian in the high 16 bits of a 32-bit field.
|
||||
localPort := uint16(buf[off+8])<<8 | uint16(buf[off+9])
|
||||
if localPort != port {
|
||||
continue
|
||||
}
|
||||
localAddr := binary.LittleEndian.Uint32(buf[off+4 : off+8])
|
||||
// 0x0100007f == 127.0.0.1 in network byte order on little-endian.
|
||||
// We accept 0.0.0.0 too in case the orphan bound to all interfaces.
|
||||
if localAddr != 0x0100007f && localAddr != 0 {
|
||||
continue
|
||||
}
|
||||
return binary.LittleEndian.Uint32(buf[off+20 : off+24])
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// wtsSessionHasUser returns true if the session has a non-empty user name,
|
||||
// i.e. someone is logged in (vs. the login/Welcome screen). The console
|
||||
// session at the lock screen has WTSUserName == "".
|
||||
const wtsUserName = 5
|
||||
|
||||
func wtsSessionHasUser(sessionID uint32) bool {
|
||||
var buf uintptr
|
||||
var bytesReturned uint32
|
||||
r, _, _ := procWTSQuerySessionInformation.Call(
|
||||
0, // WTS_CURRENT_SERVER_HANDLE
|
||||
uintptr(sessionID),
|
||||
uintptr(wtsUserName),
|
||||
uintptr(unsafe.Pointer(&buf)),
|
||||
uintptr(unsafe.Pointer(&bytesReturned)),
|
||||
)
|
||||
if r == 0 || buf == 0 {
|
||||
return false
|
||||
}
|
||||
defer func() { _, _, _ = procWTSFreeMemory.Call(buf) }()
|
||||
// First UTF-16 code unit non-zero ⇒ non-empty username.
|
||||
return *(*uint16)(unsafe.Pointer(buf)) != 0
|
||||
}
|
||||
|
||||
// getSystemTokenForSession duplicates the current SYSTEM token and sets its
|
||||
// session ID so the spawned process runs in the target session. Using a SYSTEM
|
||||
// token gives access to both Default and Winlogon desktops plus UIPI bypass.
|
||||
func getSystemTokenForSession(sessionID uint32) (windows.Token, error) {
|
||||
var cur windows.Token
|
||||
if err := windows.OpenProcessToken(windows.CurrentProcess(),
|
||||
windows.MAXIMUM_ALLOWED, &cur); err != nil {
|
||||
return 0, fmt.Errorf("OpenProcessToken: %w", err)
|
||||
}
|
||||
defer cur.Close()
|
||||
|
||||
var dup windows.Token
|
||||
if err := windows.DuplicateTokenEx(cur, windows.MAXIMUM_ALLOWED, nil,
|
||||
securityImpersonation, tokenPrimary, &dup); err != nil {
|
||||
return 0, fmt.Errorf("DuplicateTokenEx: %w", err)
|
||||
}
|
||||
|
||||
sid := sessionID
|
||||
r, _, err := procSetTokenInformation.Call(
|
||||
uintptr(dup),
|
||||
uintptr(tokenSessionID),
|
||||
uintptr(unsafe.Pointer(&sid)),
|
||||
unsafe.Sizeof(sid),
|
||||
)
|
||||
if r == 0 {
|
||||
dup.Close()
|
||||
return 0, fmt.Errorf("SetTokenInformation(SessionId=%d): %w", sessionID, err)
|
||||
}
|
||||
return dup, nil
|
||||
}
|
||||
|
||||
// injectEnvVar appends a KEY=VALUE entry to a Unicode environment block.
|
||||
// The block is a sequence of null-terminated UTF-16 strings, terminated by
|
||||
// an extra null. Returns the new []uint16 backing slice; the caller must
|
||||
// hold the returned slice alive until CreateProcessAsUser completes.
|
||||
func injectEnvVar(envBlock uintptr, key, value string) []uint16 {
|
||||
entry := key + "=" + value
|
||||
|
||||
// Walk the existing block to find its total length.
|
||||
ptr := (*uint16)(unsafe.Pointer(envBlock))
|
||||
var totalChars int
|
||||
for {
|
||||
ch := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars)*2))
|
||||
if ch == 0 {
|
||||
// Check for double-null terminator.
|
||||
next := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars+1)*2))
|
||||
totalChars++
|
||||
if next == 0 {
|
||||
// End of block (don't count the final null yet, we'll rebuild).
|
||||
break
|
||||
}
|
||||
} else {
|
||||
totalChars++
|
||||
}
|
||||
}
|
||||
|
||||
entryUTF16, _ := windows.UTF16FromString(entry)
|
||||
// New block: existing entries + new entry (null-terminated) + final null.
|
||||
newLen := totalChars + len(entryUTF16) + 1
|
||||
newBlock := make([]uint16, newLen)
|
||||
// Copy existing entries (up to but not including the final null).
|
||||
for i := range totalChars {
|
||||
newBlock[i] = *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(i)*2))
|
||||
}
|
||||
copy(newBlock[totalChars:], entryUTF16)
|
||||
newBlock[newLen-1] = 0 // final null terminator
|
||||
|
||||
return newBlock
|
||||
}
|
||||
|
||||
func spawnAgentInSession(sessionID uint32, port uint16, authToken string, jobHandle windows.Handle) (windows.Handle, error) {
|
||||
token, err := getSystemTokenForSession(sessionID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get SYSTEM token for session %d: %w", sessionID, err)
|
||||
}
|
||||
defer token.Close()
|
||||
|
||||
var envBlock uintptr
|
||||
r, _, e := procCreateEnvironmentBlock.Call(
|
||||
uintptr(unsafe.Pointer(&envBlock)),
|
||||
uintptr(token),
|
||||
0,
|
||||
)
|
||||
if r == 0 {
|
||||
// Without an environment block we cannot inject NB_VNC_AGENT_TOKEN;
|
||||
// the agent would start unauthenticated. Abort instead of launching.
|
||||
return 0, fmt.Errorf("CreateEnvironmentBlock: %w", e)
|
||||
}
|
||||
defer func() { _, _, _ = procDestroyEnvironmentBlock.Call(envBlock) }()
|
||||
|
||||
// Inject the auth token into the environment block so it doesn't appear
|
||||
// in the process command line (visible via tasklist/wmic). injectedBlock
|
||||
// must stay alive until CreateProcessAsUser returns.
|
||||
injectedBlock := injectEnvVar(envBlock, agentTokenEnvVar, authToken)
|
||||
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get executable path: %w", err)
|
||||
}
|
||||
|
||||
cmdLine := fmt.Sprintf(`"%s" %s --port %d`, exePath, vncAgentSubcommand, port)
|
||||
cmdLineW, err := windows.UTF16PtrFromString(cmdLine)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("UTF16 cmdline: %w", err)
|
||||
}
|
||||
|
||||
// Create an inheritable pipe for the agent's stderr so we can relog
|
||||
// its output in the service process.
|
||||
var sa windows.SecurityAttributes
|
||||
sa.Length = uint32(unsafe.Sizeof(sa))
|
||||
sa.InheritHandle = 1
|
||||
|
||||
var stderrRead, stderrWrite windows.Handle
|
||||
if err := windows.CreatePipe(&stderrRead, &stderrWrite, &sa, 0); err != nil {
|
||||
return 0, fmt.Errorf("create stderr pipe: %w", err)
|
||||
}
|
||||
// The read end must NOT be inherited by the child.
|
||||
_ = windows.SetHandleInformation(stderrRead, windows.HANDLE_FLAG_INHERIT, 0)
|
||||
|
||||
desktop, _ := windows.UTF16PtrFromString(`WinSta0\Default`)
|
||||
si := windows.StartupInfo{
|
||||
Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})),
|
||||
Desktop: desktop,
|
||||
Flags: windows.STARTF_USESHOWWINDOW | windows.STARTF_USESTDHANDLES,
|
||||
ShowWindow: 0,
|
||||
StdErr: stderrWrite,
|
||||
StdOutput: stderrWrite,
|
||||
}
|
||||
var pi windows.ProcessInformation
|
||||
|
||||
var envPtr *uint16
|
||||
if len(injectedBlock) > 0 {
|
||||
envPtr = &injectedBlock[0]
|
||||
} else if envBlock != 0 {
|
||||
envPtr = (*uint16)(unsafe.Pointer(envBlock))
|
||||
}
|
||||
|
||||
// CREATE_SUSPENDED so we can assign the process to our Job Object
|
||||
// before it executes. Without this the agent could spawn its own child
|
||||
// processes and have them inherit the SCM service-job (not ours), or
|
||||
// briefly listen on the agent port before we tear it down on rollback.
|
||||
// CREATE_BREAKAWAY_FROM_JOB lets the child leave the SCM-managed
|
||||
// service job; harmless if that job allows breakaway, and is required
|
||||
// before AssignProcessToJobObject can succeed in the no-nested-jobs case.
|
||||
err = windows.CreateProcessAsUser(
|
||||
token, nil, cmdLineW,
|
||||
nil, nil, true, // inheritHandles=true for the pipe
|
||||
createUnicodeEnvironment|createNoWindow|createSuspended|createBreakawayFromJob,
|
||||
envPtr, nil, &si, &pi,
|
||||
)
|
||||
runtime.KeepAlive(injectedBlock)
|
||||
// Close the write end in the parent so reads will get EOF when the child exits.
|
||||
_ = windows.CloseHandle(stderrWrite)
|
||||
if err != nil {
|
||||
_ = windows.CloseHandle(stderrRead)
|
||||
return 0, fmt.Errorf("CreateProcessAsUser: %w", err)
|
||||
}
|
||||
|
||||
if jobHandle != 0 {
|
||||
r, _, e := procAssignProcessToJobObject.Call(uintptr(jobHandle), uintptr(pi.Process))
|
||||
if r == 0 {
|
||||
log.Warnf("assign agent to job object: %v (orphan possible on service crash)", e)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := windows.ResumeThread(pi.Thread); err != nil {
|
||||
log.Warnf("resume agent main thread: %v", err)
|
||||
}
|
||||
_ = windows.CloseHandle(pi.Thread)
|
||||
|
||||
// Relog agent output in the service with a [vnc-agent] prefix.
|
||||
go relogAgentOutput(stderrRead)
|
||||
|
||||
log.Infof("spawned agent PID=%d in session %d on port %d", pi.ProcessId, sessionID, port)
|
||||
return pi.Process, nil
|
||||
}
|
||||
|
||||
// sessionManager monitors the active console session and ensures a VNC agent
|
||||
// process is running in it. When the session changes (e.g., user switch, RDP
|
||||
// connect/disconnect), it kills the old agent and spawns a new one.
|
||||
type sessionManager struct {
|
||||
port uint16
|
||||
mu sync.Mutex
|
||||
agentProc windows.Handle
|
||||
everSpawned bool
|
||||
agentStartedAt time.Time
|
||||
spawnFailures int
|
||||
nextSpawnAt time.Time
|
||||
sessionID uint32
|
||||
authToken string
|
||||
done chan struct{}
|
||||
// jobHandle owns the agent processes via a Windows Job Object with
|
||||
// JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE. When the service exits or crashes,
|
||||
// the OS closes the handle and terminates every assigned agent: no
|
||||
// orphaned listeners holding the agent port across restarts.
|
||||
jobHandle windows.Handle
|
||||
}
|
||||
|
||||
func newSessionManager(port uint16) *sessionManager {
|
||||
m := &sessionManager{port: port, sessionID: ^uint32(0), done: make(chan struct{})}
|
||||
if h, err := createKillOnCloseJob(); err != nil {
|
||||
log.Warnf("create job object for vnc-agent (orphan agents possible after crash): %v", err)
|
||||
} else {
|
||||
m.jobHandle = h
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// createKillOnCloseJob returns a Job Object configured so that closing its
|
||||
// handle (process exit or explicit Close) terminates every process assigned
|
||||
// to it. Used to keep orphaned vnc-agent processes from outliving the service.
|
||||
func createKillOnCloseJob() (windows.Handle, error) {
|
||||
r, _, e := procCreateJobObjectW.Call(0, 0)
|
||||
if r == 0 {
|
||||
return 0, fmt.Errorf("CreateJobObject: %w", e)
|
||||
}
|
||||
job := windows.Handle(r)
|
||||
|
||||
// JOBOBJECT_EXTENDED_LIMIT_INFORMATION on amd64 = 144 bytes.
|
||||
//
|
||||
// JOBOBJECT_BASIC_LIMIT_INFORMATION (64 bytes with alignment padding)
|
||||
// PerProcessUserTimeLimit LARGE_INTEGER off 0
|
||||
// PerJobUserTimeLimit LARGE_INTEGER off 8
|
||||
// LimitFlags DWORD off 16
|
||||
// [4 byte pad to align SIZE_T]
|
||||
// MinimumWorkingSetSize SIZE_T off 24
|
||||
// MaximumWorkingSetSize SIZE_T off 32
|
||||
// ActiveProcessLimit DWORD off 40
|
||||
// [4 byte pad to align ULONG_PTR]
|
||||
// Affinity ULONG_PTR off 48
|
||||
// PriorityClass DWORD off 56
|
||||
// SchedulingClass DWORD off 60
|
||||
// IO_COUNTERS (48) + 4 * SIZE_T (32) = 144 total.
|
||||
//
|
||||
// We only set LimitFlags; the rest stays zero.
|
||||
const sizeofExtended = 144
|
||||
const offsetLimitFlags = 16
|
||||
const jobObjectExtendedLimitInformation = 9
|
||||
const jobObjectLimitKillOnJobClose = 0x00002000
|
||||
|
||||
var info [sizeofExtended]byte
|
||||
binary.LittleEndian.PutUint32(info[offsetLimitFlags:offsetLimitFlags+4], jobObjectLimitKillOnJobClose)
|
||||
|
||||
r, _, e = procSetInformationJobObject.Call(
|
||||
uintptr(job),
|
||||
uintptr(jobObjectExtendedLimitInformation),
|
||||
uintptr(unsafe.Pointer(&info[0])),
|
||||
uintptr(sizeofExtended),
|
||||
)
|
||||
if r == 0 {
|
||||
_ = windows.CloseHandle(job)
|
||||
return 0, fmt.Errorf("SetInformationJobObject(KILL_ON_JOB_CLOSE): %w", e)
|
||||
}
|
||||
return job, nil
|
||||
}
|
||||
|
||||
// AuthToken returns the current agent authentication token.
|
||||
func (m *sessionManager) AuthToken() string {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.authToken
|
||||
}
|
||||
|
||||
// Stop signals the session manager to exit its polling loop and closes the
|
||||
// Job Object handle, which Windows uses as the trigger to terminate every
|
||||
// agent process this manager spawned.
|
||||
func (m *sessionManager) Stop() {
|
||||
select {
|
||||
case <-m.done:
|
||||
default:
|
||||
close(m.done)
|
||||
}
|
||||
m.mu.Lock()
|
||||
if m.jobHandle != 0 {
|
||||
_ = windows.CloseHandle(m.jobHandle)
|
||||
m.jobHandle = 0
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *sessionManager) run() {
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
if !m.tick() {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-m.done:
|
||||
m.mu.Lock()
|
||||
m.killAgent()
|
||||
m.mu.Unlock()
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tick performs one session/agent-state update. Returns false if the manager
|
||||
// should permanently stop (e.g. missing SYSTEM privileges).
|
||||
func (m *sessionManager) tick() bool {
|
||||
sid := getActiveSessionID()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.handleSessionChange(sid)
|
||||
m.reapExitedAgent()
|
||||
return m.maybeSpawnAgent(sid)
|
||||
}
|
||||
|
||||
func (m *sessionManager) handleSessionChange(sid uint32) {
|
||||
if sid == m.sessionID {
|
||||
return
|
||||
}
|
||||
log.Infof("active session changed: %d -> %d", m.sessionID, sid)
|
||||
m.killAgent()
|
||||
m.sessionID = sid
|
||||
}
|
||||
|
||||
func (m *sessionManager) reapExitedAgent() {
|
||||
if m.agentProc == 0 {
|
||||
return
|
||||
}
|
||||
var code uint32
|
||||
if err := windows.GetExitCodeProcess(m.agentProc, &code); err != nil {
|
||||
log.Debugf("GetExitCodeProcess: %v", err)
|
||||
return
|
||||
}
|
||||
if code == stillActive {
|
||||
return
|
||||
}
|
||||
m.scheduleNextSpawn(code, time.Since(m.agentStartedAt))
|
||||
if err := windows.CloseHandle(m.agentProc); err != nil {
|
||||
log.Debugf("close agent handle: %v", err)
|
||||
}
|
||||
m.agentProc = 0
|
||||
}
|
||||
|
||||
// scheduleNextSpawn applies an exponential backoff on fast crashes (<5s) and
|
||||
// resets immediately otherwise.
|
||||
func (m *sessionManager) scheduleNextSpawn(exitCode uint32, lifetime time.Duration) {
|
||||
if lifetime < 5*time.Second {
|
||||
m.spawnFailures++
|
||||
backoff := time.Duration(1<<min(m.spawnFailures, 5)) * time.Second
|
||||
if backoff > 30*time.Second {
|
||||
backoff = 30 * time.Second
|
||||
}
|
||||
m.nextSpawnAt = time.Now().Add(backoff)
|
||||
log.Warnf("agent exited (code=%d) after %v, retrying in %v (failures=%d)", exitCode, lifetime.Round(time.Millisecond), backoff, m.spawnFailures)
|
||||
return
|
||||
}
|
||||
m.spawnFailures = 0
|
||||
m.nextSpawnAt = time.Time{}
|
||||
log.Infof("agent exited (code=%d) after %v, respawning", exitCode, lifetime.Round(time.Second))
|
||||
}
|
||||
|
||||
// maybeSpawnAgent spawns a new agent if there's no current one and the backoff
|
||||
// window has elapsed. Returns false to permanently stop the manager when the
|
||||
// service lacks the privileges needed to spawn cross-session.
|
||||
func (m *sessionManager) maybeSpawnAgent(sid uint32) bool {
|
||||
if m.agentProc != 0 || sid == 0xFFFFFFFF || !time.Now().After(m.nextSpawnAt) {
|
||||
return true
|
||||
}
|
||||
// Reap any orphan still holding the agent port from a previous
|
||||
// service instance, only on our very first spawn. Once we own
|
||||
// an agent, we manage its lifecycle ourselves and never need to
|
||||
// kill an unknown listener; if a kill+respawn races on port
|
||||
// release, the spawn-failure backoff handles it without forcing
|
||||
// a synchronous wait or duplicate kill.
|
||||
if !m.everSpawned {
|
||||
reapOrphanOnPort(m.port)
|
||||
}
|
||||
token, err := generateAuthToken()
|
||||
if err != nil {
|
||||
log.Warnf("generate agent auth token: %v", err)
|
||||
return true
|
||||
}
|
||||
m.authToken = token
|
||||
h, err := spawnAgentInSession(sid, m.port, m.authToken, m.jobHandle)
|
||||
if err != nil {
|
||||
m.authToken = ""
|
||||
if errors.Is(err, windows.ERROR_PRIVILEGE_NOT_HELD) {
|
||||
// SE_TCB_NAME (token-impersonation across sessions) is only
|
||||
// granted to SYSTEM. Without it spawnAgent will fail every 2
|
||||
// seconds forever: log once and give up.
|
||||
log.Warnf("VNC service mode disabled: agent spawn requires SYSTEM privileges (got: %v)", err)
|
||||
return false
|
||||
}
|
||||
log.Warnf("spawn agent in session %d: %v", sid, err)
|
||||
return true
|
||||
}
|
||||
m.agentProc = h
|
||||
m.agentStartedAt = time.Now()
|
||||
m.everSpawned = true
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *sessionManager) killAgent() {
|
||||
if m.agentProc == 0 {
|
||||
return
|
||||
}
|
||||
_ = windows.TerminateProcess(m.agentProc, 0)
|
||||
_ = windows.CloseHandle(m.agentProc)
|
||||
m.agentProc = 0
|
||||
log.Info("killed old agent")
|
||||
}
|
||||
|
||||
// relogAgentOutput reads log lines from the agent's stderr pipe and
|
||||
// relogs them with the service's formatter. The *os.File owns the
|
||||
// underlying handle, so closing it suffices.
|
||||
func relogAgentOutput(pipe windows.Handle) {
|
||||
f := os.NewFile(uintptr(pipe), "vnc-agent-stderr")
|
||||
defer func() { _ = f.Close() }()
|
||||
relogAgentStream(f)
|
||||
}
|
||||
|
||||
// logCleanupCall invokes a Windows syscall used solely as a cleanup primitive
|
||||
// (CloseClipboard, ReleaseDC, etc.) and logs failures at trace level. The
|
||||
// indirection lets us satisfy errcheck without scattering ignored returns at
|
||||
// each call site, while still capturing diagnostic info when the OS reports
|
||||
// a failure.
|
||||
func logCleanupCall(name string, proc *windows.LazyProc) {
|
||||
r, _, err := proc.Call()
|
||||
if r == 0 && err != nil && err != windows.NTE_OP_OK {
|
||||
log.Tracef("%s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// logCleanupCallArgs is logCleanupCall with one argument; common pattern for
|
||||
// release-by-handle syscalls.
|
||||
func logCleanupCallArgs(name string, proc *windows.LazyProc, args ...uintptr) {
|
||||
r, _, err := proc.Call(args...)
|
||||
if r == 0 && err != nil && err != windows.NTE_OP_OK {
|
||||
log.Tracef("%s: %v", name, err)
|
||||
}
|
||||
}
|
||||
643
client/vnc/server/capture_darwin.go
Normal file
643
client/vnc/server/capture_darwin.go
Normal file
@@ -0,0 +1,643 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"image"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var darwinCaptureOnce sync.Once
|
||||
|
||||
var (
|
||||
cgMainDisplayID func() uint32
|
||||
cgDisplayPixelsWide func(uint32) uintptr
|
||||
cgDisplayPixelsHigh func(uint32) uintptr
|
||||
cgDisplayCreateImage func(uint32) uintptr
|
||||
cgImageGetWidth func(uintptr) uintptr
|
||||
cgImageGetHeight func(uintptr) uintptr
|
||||
cgImageGetBytesPerRow func(uintptr) uintptr
|
||||
cgImageGetBitsPerPixel func(uintptr) uintptr
|
||||
cgImageGetDataProvider func(uintptr) uintptr
|
||||
cgDataProviderCopyData func(uintptr) uintptr
|
||||
cgImageRelease func(uintptr)
|
||||
cfDataGetLength func(uintptr) int64
|
||||
cfDataGetBytePtr func(uintptr) uintptr
|
||||
cfRelease func(uintptr)
|
||||
cgRequestScreenCaptureAccess func() bool
|
||||
cgEventCreate func(uintptr) uintptr
|
||||
cgEventGetLocation func(uintptr) cgPoint
|
||||
darwinCaptureReady bool
|
||||
)
|
||||
|
||||
// cgPoint mirrors CoreGraphics CGPoint: two doubles, 16 bytes, returned
|
||||
// in registers on Darwin amd64/arm64. Used to receive cursor coordinates
|
||||
// from CGEventGetLocation via purego.
|
||||
type cgPoint struct {
|
||||
X, Y float64
|
||||
}
|
||||
|
||||
func initDarwinCapture() {
|
||||
darwinCaptureOnce.Do(func() {
|
||||
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load CoreGraphics: %v", err)
|
||||
return
|
||||
}
|
||||
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load CoreFoundation: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
purego.RegisterLibFunc(&cgMainDisplayID, cg, "CGMainDisplayID")
|
||||
purego.RegisterLibFunc(&cgDisplayPixelsWide, cg, "CGDisplayPixelsWide")
|
||||
purego.RegisterLibFunc(&cgDisplayPixelsHigh, cg, "CGDisplayPixelsHigh")
|
||||
purego.RegisterLibFunc(&cgDisplayCreateImage, cg, "CGDisplayCreateImage")
|
||||
purego.RegisterLibFunc(&cgImageGetWidth, cg, "CGImageGetWidth")
|
||||
purego.RegisterLibFunc(&cgImageGetHeight, cg, "CGImageGetHeight")
|
||||
purego.RegisterLibFunc(&cgImageGetBytesPerRow, cg, "CGImageGetBytesPerRow")
|
||||
purego.RegisterLibFunc(&cgImageGetBitsPerPixel, cg, "CGImageGetBitsPerPixel")
|
||||
purego.RegisterLibFunc(&cgImageGetDataProvider, cg, "CGImageGetDataProvider")
|
||||
purego.RegisterLibFunc(&cgDataProviderCopyData, cg, "CGDataProviderCopyData")
|
||||
purego.RegisterLibFunc(&cgImageRelease, cg, "CGImageRelease")
|
||||
purego.RegisterLibFunc(&cfDataGetLength, cf, "CFDataGetLength")
|
||||
purego.RegisterLibFunc(&cfDataGetBytePtr, cf, "CFDataGetBytePtr")
|
||||
purego.RegisterLibFunc(&cfRelease, cf, "CFRelease")
|
||||
|
||||
// CGRequestScreenCaptureAccess (macOS 11+) prompts on first call and
|
||||
// is a cheap no-op once granted. The Preflight companion is unreliable
|
||||
// on Sequoia (returns false even when access is granted), so we drive
|
||||
// the permission flow from actual capture failures instead.
|
||||
if sym, err := purego.Dlsym(cg, "CGRequestScreenCaptureAccess"); err == nil {
|
||||
purego.RegisterFunc(&cgRequestScreenCaptureAccess, sym)
|
||||
}
|
||||
// CGEventCreate / CGEventGetLocation feed the cursor position used
|
||||
// by remote-cursor compositing. Optional; absence reports as a
|
||||
// position-source error and disables that feature on this host.
|
||||
if sym, err := purego.Dlsym(cg, "CGEventCreate"); err == nil {
|
||||
purego.RegisterFunc(&cgEventCreate, sym)
|
||||
}
|
||||
if sym, err := purego.Dlsym(cg, "CGEventGetLocation"); err == nil {
|
||||
purego.RegisterFunc(&cgEventGetLocation, sym)
|
||||
}
|
||||
|
||||
darwinCaptureReady = true
|
||||
})
|
||||
}
|
||||
|
||||
// CGCapturer captures the macOS main display using Core Graphics.
|
||||
type CGCapturer struct {
|
||||
displayID uint32
|
||||
w, h int
|
||||
// downscale is 1 for pixel-perfect, 2 for Retina 2:1 box-filter downscale.
|
||||
downscale int
|
||||
hashSeed maphash.Seed
|
||||
lastHash uint64
|
||||
hasHash bool
|
||||
// cursor lazily binds the private CGSCreateCurrentCursorImage symbol
|
||||
// so we can emit the Cursor pseudo-encoding without a per-frame cost
|
||||
// on builds that never query it.
|
||||
cursorOnce sync.Once
|
||||
cursor *cgCursor
|
||||
}
|
||||
|
||||
// PrimeScreenCapturePermission triggers the macOS Screen Recording
|
||||
// permission prompt without creating a full capturer. The platform wiring
|
||||
// calls this at VNC-server enable time so the user sees the prompt the
|
||||
// moment they turn the feature on. CGRequestScreenCaptureAccess is a
|
||||
// no-op when the grant already exists, so calling it on every enable is
|
||||
// cheap and safe.
|
||||
func PrimeScreenCapturePermission() {
|
||||
initDarwinCapture()
|
||||
if !darwinCaptureReady {
|
||||
return
|
||||
}
|
||||
if cgRequestScreenCaptureAccess != nil {
|
||||
cgRequestScreenCaptureAccess()
|
||||
}
|
||||
}
|
||||
|
||||
// notifyScreenRecordingMissing nudges the user once per agent process to
|
||||
// approve Screen Recording. The capturer init retries on backoff when the
|
||||
// grant is missing; without the sync.Once we would reopen System Settings
|
||||
// every tick and flood the daemon log with the same warning.
|
||||
var screenRecordingNotifyOnce sync.Once
|
||||
|
||||
func notifyScreenRecordingMissing() {
|
||||
screenRecordingNotifyOnce.Do(func() {
|
||||
if cgRequestScreenCaptureAccess != nil {
|
||||
cgRequestScreenCaptureAccess()
|
||||
}
|
||||
openPrivacyPane("Privacy_ScreenCapture")
|
||||
log.Warn("Screen Recording permission not granted. " +
|
||||
"Opened System Settings > Privacy & Security > Screen Recording; enable netbird and restart.")
|
||||
})
|
||||
}
|
||||
|
||||
// NewCGCapturer creates a screen capturer for the main display.
|
||||
func NewCGCapturer() (*CGCapturer, error) {
|
||||
initDarwinCapture()
|
||||
if !darwinCaptureReady {
|
||||
return nil, fmt.Errorf("CoreGraphics not available")
|
||||
}
|
||||
|
||||
displayID := cgMainDisplayID()
|
||||
c := &CGCapturer{displayID: displayID, downscale: 1, hashSeed: maphash.MakeSeed()}
|
||||
|
||||
img, err := c.Capture()
|
||||
if err != nil {
|
||||
notifyScreenRecordingMissing()
|
||||
return nil, fmt.Errorf("probe capture: %w", err)
|
||||
}
|
||||
nativeW := img.Rect.Dx()
|
||||
nativeH := img.Rect.Dy()
|
||||
c.hasHash = false
|
||||
if nativeW == 0 || nativeH == 0 {
|
||||
return nil, errors.New("display dimensions are zero")
|
||||
}
|
||||
|
||||
logicalW := int(cgDisplayPixelsWide(displayID))
|
||||
logicalH := int(cgDisplayPixelsHigh(displayID))
|
||||
|
||||
// Enable 2:1 downscale on Retina unless explicitly disabled. Cuts pixel
|
||||
// count 4x, shrinking convert, diff, and wire data proportionally.
|
||||
if !retinaDownscaleDisabled() && nativeW >= 2*logicalW && nativeH >= 2*logicalH && nativeW%2 == 0 && nativeH%2 == 0 {
|
||||
c.downscale = 2
|
||||
}
|
||||
c.w = nativeW / c.downscale
|
||||
c.h = nativeH / c.downscale
|
||||
|
||||
log.Infof("macOS capturer ready: %dx%d (native %dx%d, logical %dx%d, downscale=%d, display=%d)",
|
||||
c.w, c.h, nativeW, nativeH, logicalW, logicalH, c.downscale, displayID)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func retinaDownscaleDisabled() bool {
|
||||
v := os.Getenv(EnvVNCDisableDownscale)
|
||||
if v == "" {
|
||||
return false
|
||||
}
|
||||
disabled, err := strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
log.Warnf("parse %s: %v", EnvVNCDisableDownscale, err)
|
||||
return false
|
||||
}
|
||||
return disabled
|
||||
}
|
||||
|
||||
// Width returns the screen width.
|
||||
func (c *CGCapturer) Width() int { return c.w }
|
||||
|
||||
// Height returns the screen height.
|
||||
func (c *CGCapturer) Height() int { return c.h }
|
||||
|
||||
// Capture returns the current screen as an RGBA image.
|
||||
// CaptureInto writes a fresh frame directly into dst, skipping the
|
||||
// per-frame image.RGBA allocation that Capture() does. Returns
|
||||
// errFrameUnchanged when the screen hash matches the prior call.
|
||||
func (c *CGCapturer) CaptureInto(dst *image.RGBA) error {
|
||||
cgImage := cgDisplayCreateImage(c.displayID)
|
||||
if cgImage == 0 {
|
||||
return fmt.Errorf("CGDisplayCreateImage returned nil (screen recording permission?)")
|
||||
}
|
||||
defer cgImageRelease(cgImage)
|
||||
w := int(cgImageGetWidth(cgImage))
|
||||
h := int(cgImageGetHeight(cgImage))
|
||||
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
|
||||
bpp := int(cgImageGetBitsPerPixel(cgImage))
|
||||
provider := cgImageGetDataProvider(cgImage)
|
||||
if provider == 0 {
|
||||
return fmt.Errorf("CGImageGetDataProvider returned nil")
|
||||
}
|
||||
cfData := cgDataProviderCopyData(provider)
|
||||
if cfData == 0 {
|
||||
return fmt.Errorf("CGDataProviderCopyData returned nil")
|
||||
}
|
||||
defer cfRelease(cfData)
|
||||
dataLen := int(cfDataGetLength(cfData))
|
||||
dataPtr := cfDataGetBytePtr(cfData)
|
||||
if dataPtr == 0 || dataLen == 0 {
|
||||
return fmt.Errorf("empty image data")
|
||||
}
|
||||
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
|
||||
hash := maphash.Bytes(c.hashSeed, src)
|
||||
if c.hasHash && hash == c.lastHash {
|
||||
return errFrameUnchanged
|
||||
}
|
||||
c.lastHash = hash
|
||||
c.hasHash = true
|
||||
|
||||
ds := c.downscale
|
||||
if ds < 1 {
|
||||
ds = 1
|
||||
}
|
||||
outW := w / ds
|
||||
outH := h / ds
|
||||
if dst.Rect.Dx() != outW || dst.Rect.Dy() != outH {
|
||||
return fmt.Errorf("dst size mismatch: dst=%dx%d capturer=%dx%d",
|
||||
dst.Rect.Dx(), dst.Rect.Dy(), outW, outH)
|
||||
}
|
||||
bytesPerPixel := bpp / 8
|
||||
if bytesPerPixel == 4 && ds == 1 {
|
||||
convertBGRAToRGBA(dst.Pix, dst.Stride, src, bytesPerRow, w, h)
|
||||
return nil
|
||||
}
|
||||
if bytesPerPixel == 4 && ds == 2 {
|
||||
convertBGRAToRGBADownscale2(dst.Pix, dst.Stride, src, bytesPerRow, outW, outH)
|
||||
return nil
|
||||
}
|
||||
for row := 0; row < outH; row++ {
|
||||
srcOff := row * ds * bytesPerRow
|
||||
dstOff := row * dst.Stride
|
||||
for col := 0; col < outW; col++ {
|
||||
si := srcOff + col*ds*bytesPerPixel
|
||||
di := dstOff + col*4
|
||||
dst.Pix[di+0] = src[si+2]
|
||||
dst.Pix[di+1] = src[si+1]
|
||||
dst.Pix[di+2] = src[si+0]
|
||||
dst.Pix[di+3] = 0xff
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CGCapturer) Capture() (*image.RGBA, error) {
|
||||
cgImage := cgDisplayCreateImage(c.displayID)
|
||||
if cgImage == 0 {
|
||||
return nil, fmt.Errorf("CGDisplayCreateImage returned nil (screen recording permission?)")
|
||||
}
|
||||
defer cgImageRelease(cgImage)
|
||||
|
||||
w := int(cgImageGetWidth(cgImage))
|
||||
h := int(cgImageGetHeight(cgImage))
|
||||
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
|
||||
bpp := int(cgImageGetBitsPerPixel(cgImage))
|
||||
|
||||
provider := cgImageGetDataProvider(cgImage)
|
||||
if provider == 0 {
|
||||
return nil, fmt.Errorf("CGImageGetDataProvider returned nil")
|
||||
}
|
||||
|
||||
cfData := cgDataProviderCopyData(provider)
|
||||
if cfData == 0 {
|
||||
return nil, fmt.Errorf("CGDataProviderCopyData returned nil")
|
||||
}
|
||||
defer cfRelease(cfData)
|
||||
|
||||
dataLen := int(cfDataGetLength(cfData))
|
||||
dataPtr := cfDataGetBytePtr(cfData)
|
||||
if dataPtr == 0 || dataLen == 0 {
|
||||
return nil, fmt.Errorf("empty image data")
|
||||
}
|
||||
|
||||
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
|
||||
|
||||
hash := maphash.Bytes(c.hashSeed, src)
|
||||
if c.hasHash && hash == c.lastHash {
|
||||
return nil, errFrameUnchanged
|
||||
}
|
||||
c.lastHash = hash
|
||||
c.hasHash = true
|
||||
|
||||
ds := c.downscale
|
||||
if ds < 1 {
|
||||
ds = 1
|
||||
}
|
||||
outW := w / ds
|
||||
outH := h / ds
|
||||
img := image.NewRGBA(image.Rect(0, 0, outW, outH))
|
||||
|
||||
bytesPerPixel := bpp / 8
|
||||
switch {
|
||||
case bytesPerPixel == 4 && ds == 1:
|
||||
convertBGRAToRGBA(img.Pix, img.Stride, src, bytesPerRow, w, h)
|
||||
case bytesPerPixel == 4 && ds == 2:
|
||||
convertBGRAToRGBADownscale2(img.Pix, img.Stride, src, bytesPerRow, outW, outH)
|
||||
default:
|
||||
convertBGRAToRGBAGeneric(img.Pix, img.Stride, src, bytesPerRow, bgraDownscaleParams{outW: outW, outH: outH, bytesPerPixel: bytesPerPixel, ds: ds})
|
||||
}
|
||||
|
||||
return img, nil
|
||||
}
|
||||
|
||||
type bgraDownscaleParams struct {
|
||||
outW, outH, bytesPerPixel, ds int
|
||||
}
|
||||
|
||||
// convertBGRAToRGBAGeneric is the slow per-pixel fallback for non-4-bytes
|
||||
// or non-1/2 downscale formats. Always available regardless of the source
|
||||
// format quirks the fast paths optimize for.
|
||||
func convertBGRAToRGBAGeneric(dst []byte, dstStride int, src []byte, srcStride int, p bgraDownscaleParams) {
|
||||
for row := 0; row < p.outH; row++ {
|
||||
srcOff := row * p.ds * srcStride
|
||||
dstOff := row * dstStride
|
||||
for col := 0; col < p.outW; col++ {
|
||||
si := srcOff + col*p.ds*p.bytesPerPixel
|
||||
di := dstOff + col*4
|
||||
dst[di+0] = src[si+2]
|
||||
dst[di+1] = src[si+1]
|
||||
dst[di+2] = src[si+0]
|
||||
dst[di+3] = 0xff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// convertBGRAToRGBADownscale2 averages every 2x2 BGRA block into one RGBA
|
||||
// output pixel, parallelised across GOMAXPROCS cores. outW and outH are the
|
||||
// destination dimensions (source is 2*outW by 2*outH).
|
||||
func convertBGRAToRGBADownscale2(dst []byte, dstStride int, src []byte, srcStride, outW, outH int) {
|
||||
workers := runtime.GOMAXPROCS(0)
|
||||
if workers > outH {
|
||||
workers = outH
|
||||
}
|
||||
if workers < 1 || outH < 32 {
|
||||
workers = 1
|
||||
}
|
||||
|
||||
convertRows := func(y0, y1 int) {
|
||||
for row := y0; row < y1; row++ {
|
||||
srcRow0 := 2 * row * srcStride
|
||||
srcRow1 := srcRow0 + srcStride
|
||||
dstOff := row * dstStride
|
||||
for col := 0; col < outW; col++ {
|
||||
s0 := srcRow0 + col*8
|
||||
s1 := srcRow1 + col*8
|
||||
b := (uint32(src[s0]) + uint32(src[s0+4]) + uint32(src[s1]) + uint32(src[s1+4])) >> 2
|
||||
g := (uint32(src[s0+1]) + uint32(src[s0+5]) + uint32(src[s1+1]) + uint32(src[s1+5])) >> 2
|
||||
r := (uint32(src[s0+2]) + uint32(src[s0+6]) + uint32(src[s1+2]) + uint32(src[s1+6])) >> 2
|
||||
di := dstOff + col*4
|
||||
dst[di+0] = byte(r)
|
||||
dst[di+1] = byte(g)
|
||||
dst[di+2] = byte(b)
|
||||
dst[di+3] = 0xff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if workers == 1 {
|
||||
convertRows(0, outH)
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
chunk := (outH + workers - 1) / workers
|
||||
for i := 0; i < workers; i++ {
|
||||
y0 := i * chunk
|
||||
y1 := y0 + chunk
|
||||
if y1 > outH {
|
||||
y1 = outH
|
||||
}
|
||||
if y0 >= y1 {
|
||||
break
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(y0, y1 int) {
|
||||
defer wg.Done()
|
||||
convertRows(y0, y1)
|
||||
}(y0, y1)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// convertBGRAToRGBA swaps R/B channels using uint32 word operations, and
|
||||
// parallelises across GOMAXPROCS cores for large images.
|
||||
func convertBGRAToRGBA(dst []byte, dstStride int, src []byte, srcStride, w, h int) {
|
||||
workers := runtime.GOMAXPROCS(0)
|
||||
if workers > h {
|
||||
workers = h
|
||||
}
|
||||
if workers < 1 || h < 64 {
|
||||
workers = 1
|
||||
}
|
||||
|
||||
convertRows := func(y0, y1 int) {
|
||||
rowBytes := w * 4
|
||||
for row := y0; row < y1; row++ {
|
||||
dstRow := dst[row*dstStride : row*dstStride+rowBytes]
|
||||
srcRow := src[row*srcStride : row*srcStride+rowBytes]
|
||||
dstU := unsafe.Slice((*uint32)(unsafe.Pointer(&dstRow[0])), w)
|
||||
srcU := unsafe.Slice((*uint32)(unsafe.Pointer(&srcRow[0])), w)
|
||||
for i, p := range srcU {
|
||||
dstU[i] = (p & 0xff00ff00) | ((p & 0x000000ff) << 16) | ((p & 0x00ff0000) >> 16) | 0xff000000
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if workers == 1 {
|
||||
convertRows(0, h)
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
chunk := (h + workers - 1) / workers
|
||||
for i := 0; i < workers; i++ {
|
||||
y0 := i * chunk
|
||||
y1 := y0 + chunk
|
||||
if y1 > h {
|
||||
y1 = h
|
||||
}
|
||||
if y0 >= y1 {
|
||||
break
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(y0, y1 int) {
|
||||
defer wg.Done()
|
||||
convertRows(y0, y1)
|
||||
}(y0, y1)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// MacPoller wraps CGCapturer with a staleness-cached on-demand Capture:
|
||||
// sessions drive captures themselves from their encoder goroutine, so we
|
||||
// don't need a background ticker. The last result is cached for a short
|
||||
// window so concurrent sessions coalesce into one capture.
|
||||
//
|
||||
// The capturer is allocated lazily on first use and released when all
|
||||
// clients disconnect. Init is retried with backoff because the user may
|
||||
// grant Screen Recording permission while the server is already running.
|
||||
type MacPoller struct {
|
||||
mu sync.Mutex
|
||||
|
||||
capturer *CGCapturer
|
||||
w, h int
|
||||
|
||||
lastFrame *image.RGBA
|
||||
lastAt time.Time
|
||||
|
||||
clients atomic.Int32
|
||||
initFails int
|
||||
initBackoffUntil time.Time
|
||||
closed bool
|
||||
}
|
||||
|
||||
// macInitRetryBackoffFor returns the delay we wait between init attempts
|
||||
// after consecutive failures. Screen Recording permission is a one-shot
|
||||
// user grant, so after several failures we back off aggressively.
|
||||
func macInitRetryBackoffFor(fails int) time.Duration {
|
||||
switch {
|
||||
case fails > 15:
|
||||
return 30 * time.Second
|
||||
case fails > 5:
|
||||
return 10 * time.Second
|
||||
default:
|
||||
return 2 * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
// NewMacPoller creates a lazy on-demand capturer for the macOS display.
|
||||
func NewMacPoller() *MacPoller {
|
||||
return &MacPoller{}
|
||||
}
|
||||
|
||||
// Wake is a no-op retained for API compatibility. With on-demand capture
|
||||
// there is no background retry loop to kick: init happens on the next
|
||||
// Capture/ClientConnect call.
|
||||
func (p *MacPoller) Wake() {
|
||||
// intentional no-op
|
||||
}
|
||||
|
||||
// ClientConnect increments the active client count and eagerly initialises
|
||||
// the capturer so the first FBUpdateRequest doesn't pay the init cost.
|
||||
func (p *MacPoller) ClientConnect() {
|
||||
if p.clients.Add(1) == 1 {
|
||||
p.mu.Lock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// ClientDisconnect decrements the active client count. On the last
|
||||
// disconnect the capturer is released.
|
||||
func (p *MacPoller) ClientDisconnect() {
|
||||
if p.clients.Add(-1) == 0 {
|
||||
p.mu.Lock()
|
||||
p.capturer = nil
|
||||
p.lastFrame = nil
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Close releases all resources.
|
||||
func (p *MacPoller) Close() {
|
||||
p.mu.Lock()
|
||||
p.closed = true
|
||||
p.capturer = nil
|
||||
p.lastFrame = nil
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
// Width returns the screen width. Triggers lazy init if needed.
|
||||
func (p *MacPoller) Width() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
return p.w
|
||||
}
|
||||
|
||||
// Height returns the screen height. Triggers lazy init if needed.
|
||||
func (p *MacPoller) Height() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
return p.h
|
||||
}
|
||||
|
||||
// CaptureInto fills dst directly via the underlying capturer, bypassing
|
||||
// the freshness cache.
|
||||
func (p *MacPoller) CaptureInto(dst *image.RGBA) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
err := p.capturer.CaptureInto(dst)
|
||||
if errors.Is(err, errFrameUnchanged) {
|
||||
// Caller (session) treats this as "no change"; the dst buffer
|
||||
// keeps its prior contents from the previous capture cycle so
|
||||
// the diff stays meaningful.
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
p.capturer = nil
|
||||
return fmt.Errorf("macos capture: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Capture returns a fresh frame, serving from the short-lived cache if a
|
||||
// previous caller captured within freshWindow. Handles the
|
||||
// errFrameUnchanged return from CGCapturer by reusing the cached frame.
|
||||
func (p *MacPoller) Capture() (*image.RGBA, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.lastFrame != nil && time.Since(p.lastAt) < freshWindow {
|
||||
return p.lastFrame, nil
|
||||
}
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
img, err := p.capturer.Capture()
|
||||
if errors.Is(err, errFrameUnchanged) {
|
||||
if p.lastFrame != nil {
|
||||
p.lastAt = time.Now()
|
||||
return p.lastFrame, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if err != nil {
|
||||
// Drop the capturer so the next call retries init; the display stream
|
||||
// can die if the session changes or permissions are revoked.
|
||||
p.capturer = nil
|
||||
return nil, fmt.Errorf("macos capture: %w", err)
|
||||
}
|
||||
p.lastFrame = img
|
||||
p.lastAt = time.Now()
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// ensureCapturerLocked initialises the underlying CGCapturer if needed.
|
||||
// Caller must hold p.mu.
|
||||
func (p *MacPoller) ensureCapturerLocked() error {
|
||||
if p.closed {
|
||||
return fmt.Errorf("poller closed")
|
||||
}
|
||||
if p.capturer != nil {
|
||||
return nil
|
||||
}
|
||||
if time.Now().Before(p.initBackoffUntil) {
|
||||
return fmt.Errorf("macOS capturer unavailable (retry scheduled)")
|
||||
}
|
||||
c, err := NewCGCapturer()
|
||||
if err != nil {
|
||||
p.initFails++
|
||||
p.initBackoffUntil = time.Now().Add(macInitRetryBackoffFor(p.initFails))
|
||||
if p.initFails == 1 || p.initFails%10 == 0 {
|
||||
log.Warnf("macOS capturer: %v (attempt %d)", err, p.initFails)
|
||||
} else {
|
||||
log.Debugf("macOS capturer: %v (attempt %d)", err, p.initFails)
|
||||
}
|
||||
return err
|
||||
}
|
||||
p.initFails = 0
|
||||
p.capturer = c
|
||||
p.w, p.h = c.Width(), c.Height()
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ ScreenCapturer = (*MacPoller)(nil)
|
||||
99
client/vnc/server/capture_dxgi_windows.go
Normal file
99
client/vnc/server/capture_dxgi_windows.go
Normal file
@@ -0,0 +1,99 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
|
||||
"github.com/kirides/go-d3d/d3d11"
|
||||
"github.com/kirides/go-d3d/outputduplication"
|
||||
)
|
||||
|
||||
// dxgiCapturer captures the desktop using DXGI Desktop Duplication.
|
||||
// Provides GPU-accelerated capture with native dirty rect tracking.
|
||||
// Only works from the interactive user session, not Session 0.
|
||||
//
|
||||
// Uses a double-buffer: DXGI writes into img, then we copy to the current
|
||||
// output buffer and hand it out. Alternating between two output buffers
|
||||
// avoids allocating a new image.RGBA per frame (~8MB at 1080p, 30fps).
|
||||
type dxgiCapturer struct {
|
||||
dup *outputduplication.OutputDuplicator
|
||||
device *d3d11.ID3D11Device
|
||||
ctx *d3d11.ID3D11DeviceContext
|
||||
img *image.RGBA
|
||||
out [2]*image.RGBA
|
||||
outIdx int
|
||||
width int
|
||||
height int
|
||||
}
|
||||
|
||||
func newDXGICapturer() (*dxgiCapturer, error) {
|
||||
device, deviceCtx, err := d3d11.NewD3D11Device()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create D3D11 device: %w", err)
|
||||
}
|
||||
|
||||
dup, err := outputduplication.NewIDXGIOutputDuplication(device, deviceCtx, 0)
|
||||
if err != nil {
|
||||
device.Release()
|
||||
deviceCtx.Release()
|
||||
return nil, fmt.Errorf("create output duplication: %w", err)
|
||||
}
|
||||
|
||||
w, h := screenSize()
|
||||
if w == 0 || h == 0 {
|
||||
dup.Release()
|
||||
device.Release()
|
||||
deviceCtx.Release()
|
||||
return nil, fmt.Errorf("screen dimensions are zero")
|
||||
}
|
||||
|
||||
rect := image.Rect(0, 0, w, h)
|
||||
c := &dxgiCapturer{
|
||||
dup: dup,
|
||||
device: device,
|
||||
ctx: deviceCtx,
|
||||
img: image.NewRGBA(rect),
|
||||
out: [2]*image.RGBA{image.NewRGBA(rect), image.NewRGBA(rect)},
|
||||
width: w,
|
||||
height: h,
|
||||
}
|
||||
|
||||
// Grab the initial frame with a longer timeout to ensure we have
|
||||
// a valid image before returning.
|
||||
_ = dup.GetImage(c.img, 2000)
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *dxgiCapturer) capture() (*image.RGBA, error) {
|
||||
err := c.dup.GetImage(c.img, 100)
|
||||
if err != nil && !errors.Is(err, outputduplication.ErrNoImageYet) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Copy into the next output buffer. The DesktopCapturer hands out the
|
||||
// returned pointer to VNC sessions that read pixels concurrently, so we
|
||||
// alternate between two pre-allocated buffers instead of allocating per frame.
|
||||
out := c.out[c.outIdx]
|
||||
c.outIdx ^= 1
|
||||
copy(out.Pix, c.img.Pix)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *dxgiCapturer) close() {
|
||||
if c.dup != nil {
|
||||
c.dup.Release()
|
||||
c.dup = nil
|
||||
}
|
||||
if c.ctx != nil {
|
||||
c.ctx.Release()
|
||||
c.ctx = nil
|
||||
}
|
||||
if c.device != nil {
|
||||
c.device.Release()
|
||||
c.device = nil
|
||||
}
|
||||
}
|
||||
148
client/vnc/server/capture_fb_freebsd.go
Normal file
148
client/vnc/server/capture_fb_freebsd.go
Normal file
@@ -0,0 +1,148 @@
|
||||
//go:build freebsd
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// FreeBSD vt(4) framebuffer ioctl numbers from sys/fbio.h.
|
||||
//
|
||||
// #define FBIOGTYPE _IOR('F', 0, struct fbtype)
|
||||
//
|
||||
// _IOR(g, n, t) on FreeBSD: dir=2 (read) <<30 | (sizeof(t) & 0x1fff)<<16
|
||||
// | (g<<8) | n. sizeof(struct fbtype)=24 → 0x40184600.
|
||||
const fbioGType = 0x40184600
|
||||
|
||||
func defaultFBPath() string { return "/dev/ttyv0" }
|
||||
|
||||
// fbType mirrors FreeBSD's struct fbtype.
|
||||
type fbType struct {
|
||||
FbType int32
|
||||
FbHeight int32
|
||||
FbWidth int32
|
||||
FbDepth int32
|
||||
FbCMSize int32
|
||||
FbSize int32
|
||||
}
|
||||
|
||||
// FBCapturer reads pixels from FreeBSD's vt(4) framebuffer device. The
|
||||
// vt(4) console exposes the active framebuffer via ttyv0 with FBIOGTYPE
|
||||
// for geometry and mmap for backing memory. Pixel layout is assumed to
|
||||
// be 32bpp BGRA (the common case for KMS-backed vt); fbtype doesn't
|
||||
// expose channel offsets, so we don't try to handle exotic layouts here.
|
||||
type FBCapturer struct {
|
||||
mu sync.Mutex
|
||||
path string
|
||||
fd int
|
||||
mmap []byte
|
||||
w, h int
|
||||
bpp int
|
||||
stride int
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// NewFBCapturer opens the given vt(4) device and queries its geometry.
|
||||
func NewFBCapturer(path string) (*FBCapturer, error) {
|
||||
if path == "" {
|
||||
path = defaultFBPath()
|
||||
}
|
||||
fd, err := unix.Open(path, unix.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open %s: %w", path, err)
|
||||
}
|
||||
|
||||
var fbt fbType
|
||||
if _, _, e := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), fbioGType, uintptr(unsafe.Pointer(&fbt))); e != 0 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("FBIOGTYPE: %v", e)
|
||||
}
|
||||
if fbt.FbDepth != 16 && fbt.FbDepth != 24 && fbt.FbDepth != 32 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("unsupported framebuffer depth: %d", fbt.FbDepth)
|
||||
}
|
||||
if fbt.FbWidth <= 0 || fbt.FbHeight <= 0 || fbt.FbSize <= 0 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("invalid framebuffer geometry: %dx%d size=%d", fbt.FbWidth, fbt.FbHeight, fbt.FbSize)
|
||||
}
|
||||
|
||||
mm, err := unix.Mmap(fd, 0, int(fbt.FbSize), unix.PROT_READ, unix.MAP_SHARED)
|
||||
if err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("mmap %s: %w (vt may not support mmap on this driver, e.g. virtio_gpu)", path, err)
|
||||
}
|
||||
|
||||
bpp := int(fbt.FbDepth)
|
||||
stride := int(fbt.FbWidth) * (bpp / 8)
|
||||
c := &FBCapturer{
|
||||
path: path,
|
||||
fd: fd, // valid fd >= 0; we use -1 as the closed sentinel
|
||||
mmap: mm,
|
||||
w: int(fbt.FbWidth),
|
||||
h: int(fbt.FbHeight),
|
||||
bpp: bpp,
|
||||
stride: stride,
|
||||
}
|
||||
log.Infof("framebuffer capturer ready: %s %dx%d bpp=%d (freebsd vt)", path, c.w, c.h, c.bpp)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Width returns the framebuffer width.
|
||||
func (c *FBCapturer) Width() int { return c.w }
|
||||
|
||||
// Height returns the framebuffer height.
|
||||
func (c *FBCapturer) Height() int { return c.h }
|
||||
|
||||
// Capture allocates a fresh image and fills it with the current
|
||||
// framebuffer contents.
|
||||
func (c *FBCapturer) Capture() (*image.RGBA, error) {
|
||||
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
|
||||
if err := c.CaptureInto(img); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// CaptureInto reads the framebuffer directly into dst.Pix. Assumes BGRA
|
||||
// for 32bpp; the FreeBSD fbtype struct doesn't expose channel offsets.
|
||||
func (c *FBCapturer) CaptureInto(dst *image.RGBA) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if dst.Rect.Dx() != c.w || dst.Rect.Dy() != c.h {
|
||||
return fmt.Errorf("dst size mismatch: dst=%dx%d fb=%dx%d",
|
||||
dst.Rect.Dx(), dst.Rect.Dy(), c.w, c.h)
|
||||
}
|
||||
switch c.bpp {
|
||||
case 32:
|
||||
// vt(4) on KMS framebuffers is BGRA: byte 0=B, 1=G, 2=R.
|
||||
swizzleBGRAtoRGBA(dst.Pix, c.mmap[:c.h*c.stride])
|
||||
case 24:
|
||||
swizzleFB24(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h)
|
||||
case 16:
|
||||
swizzleFB16RGB565(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close releases the framebuffer mmap and file descriptor. Serialized with
|
||||
// CaptureInto via c.mu so an in-flight capture can't read freed memory.
|
||||
func (c *FBCapturer) Close() {
|
||||
c.closeOnce.Do(func() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.mmap != nil {
|
||||
_ = unix.Munmap(c.mmap)
|
||||
c.mmap = nil
|
||||
}
|
||||
if c.fd >= 0 {
|
||||
_ = unix.Close(c.fd)
|
||||
c.fd = -1
|
||||
}
|
||||
})
|
||||
}
|
||||
229
client/vnc/server/capture_fb_linux.go
Normal file
229
client/vnc/server/capture_fb_linux.go
Normal file
@@ -0,0 +1,229 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"image"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// Linux framebuffer ioctls (linux/fb.h).
|
||||
const (
|
||||
fbioGetVScreenInfo = 0x4600
|
||||
fbioGetFScreenInfo = 0x4602
|
||||
)
|
||||
|
||||
func defaultFBPath() string { return "/dev/fb0" }
|
||||
|
||||
// fbVarScreenInfo mirrors the kernel's fb_var_screeninfo. Only the
|
||||
// fields we use are mapped; the rest are absorbed into _padN.
|
||||
type fbVarScreenInfo struct {
|
||||
Xres, Yres uint32
|
||||
XresVirtual, YresVirtual uint32
|
||||
XOffset, YOffset uint32
|
||||
BitsPerPixel uint32
|
||||
Grayscale uint32
|
||||
RedOffset, RedLen, RedMSBR uint32
|
||||
GreenOffset, GreenLen, GreenMSBR uint32
|
||||
BlueOffset, BlueLen, BlueMSBR uint32
|
||||
TranspOffset, TranspLen, TranspM uint32
|
||||
NonStd uint32
|
||||
Activate uint32
|
||||
Height, Width uint32
|
||||
AccelFlags uint32
|
||||
PixClock uint32
|
||||
LeftMargin, RightMargin uint32
|
||||
UpperMargin, LowerMargin uint32
|
||||
HsyncLen, VsyncLen uint32
|
||||
Sync uint32
|
||||
Vmode uint32
|
||||
Rotate uint32
|
||||
Colorspace uint32
|
||||
_pad [4]uint32
|
||||
}
|
||||
|
||||
// fbFixScreenInfo mirrors fb_fix_screeninfo. We only need LineLength.
|
||||
type fbFixScreenInfo struct {
|
||||
IDStr [16]byte
|
||||
SmemStart uint64
|
||||
SmemLen uint32
|
||||
Type uint32
|
||||
TypeAux uint32
|
||||
Visual uint32
|
||||
XPanStep uint16
|
||||
YPanStep uint16
|
||||
YWrapStep uint16
|
||||
_pad0 uint16
|
||||
LineLength uint32
|
||||
MmioStart uint64
|
||||
MmioLen uint32
|
||||
Accel uint32
|
||||
Capabilities uint16
|
||||
_reserved [2]uint16
|
||||
}
|
||||
|
||||
// FBCapturer reads pixels straight from the Linux framebuffer device.
|
||||
// Used as a fallback when X11 isn't available, e.g. on a headless box at
|
||||
// the kernel console or the display manager's pre-login screen on machines
|
||||
// without an Xorg server. The framebuffer must be mmap()-able under our
|
||||
// process privileges (typically the netbird service runs as root).
|
||||
type FBCapturer struct {
|
||||
mu sync.Mutex
|
||||
path string
|
||||
fd int
|
||||
mmap []byte
|
||||
w, h int
|
||||
bpp int
|
||||
stride int
|
||||
rOff uint32
|
||||
gOff uint32
|
||||
bOff uint32
|
||||
rLen uint32
|
||||
gLen uint32
|
||||
bLen uint32
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// NewFBCapturer opens the given framebuffer device (/dev/fbN) and
|
||||
// queries its current geometry + pixel format.
|
||||
func NewFBCapturer(path string) (*FBCapturer, error) {
|
||||
if path == "" {
|
||||
path = "/dev/fb0"
|
||||
}
|
||||
fd, err := unix.Open(path, unix.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open %s: %w", path, err)
|
||||
}
|
||||
|
||||
var vinfo fbVarScreenInfo
|
||||
if _, _, e := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), fbioGetVScreenInfo, uintptr(unsafe.Pointer(&vinfo))); e != 0 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("FBIOGET_VSCREENINFO: %v", e)
|
||||
}
|
||||
var finfo fbFixScreenInfo
|
||||
if _, _, e := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), fbioGetFScreenInfo, uintptr(unsafe.Pointer(&finfo))); e != 0 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("FBIOGET_FSCREENINFO: %v", e)
|
||||
}
|
||||
|
||||
bpp := int(vinfo.BitsPerPixel)
|
||||
if bpp != 16 && bpp != 24 && bpp != 32 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("unsupported framebuffer bpp: %d", bpp)
|
||||
}
|
||||
|
||||
size := int(finfo.LineLength) * int(vinfo.Yres)
|
||||
if size <= 0 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("invalid framebuffer dimensions: stride=%d h=%d", finfo.LineLength, vinfo.Yres)
|
||||
}
|
||||
|
||||
mm, err := unix.Mmap(fd, 0, size, unix.PROT_READ, unix.MAP_SHARED)
|
||||
if err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("mmap %s: %w", path, err)
|
||||
}
|
||||
|
||||
c := &FBCapturer{
|
||||
path: path,
|
||||
fd: fd,
|
||||
mmap: mm,
|
||||
w: int(vinfo.Xres),
|
||||
h: int(vinfo.Yres),
|
||||
bpp: bpp,
|
||||
stride: int(finfo.LineLength),
|
||||
rOff: vinfo.RedOffset,
|
||||
gOff: vinfo.GreenOffset,
|
||||
bOff: vinfo.BlueOffset,
|
||||
rLen: vinfo.RedLen,
|
||||
gLen: vinfo.GreenLen,
|
||||
bLen: vinfo.BlueLen,
|
||||
}
|
||||
log.Infof("framebuffer capturer ready: %s %dx%d bpp=%d r=%d/%d g=%d/%d b=%d/%d",
|
||||
path, c.w, c.h, c.bpp, c.rOff, c.rLen, c.gOff, c.gLen, c.bOff, c.bLen)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Width returns the framebuffer width in pixels.
|
||||
func (c *FBCapturer) Width() int { return c.w }
|
||||
|
||||
// Height returns the framebuffer height in pixels.
|
||||
func (c *FBCapturer) Height() int { return c.h }
|
||||
|
||||
// Capture allocates a fresh image and fills it with the current
|
||||
// framebuffer contents.
|
||||
func (c *FBCapturer) Capture() (*image.RGBA, error) {
|
||||
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
|
||||
if err := c.CaptureInto(img); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// CaptureInto reads the framebuffer directly into dst.Pix.
|
||||
func (c *FBCapturer) CaptureInto(dst *image.RGBA) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if dst.Rect.Dx() != c.w || dst.Rect.Dy() != c.h {
|
||||
return fmt.Errorf("dst size mismatch: dst=%dx%d fb=%dx%d",
|
||||
dst.Rect.Dx(), dst.Rect.Dy(), c.w, c.h)
|
||||
}
|
||||
|
||||
switch c.bpp {
|
||||
case 32:
|
||||
swizzleFB32(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h, channelShifts{R: c.rOff, G: c.gOff, B: c.bOff})
|
||||
case 24:
|
||||
swizzleFB24(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h)
|
||||
case 16:
|
||||
swizzleFB16RGB565(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close releases the framebuffer mmap and file descriptor. Serialized with
|
||||
// CaptureInto via c.mu so an in-flight capture can't read freed memory.
|
||||
func (c *FBCapturer) Close() {
|
||||
c.closeOnce.Do(func() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.mmap != nil {
|
||||
_ = unix.Munmap(c.mmap)
|
||||
c.mmap = nil
|
||||
}
|
||||
if c.fd >= 0 {
|
||||
_ = unix.Close(c.fd)
|
||||
c.fd = -1
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// channelShifts groups the bit offsets for the R/G/B channels in a packed
|
||||
// uint32 framebuffer pixel. Bundling avoids drowning per-row callers in a
|
||||
// 9-parameter signature.
|
||||
type channelShifts struct {
|
||||
R, G, B uint32
|
||||
}
|
||||
|
||||
// swizzleFB32 handles 32-bit framebuffers with arbitrary R/G/B channel
|
||||
// offsets. Pulls one pixel per uint32, then masks each channel into the
|
||||
// destination RGBA byte order.
|
||||
func swizzleFB32(dst []byte, dstStride int, src []byte, srcStride, w, h int, shifts channelShifts) {
|
||||
for y := 0; y < h; y++ {
|
||||
srcRow := src[y*srcStride : y*srcStride+w*4]
|
||||
dstRow := dst[y*dstStride:]
|
||||
for x := 0; x < w; x++ {
|
||||
pix := binary.LittleEndian.Uint32(srcRow[x*4 : x*4+4])
|
||||
dstRow[x*4+0] = byte(pix >> shifts.R)
|
||||
dstRow[x*4+1] = byte(pix >> shifts.G)
|
||||
dstRow[x*4+2] = byte(pix >> shifts.B)
|
||||
dstRow[x*4+3] = 0xff
|
||||
}
|
||||
}
|
||||
}
|
||||
149
client/vnc/server/capture_fb_unix.go
Normal file
149
client/vnc/server/capture_fb_unix.go
Normal file
@@ -0,0 +1,149 @@
|
||||
//go:build unix && !darwin && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"image"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// FBPoller wraps FBCapturer with the same lifecycle (ClientConnect /
|
||||
// ClientDisconnect, lazy init) as X11Poller, so it slots into the same
|
||||
// session plumbing without code changes upstream. The concrete
|
||||
// FBCapturer is platform-specific (capture_fb_linux.go / _freebsd.go);
|
||||
// this file owns the cross-platform glue.
|
||||
type FBPoller struct {
|
||||
mu sync.Mutex
|
||||
path string
|
||||
capturer *FBCapturer
|
||||
w, h int
|
||||
clients int32
|
||||
}
|
||||
|
||||
// NewFBPoller returns a poller that opens path on first use. Empty path
|
||||
// defaults to /dev/fb0 on Linux and /dev/ttyv0 on FreeBSD.
|
||||
func NewFBPoller(path string) *FBPoller {
|
||||
if path == "" {
|
||||
path = defaultFBPath()
|
||||
}
|
||||
return &FBPoller{path: path}
|
||||
}
|
||||
|
||||
// ClientConnect eagerly initialises the capturer on first connect.
|
||||
func (p *FBPoller) ClientConnect() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.clients++
|
||||
if p.clients == 1 {
|
||||
_ = p.ensureCapturerLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// ClientDisconnect closes the capturer when the last client leaves.
|
||||
func (p *FBPoller) ClientDisconnect() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.clients--
|
||||
if p.clients <= 0 && p.capturer != nil {
|
||||
p.capturer.Close()
|
||||
p.capturer = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Width returns the framebuffer width, doing lazy init if needed.
|
||||
func (p *FBPoller) Width() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
return p.w
|
||||
}
|
||||
|
||||
// Height returns the framebuffer height, doing lazy init if needed.
|
||||
func (p *FBPoller) Height() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
return p.h
|
||||
}
|
||||
|
||||
// Capture takes a fresh frame.
|
||||
func (p *FBPoller) Capture() (*image.RGBA, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p.capturer.Capture()
|
||||
}
|
||||
|
||||
// CaptureInto fills dst directly.
|
||||
func (p *FBPoller) CaptureInto(dst *image.RGBA) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
return p.capturer.CaptureInto(dst)
|
||||
}
|
||||
|
||||
// Close releases all framebuffer resources.
|
||||
func (p *FBPoller) Close() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.capturer != nil {
|
||||
p.capturer.Close()
|
||||
p.capturer = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *FBPoller) ensureCapturerLocked() error {
|
||||
if p.capturer != nil {
|
||||
return nil
|
||||
}
|
||||
c, err := NewFBCapturer(p.path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.capturer = c
|
||||
p.w, p.h = c.Width(), c.Height()
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ ScreenCapturer = (*FBPoller)(nil)
|
||||
var _ captureIntoer = (*FBPoller)(nil)
|
||||
|
||||
// swizzleFB24 handles 24-bit packed framebuffers (B,G,R triplets).
|
||||
// Shared between Linux and FreeBSD framebuffer paths.
|
||||
func swizzleFB24(dst []byte, dstStride int, src []byte, srcStride, w, h int) {
|
||||
for y := 0; y < h; y++ {
|
||||
srcRow := src[y*srcStride : y*srcStride+w*3]
|
||||
dstRow := dst[y*dstStride:]
|
||||
for x := 0; x < w; x++ {
|
||||
b := srcRow[x*3+0]
|
||||
g := srcRow[x*3+1]
|
||||
r := srcRow[x*3+2]
|
||||
dstRow[x*4+0] = r
|
||||
dstRow[x*4+1] = g
|
||||
dstRow[x*4+2] = b
|
||||
dstRow[x*4+3] = 0xff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// swizzleFB16RGB565 handles 16bpp RGB 565 framebuffers.
|
||||
func swizzleFB16RGB565(dst []byte, dstStride int, src []byte, srcStride, w, h int) {
|
||||
for y := 0; y < h; y++ {
|
||||
srcRow := src[y*srcStride : y*srcStride+w*2]
|
||||
dstRow := dst[y*dstStride:]
|
||||
for x := 0; x < w; x++ {
|
||||
pix := uint16(srcRow[x*2]) | uint16(srcRow[x*2+1])<<8
|
||||
r := byte((pix >> 11) & 0x1f)
|
||||
g := byte((pix >> 5) & 0x3f)
|
||||
b := byte(pix & 0x1f)
|
||||
dstRow[x*4+0] = (r << 3) | (r >> 2)
|
||||
dstRow[x*4+1] = (g << 2) | (g >> 4)
|
||||
dstRow[x*4+2] = (b << 3) | (b >> 2)
|
||||
dstRow[x*4+3] = 0xff
|
||||
}
|
||||
}
|
||||
}
|
||||
556
client/vnc/server/capture_windows.go
Normal file
556
client/vnc/server/capture_windows.go
Normal file
@@ -0,0 +1,556 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
gdi32 = windows.NewLazySystemDLL("gdi32.dll")
|
||||
user32 = windows.NewLazySystemDLL("user32.dll")
|
||||
|
||||
procGetDC = user32.NewProc("GetDC")
|
||||
procReleaseDC = user32.NewProc("ReleaseDC")
|
||||
procCreateCompatDC = gdi32.NewProc("CreateCompatibleDC")
|
||||
procCreateDIBSection = gdi32.NewProc("CreateDIBSection")
|
||||
procSelectObject = gdi32.NewProc("SelectObject")
|
||||
procDeleteObject = gdi32.NewProc("DeleteObject")
|
||||
procDeleteDC = gdi32.NewProc("DeleteDC")
|
||||
procBitBlt = gdi32.NewProc("BitBlt")
|
||||
procGetSystemMetrics = user32.NewProc("GetSystemMetrics")
|
||||
|
||||
// Desktop switching for service/Session 0 capture.
|
||||
procOpenInputDesktop = user32.NewProc("OpenInputDesktop")
|
||||
procSetThreadDesktop = user32.NewProc("SetThreadDesktop")
|
||||
procCloseDesktop = user32.NewProc("CloseDesktop")
|
||||
procOpenWindowStation = user32.NewProc("OpenWindowStationW")
|
||||
procSetProcessWindowStation = user32.NewProc("SetProcessWindowStation")
|
||||
procCloseWindowStation = user32.NewProc("CloseWindowStation")
|
||||
procGetUserObjectInformationW = user32.NewProc("GetUserObjectInformationW")
|
||||
)
|
||||
|
||||
const uoiName = 2
|
||||
|
||||
const (
|
||||
smCxScreen = 0
|
||||
smCyScreen = 1
|
||||
srccopy = 0x00CC0020
|
||||
captureBlt = 0x40000000
|
||||
dibRgbColors = 0
|
||||
)
|
||||
|
||||
type bitmapInfoHeader struct {
|
||||
Size uint32
|
||||
Width int32
|
||||
Height int32
|
||||
Planes uint16
|
||||
BitCount uint16
|
||||
Compression uint32
|
||||
SizeImage uint32
|
||||
XPelsPerMeter int32
|
||||
YPelsPerMeter int32
|
||||
ClrUsed uint32
|
||||
ClrImportant uint32
|
||||
}
|
||||
|
||||
type bitmapInfo struct {
|
||||
Header bitmapInfoHeader
|
||||
}
|
||||
|
||||
// setupInteractiveWindowStation associates the current process with WinSta0,
|
||||
// the interactive window station. This is required for a SYSTEM service in
|
||||
// Session 0 to call OpenInputDesktop for screen capture and input injection.
|
||||
func setupInteractiveWindowStation() error {
|
||||
name, err := windows.UTF16PtrFromString("WinSta0")
|
||||
if err != nil {
|
||||
return fmt.Errorf("UTF16 WinSta0: %w", err)
|
||||
}
|
||||
hWinSta, _, err := procOpenWindowStation.Call(
|
||||
uintptr(unsafe.Pointer(name)),
|
||||
0,
|
||||
uintptr(windows.MAXIMUM_ALLOWED),
|
||||
)
|
||||
if hWinSta == 0 {
|
||||
return fmt.Errorf("OpenWindowStation(WinSta0): %w", err)
|
||||
}
|
||||
r, _, err := procSetProcessWindowStation.Call(hWinSta)
|
||||
if r == 0 {
|
||||
_, _, _ = procCloseWindowStation.Call(hWinSta)
|
||||
return fmt.Errorf("SetProcessWindowStation: %w", err)
|
||||
}
|
||||
log.Info("process window station set to WinSta0 (interactive)")
|
||||
return nil
|
||||
}
|
||||
|
||||
func screenSize() (int, int) {
|
||||
w, _, _ := procGetSystemMetrics.Call(uintptr(smCxScreen))
|
||||
h, _, _ := procGetSystemMetrics.Call(uintptr(smCyScreen))
|
||||
return int(w), int(h)
|
||||
}
|
||||
|
||||
func getDesktopName(hDesk uintptr) string {
|
||||
var buf [256]uint16
|
||||
var needed uint32
|
||||
_, _, _ = procGetUserObjectInformationW.Call(hDesk, uoiName,
|
||||
uintptr(unsafe.Pointer(&buf[0])), 512,
|
||||
uintptr(unsafe.Pointer(&needed)))
|
||||
return windows.UTF16ToString(buf[:])
|
||||
}
|
||||
|
||||
// switchToInputDesktop opens the desktop currently receiving user input
|
||||
// and sets it as the calling OS thread's desktop. Must be called from a
|
||||
// goroutine locked to its OS thread via runtime.LockOSThread().
|
||||
func switchToInputDesktop() (bool, string) {
|
||||
hDesk, _, _ := procOpenInputDesktop.Call(0, 0, uintptr(windows.MAXIMUM_ALLOWED))
|
||||
if hDesk == 0 {
|
||||
return false, ""
|
||||
}
|
||||
name := getDesktopName(hDesk)
|
||||
ret, _, _ := procSetThreadDesktop.Call(hDesk)
|
||||
_, _, _ = procCloseDesktop.Call(hDesk)
|
||||
return ret != 0, name
|
||||
}
|
||||
|
||||
// gdiCapturer captures the desktop screen using GDI BitBlt.
|
||||
// GDI objects (DC, DIBSection) are allocated once and reused across frames.
|
||||
type gdiCapturer struct {
|
||||
mu sync.Mutex
|
||||
width int
|
||||
height int
|
||||
|
||||
// Pre-allocated GDI resources, reused across captures.
|
||||
memDC uintptr
|
||||
bmp uintptr
|
||||
bits uintptr
|
||||
}
|
||||
|
||||
func newGDICapturer() (*gdiCapturer, error) {
|
||||
w, h := screenSize()
|
||||
if w == 0 || h == 0 {
|
||||
return nil, fmt.Errorf("screen dimensions are zero")
|
||||
}
|
||||
c := &gdiCapturer{width: w, height: h}
|
||||
if err := c.allocGDI(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// allocGDI pre-allocates the compatible DC and DIB section for reuse.
|
||||
func (c *gdiCapturer) allocGDI() error {
|
||||
screenDC, _, _ := procGetDC.Call(0)
|
||||
if screenDC == 0 {
|
||||
return fmt.Errorf("GetDC returned 0")
|
||||
}
|
||||
defer func() { _, _, _ = procReleaseDC.Call(0, screenDC) }()
|
||||
|
||||
memDC, _, _ := procCreateCompatDC.Call(screenDC)
|
||||
if memDC == 0 {
|
||||
return fmt.Errorf("CreateCompatibleDC returned 0")
|
||||
}
|
||||
|
||||
bi := bitmapInfo{
|
||||
Header: bitmapInfoHeader{
|
||||
Size: uint32(unsafe.Sizeof(bitmapInfoHeader{})),
|
||||
Width: int32(c.width),
|
||||
Height: -int32(c.height), // negative = top-down DIB
|
||||
Planes: 1,
|
||||
BitCount: 32,
|
||||
},
|
||||
}
|
||||
|
||||
var bits uintptr
|
||||
bmp, _, _ := procCreateDIBSection.Call(
|
||||
screenDC,
|
||||
uintptr(unsafe.Pointer(&bi)),
|
||||
dibRgbColors,
|
||||
uintptr(unsafe.Pointer(&bits)),
|
||||
0, 0,
|
||||
)
|
||||
if bmp == 0 || bits == 0 {
|
||||
_, _, _ = procDeleteDC.Call(memDC)
|
||||
return fmt.Errorf("CreateDIBSection returned 0")
|
||||
}
|
||||
|
||||
_, _, _ = procSelectObject.Call(memDC, bmp)
|
||||
|
||||
c.memDC = memDC
|
||||
c.bmp = bmp
|
||||
c.bits = bits
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *gdiCapturer) close() { c.freeGDI() }
|
||||
|
||||
// freeGDI releases pre-allocated GDI resources.
|
||||
func (c *gdiCapturer) freeGDI() {
|
||||
if c.bmp != 0 {
|
||||
_, _, _ = procDeleteObject.Call(c.bmp)
|
||||
c.bmp = 0
|
||||
}
|
||||
if c.memDC != 0 {
|
||||
_, _, _ = procDeleteDC.Call(c.memDC)
|
||||
c.memDC = 0
|
||||
}
|
||||
c.bits = 0
|
||||
}
|
||||
|
||||
func (c *gdiCapturer) capture() (*image.RGBA, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.memDC == 0 {
|
||||
return nil, fmt.Errorf("GDI resources not allocated")
|
||||
}
|
||||
|
||||
screenDC, _, _ := procGetDC.Call(0)
|
||||
if screenDC == 0 {
|
||||
return nil, fmt.Errorf("GetDC returned 0")
|
||||
}
|
||||
defer func() { _, _, _ = procReleaseDC.Call(0, screenDC) }()
|
||||
|
||||
// SRCCOPY|CAPTUREBLT: CAPTUREBLT forces inclusion of layered/topmost
|
||||
// windows in the capture and is required for GDI BitBlt to return live
|
||||
// pixels when the session is rendered through RDP / DWM-composited
|
||||
// surfaces. Without it BitBlt reads the backing-store DIB which is
|
||||
// often empty (all-black) on RDP and headless sessions.
|
||||
ret, _, _ := procBitBlt.Call(c.memDC, 0, 0, uintptr(c.width), uintptr(c.height),
|
||||
screenDC, 0, 0, srccopy|captureBlt)
|
||||
if ret == 0 {
|
||||
return nil, fmt.Errorf("BitBlt returned 0")
|
||||
}
|
||||
|
||||
n := c.width * c.height * 4
|
||||
raw := unsafe.Slice((*byte)(unsafe.Pointer(c.bits)), n)
|
||||
|
||||
// GDI gives BGRA, the RFB encoder expects RGBA (img.Pix layout).
|
||||
// Swap R and B in bulk using uint32 operations (one load + mask + shift
|
||||
// per pixel instead of three separate byte assignments).
|
||||
img := image.NewRGBA(image.Rect(0, 0, c.width, c.height))
|
||||
swizzleBGRAtoRGBA(img.Pix, raw)
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// DesktopCapturer captures the interactive desktop, handling desktop transitions
|
||||
// (login screen, UAC prompts). A dedicated OS-locked goroutine continuously
|
||||
// captures frames on demand via a dedicated OS-locked goroutine (required
|
||||
// because DXGI's D3D11 device context is not thread-safe). Sessions drive
|
||||
// timing by calling Capture(); a short staleness cache coalesces concurrent
|
||||
// requests. Capture pauses automatically when no clients are connected.
|
||||
type DesktopCapturer struct {
|
||||
mu sync.Mutex
|
||||
w, h int
|
||||
|
||||
// lastFrame/lastAt implement a small staleness cache so multiple
|
||||
// near-simultaneous Capture calls share one DXGI round-trip.
|
||||
lastFrame *image.RGBA
|
||||
lastAt time.Time
|
||||
|
||||
// clients tracks the number of active VNC sessions. When zero, the
|
||||
// worker goroutine releases the underlying capturer.
|
||||
clients atomic.Int32
|
||||
|
||||
// reqCh carries capture requests from sessions to the OS-locked worker.
|
||||
reqCh chan captureReq
|
||||
// wake is signaled when a client connects and the worker should resume.
|
||||
wake chan struct{}
|
||||
// done is closed when Close is called, terminating the worker.
|
||||
done chan struct{}
|
||||
|
||||
// cursorState holds the latest cursor sprite sampled by the worker.
|
||||
// The worker calls GetCursorInfo every capture and decodes a new
|
||||
// sprite only when the HCURSOR changes.
|
||||
cursorState cursorState
|
||||
}
|
||||
|
||||
// captureReq is a single capture request awaiting a reply. Reply channel is
|
||||
// buffered to size 1 so the worker never blocks on a sender that's gone.
|
||||
type captureReq struct {
|
||||
reply chan captureReply
|
||||
}
|
||||
|
||||
type captureReply struct {
|
||||
img *image.RGBA
|
||||
err error
|
||||
}
|
||||
|
||||
// NewDesktopCapturer creates an on-demand capturer for the active desktop.
|
||||
func NewDesktopCapturer() *DesktopCapturer {
|
||||
c := &DesktopCapturer{
|
||||
wake: make(chan struct{}, 1),
|
||||
done: make(chan struct{}),
|
||||
reqCh: make(chan captureReq),
|
||||
}
|
||||
go c.worker()
|
||||
return c
|
||||
}
|
||||
|
||||
// ClientConnect increments the active client count, resuming capture if needed.
|
||||
func (c *DesktopCapturer) ClientConnect() {
|
||||
c.clients.Add(1)
|
||||
select {
|
||||
case c.wake <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// ClientDisconnect decrements the active client count.
|
||||
func (c *DesktopCapturer) ClientDisconnect() {
|
||||
c.clients.Add(-1)
|
||||
}
|
||||
|
||||
// Close stops the capture loop and releases resources.
|
||||
func (c *DesktopCapturer) Close() {
|
||||
select {
|
||||
case <-c.done:
|
||||
default:
|
||||
close(c.done)
|
||||
}
|
||||
}
|
||||
|
||||
// Width returns the current screen width, triggering a capture if the
|
||||
// worker hasn't initialised yet. validateCapturer depends on Width/Height
|
||||
// becoming non-zero promptly after ClientConnect so it doesn't reject
|
||||
// brand-new sessions.
|
||||
func (c *DesktopCapturer) Width() int {
|
||||
c.mu.Lock()
|
||||
w := c.w
|
||||
c.mu.Unlock()
|
||||
if w == 0 && c.clients.Load() > 0 {
|
||||
_, _ = c.Capture()
|
||||
c.mu.Lock()
|
||||
w = c.w
|
||||
c.mu.Unlock()
|
||||
}
|
||||
return w
|
||||
}
|
||||
|
||||
// Height returns the current screen height, triggering a capture if the
|
||||
// worker hasn't initialised yet (see Width). Returns 0 while no client is
|
||||
// connected so callers don't deadlock against a parked worker.
|
||||
func (c *DesktopCapturer) Height() int {
|
||||
c.mu.Lock()
|
||||
h := c.h
|
||||
c.mu.Unlock()
|
||||
if h == 0 && c.clients.Load() > 0 {
|
||||
_, _ = c.Capture()
|
||||
c.mu.Lock()
|
||||
h = c.h
|
||||
c.mu.Unlock()
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// Capture returns a freshly captured frame, serving from a short staleness
|
||||
// cache when multiple sessions ask within freshWindow of each other. All
|
||||
// real DXGI/GDI work happens on the OS-locked worker goroutine.
|
||||
func (c *DesktopCapturer) Capture() (*image.RGBA, error) {
|
||||
c.mu.Lock()
|
||||
if c.lastFrame != nil && time.Since(c.lastAt) < freshWindow {
|
||||
img := c.lastFrame
|
||||
c.mu.Unlock()
|
||||
return img, nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
reply := make(chan captureReply, 1)
|
||||
select {
|
||||
case c.reqCh <- captureReq{reply: reply}:
|
||||
case <-c.done:
|
||||
return nil, fmt.Errorf("capturer closed")
|
||||
}
|
||||
select {
|
||||
case r := <-reply:
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.lastFrame = r.img
|
||||
c.lastAt = time.Now()
|
||||
c.mu.Unlock()
|
||||
return r.img, nil
|
||||
case <-c.done:
|
||||
return nil, fmt.Errorf("capturer closed")
|
||||
}
|
||||
}
|
||||
|
||||
// waitForClient blocks until a client connects or the capturer is closed.
|
||||
func (c *DesktopCapturer) waitForClient() bool {
|
||||
if c.clients.Load() > 0 {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-c.wake:
|
||||
return true
|
||||
case <-c.done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// worker owns DXGI/GDI state on its OS-locked thread and services capture
|
||||
// requests from sessions. No background ticker: a capture happens only when
|
||||
// a session asks for one (throttled by Capture()'s staleness cache).
|
||||
func (c *DesktopCapturer) worker() {
|
||||
runtime.LockOSThread()
|
||||
|
||||
// When running as a Windows service (Session 0), we need to attach to the
|
||||
// interactive window station before OpenInputDesktop will succeed.
|
||||
if err := setupInteractiveWindowStation(); err != nil {
|
||||
log.Warnf("attach to interactive window station: %v", err)
|
||||
}
|
||||
|
||||
w := &captureWorker{c: c}
|
||||
defer w.closeCapturer()
|
||||
|
||||
for {
|
||||
if !c.waitForClient() {
|
||||
return
|
||||
}
|
||||
// Drop the capturer when all clients have disconnected so we don't
|
||||
// hold the DXGI duplication or GDI DC on an idle peer.
|
||||
if c.clients.Load() <= 0 {
|
||||
w.closeCapturer()
|
||||
continue
|
||||
}
|
||||
if !w.handleNextRequest() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// frameCapturer is the per-backend interface used by the worker. DXGI and
|
||||
// GDI implementations both satisfy it.
|
||||
type frameCapturer interface {
|
||||
capture() (*image.RGBA, error)
|
||||
close()
|
||||
}
|
||||
|
||||
// captureWorker owns the worker goroutine's mutable state. Extracted into a
|
||||
// struct so the request/desktop/init logic can live on small methods and the
|
||||
// outer worker() stays a thin loop.
|
||||
type captureWorker struct {
|
||||
c *DesktopCapturer
|
||||
cap frameCapturer
|
||||
desktopFails int
|
||||
lastDesktop string
|
||||
nextInitRetry time.Time
|
||||
cursor cursorSampler
|
||||
}
|
||||
|
||||
// handleNextRequest waits for either shutdown or a capture request and runs
|
||||
// the request through prepCapturer/capture. Returns false when the worker
|
||||
// should exit.
|
||||
func (w *captureWorker) handleNextRequest() bool {
|
||||
select {
|
||||
case <-w.c.done:
|
||||
return false
|
||||
case req := <-w.c.reqCh:
|
||||
w.serveRequest(req)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (w *captureWorker) serveRequest(req captureReq) {
|
||||
fc, err := w.prepCapturer()
|
||||
if err != nil {
|
||||
req.reply <- captureReply{err: err}
|
||||
return
|
||||
}
|
||||
img, err := fc.capture()
|
||||
if err != nil {
|
||||
log.Debugf("capture: %v", err)
|
||||
w.closeCapturer()
|
||||
w.nextInitRetry = time.Now().Add(100 * time.Millisecond)
|
||||
req.reply <- captureReply{err: err}
|
||||
return
|
||||
}
|
||||
if snap, err := w.cursor.sample(); err != nil {
|
||||
w.c.cursorState.store(&cursorSnapshot{err: err})
|
||||
} else {
|
||||
w.c.cursorState.store(snap)
|
||||
}
|
||||
req.reply <- captureReply{img: img}
|
||||
}
|
||||
|
||||
// prepCapturer switches to the input desktop, handles desktop-change
|
||||
// teardown, and creates the underlying capturer on demand. Backoff state is
|
||||
// tracked across calls via w.nextInitRetry.
|
||||
func (w *captureWorker) prepCapturer() (frameCapturer, error) {
|
||||
if err := w.refreshDesktop(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if w.cap != nil {
|
||||
return w.cap, nil
|
||||
}
|
||||
if time.Now().Before(w.nextInitRetry) {
|
||||
return nil, fmt.Errorf("capturer init backing off")
|
||||
}
|
||||
fc, err := w.createCapturer()
|
||||
if err != nil {
|
||||
w.nextInitRetry = time.Now().Add(500 * time.Millisecond)
|
||||
return nil, err
|
||||
}
|
||||
w.cap = fc
|
||||
sw, sh := screenSize()
|
||||
w.c.mu.Lock()
|
||||
w.c.w, w.c.h = sw, sh
|
||||
w.c.mu.Unlock()
|
||||
log.Infof("screen capturer ready: %dx%d", sw, sh)
|
||||
return w.cap, nil
|
||||
}
|
||||
|
||||
// refreshDesktop tracks the active input desktop. When it changes (lock
|
||||
// screen, fast-user-switch) the existing capturer is dropped so the next
|
||||
// call rebuilds one against the new desktop.
|
||||
func (w *captureWorker) refreshDesktop() error {
|
||||
ok, desk := switchToInputDesktop()
|
||||
if !ok {
|
||||
w.desktopFails++
|
||||
if w.desktopFails == 1 || w.desktopFails%100 == 0 {
|
||||
log.Warnf("switchToInputDesktop failed (count=%d), no interactive desktop session?", w.desktopFails)
|
||||
}
|
||||
return fmt.Errorf("no interactive desktop")
|
||||
}
|
||||
if w.desktopFails > 0 {
|
||||
log.Infof("switchToInputDesktop recovered after %d failures, desktop=%q", w.desktopFails, desk)
|
||||
w.desktopFails = 0
|
||||
}
|
||||
if desk != w.lastDesktop {
|
||||
log.Infof("desktop changed: %q -> %q", w.lastDesktop, desk)
|
||||
w.lastDesktop = desk
|
||||
w.closeCapturer()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *captureWorker) createCapturer() (frameCapturer, error) {
|
||||
dc, err := newDXGICapturer()
|
||||
if err == nil {
|
||||
log.Info("using DXGI Desktop Duplication for capture")
|
||||
return dc, nil
|
||||
}
|
||||
log.Warnf("DXGI Desktop Duplication unavailable, falling back to slower GDI BitBlt: %v", err)
|
||||
gc, err := newGDICapturer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Info("using GDI BitBlt for capture")
|
||||
return gc, nil
|
||||
}
|
||||
|
||||
func (w *captureWorker) closeCapturer() {
|
||||
if w.cap != nil {
|
||||
w.cap.close()
|
||||
w.cap = nil
|
||||
}
|
||||
}
|
||||
533
client/vnc/server/capture_x11.go
Normal file
533
client/vnc/server/capture_x11.go
Normal file
@@ -0,0 +1,533 @@
|
||||
//go:build unix && !darwin && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/jezek/xgb"
|
||||
"github.com/jezek/xgb/xproto"
|
||||
)
|
||||
|
||||
const (
|
||||
// x11SocketDir is the well-known directory where X servers create
|
||||
// their abstract UNIX-domain sockets, named "X<display>". Used both
|
||||
// for auto-detecting an existing display and for placing/probing
|
||||
// sockets of virtual sessions we spawn.
|
||||
x11SocketDir = "/tmp/.X11-unix"
|
||||
|
||||
// envDisplay is the X11 display selector environment variable.
|
||||
envDisplay = "DISPLAY"
|
||||
// envXAuthority points X clients at the cookie file used to
|
||||
// authenticate against the running X server.
|
||||
envXAuthority = "XAUTHORITY"
|
||||
)
|
||||
|
||||
// X11Capturer captures the screen from an X11 display using the MIT-SHM extension.
|
||||
type X11Capturer struct {
|
||||
mu sync.Mutex
|
||||
conn *xgb.Conn
|
||||
screen *xproto.ScreenInfo
|
||||
w, h int
|
||||
shmID int
|
||||
shmAddr []byte
|
||||
shmSeg uint32
|
||||
useSHM bool
|
||||
// bufs double-buffers output images so the X11Poller's capture loop can
|
||||
// overwrite one while the session is still encoding the other. Before
|
||||
// this, a single reused buffer would race with the reader. Allocation
|
||||
// happens on first use and on geometry change.
|
||||
bufs [2]*image.RGBA
|
||||
cur int
|
||||
// cursor is the XFixes binding used to report the current sprite.
|
||||
// Allocated lazily on the first Cursor call. cursorInitErr latches
|
||||
// a permanent init failure so we stop retrying every frame.
|
||||
cursor *xfixesCursor
|
||||
cursorInitErr error
|
||||
}
|
||||
|
||||
// detectX11Display finds the active X11 display and sets DISPLAY/XAUTHORITY
|
||||
// environment variables if needed. This is required when running as a system
|
||||
// service where these vars aren't set.
|
||||
func detectX11Display() {
|
||||
if os.Getenv(envDisplay) != "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Try /proc first (Linux), then ps fallback (FreeBSD and others).
|
||||
if detectX11FromProc() {
|
||||
return
|
||||
}
|
||||
if detectX11FromSockets() {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// detectX11FromProc scans /proc/*/cmdline for Xorg (Linux).
|
||||
func detectX11FromProc() bool {
|
||||
entries, err := os.ReadDir("/proc")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() {
|
||||
continue
|
||||
}
|
||||
cmdline, err := os.ReadFile("/proc/" + e.Name() + "/cmdline")
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if display, auth := parseXorgArgs(splitCmdline(cmdline)); display != "" {
|
||||
setDisplayEnv(display, auth)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// detectX11FromSockets checks /tmp/.X11-unix/ for X sockets and uses ps
|
||||
// to find the auth file. Works on FreeBSD and other systems without /proc.
|
||||
func detectX11FromSockets() bool {
|
||||
entries, err := os.ReadDir(x11SocketDir)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Pick the lowest numeric display rather than the lexically first
|
||||
// entry, so X10 doesn't win over X2.
|
||||
minDisplay := -1
|
||||
for _, e := range entries {
|
||||
name := e.Name()
|
||||
if len(name) < 2 || name[0] != 'X' {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(name[1:])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if minDisplay < 0 || n < minDisplay {
|
||||
minDisplay = n
|
||||
}
|
||||
}
|
||||
if minDisplay < 0 {
|
||||
return false
|
||||
}
|
||||
display := ":" + strconv.Itoa(minDisplay)
|
||||
os.Setenv(envDisplay, display)
|
||||
auth := findXorgAuthFromPS()
|
||||
if auth != "" {
|
||||
os.Setenv(envXAuthority, auth)
|
||||
log.Infof("auto-detected DISPLAY=%s (from socket) XAUTHORITY=%s (from ps)", display, auth)
|
||||
} else {
|
||||
log.Infof("auto-detected DISPLAY=%s (from socket)", display)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// findXorgAuthFromPS runs ps to find Xorg and extract its -auth argument.
|
||||
func findXorgAuthFromPS() string {
|
||||
out, err := exec.Command("ps", "auxww").Output()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, line := range strings.Split(string(out), "\n") {
|
||||
if !strings.Contains(line, "Xorg") && !strings.Contains(line, "/X ") {
|
||||
continue
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
for i, f := range fields {
|
||||
if f == "-auth" && i+1 < len(fields) {
|
||||
return fields[i+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseXorgArgs(args []string) (display, auth string) {
|
||||
if len(args) == 0 {
|
||||
return "", ""
|
||||
}
|
||||
base := args[0]
|
||||
if !(base == "Xorg" || base == "X" || len(base) > 0 && base[len(base)-1] == 'X' ||
|
||||
strings.Contains(base, "/Xorg") || strings.Contains(base, "/X")) {
|
||||
return "", ""
|
||||
}
|
||||
for i, arg := range args[1:] {
|
||||
if len(arg) > 0 && arg[0] == ':' {
|
||||
display = arg
|
||||
}
|
||||
if arg == "-auth" && i+2 < len(args) {
|
||||
auth = args[i+2]
|
||||
}
|
||||
}
|
||||
return display, auth
|
||||
}
|
||||
|
||||
func setDisplayEnv(display, auth string) {
|
||||
os.Setenv(envDisplay, display)
|
||||
if auth != "" {
|
||||
os.Setenv(envXAuthority, auth)
|
||||
log.Infof("auto-detected DISPLAY=%s XAUTHORITY=%s", display, auth)
|
||||
return
|
||||
}
|
||||
log.Infof("auto-detected DISPLAY=%s", display)
|
||||
}
|
||||
|
||||
func splitCmdline(data []byte) []string {
|
||||
var args []string
|
||||
for _, b := range splitNull(data) {
|
||||
if len(b) > 0 {
|
||||
args = append(args, string(b))
|
||||
}
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func splitNull(data []byte) [][]byte {
|
||||
var parts [][]byte
|
||||
start := 0
|
||||
for i, b := range data {
|
||||
if b == 0 {
|
||||
parts = append(parts, data[start:i])
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
if start < len(data) {
|
||||
parts = append(parts, data[start:])
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// NewX11Capturer connects to the X11 display and sets up shared memory capture.
|
||||
func NewX11Capturer(display string) (*X11Capturer, error) {
|
||||
if display == "" {
|
||||
detectX11Display()
|
||||
display = os.Getenv(envDisplay)
|
||||
}
|
||||
if display == "" {
|
||||
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
|
||||
}
|
||||
|
||||
conn, err := xgb.NewConnDisplay(display)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
|
||||
}
|
||||
|
||||
setup := xproto.Setup(conn)
|
||||
if len(setup.Roots) == 0 {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("no X11 screens")
|
||||
}
|
||||
screen := setup.Roots[0]
|
||||
|
||||
c := &X11Capturer{
|
||||
conn: conn,
|
||||
screen: &screen,
|
||||
w: int(screen.WidthInPixels),
|
||||
h: int(screen.HeightInPixels),
|
||||
}
|
||||
|
||||
if err := c.initSHM(); err != nil {
|
||||
log.Debugf("X11 SHM not available, using slow GetImage: %v", err)
|
||||
}
|
||||
|
||||
log.Infof("X11 capturer ready: %dx%d (display=%s, shm=%v)", c.w, c.h, display, c.useSHM)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// initSHM is implemented in capture_x11_shm_linux.go (requires SysV SHM).
|
||||
// On platforms without SysV SHM (FreeBSD), a stub returns an error and
|
||||
// the capturer falls back to GetImage.
|
||||
|
||||
// Width returns the screen width.
|
||||
func (c *X11Capturer) Width() int { return c.w }
|
||||
|
||||
// Height returns the screen height.
|
||||
func (c *X11Capturer) Height() int { return c.h }
|
||||
|
||||
// Capture returns the current screen as an RGBA image.
|
||||
func (c *X11Capturer) Capture() (*image.RGBA, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.useSHM {
|
||||
return c.captureSHM()
|
||||
}
|
||||
return c.captureGetImage()
|
||||
}
|
||||
|
||||
// CaptureInto fills the caller's destination buffer in one pass. The
|
||||
// source path (SHM or fallback GetImage) writes directly into dst.Pix
|
||||
// instead of going through the X11Capturer's internal double-buffer,
|
||||
// saving one full-frame memcpy per capture.
|
||||
func (c *X11Capturer) CaptureInto(dst *image.RGBA) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if dst.Rect.Dx() != c.w || dst.Rect.Dy() != c.h {
|
||||
return fmt.Errorf("dst size mismatch: dst=%dx%d capturer=%dx%d",
|
||||
dst.Rect.Dx(), dst.Rect.Dy(), c.w, c.h)
|
||||
}
|
||||
if c.useSHM {
|
||||
return c.captureSHMInto(dst)
|
||||
}
|
||||
return c.captureGetImageInto(dst)
|
||||
}
|
||||
|
||||
func (c *X11Capturer) captureGetImageInto(dst *image.RGBA) error {
|
||||
cookie := xproto.GetImage(c.conn, xproto.ImageFormatZPixmap,
|
||||
xproto.Drawable(c.screen.Root),
|
||||
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF)
|
||||
reply, err := cookie.Reply()
|
||||
if err != nil {
|
||||
return fmt.Errorf("GetImage: %w", err)
|
||||
}
|
||||
n := c.w * c.h * 4
|
||||
if len(reply.Data) < n {
|
||||
return fmt.Errorf("GetImage returned %d bytes, expected %d", len(reply.Data), n)
|
||||
}
|
||||
swizzleBGRAtoRGBA(dst.Pix, reply.Data)
|
||||
return nil
|
||||
}
|
||||
|
||||
// captureSHM is implemented in capture_x11_shm_linux.go.
|
||||
|
||||
func (c *X11Capturer) captureGetImage() (*image.RGBA, error) {
|
||||
cookie := xproto.GetImage(c.conn, xproto.ImageFormatZPixmap,
|
||||
xproto.Drawable(c.screen.Root),
|
||||
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF)
|
||||
|
||||
reply, err := cookie.Reply()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetImage: %w", err)
|
||||
}
|
||||
|
||||
data := reply.Data
|
||||
n := c.w * c.h * 4
|
||||
if len(data) < n {
|
||||
return nil, fmt.Errorf("GetImage returned %d bytes, expected %d", len(data), n)
|
||||
}
|
||||
|
||||
img := c.nextBuffer()
|
||||
swizzleBGRAtoRGBA(img.Pix, data)
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// nextBuffer returns the *image.RGBA the next capture should fill, advancing
|
||||
// the double-buffer index. Reallocates on geometry change.
|
||||
func (c *X11Capturer) nextBuffer() *image.RGBA {
|
||||
c.cur ^= 1
|
||||
b := c.bufs[c.cur]
|
||||
if b == nil || b.Rect.Dx() != c.w || b.Rect.Dy() != c.h {
|
||||
b = image.NewRGBA(image.Rect(0, 0, c.w, c.h))
|
||||
c.bufs[c.cur] = b
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Close releases X11 resources.
|
||||
func (c *X11Capturer) Close() {
|
||||
c.closeSHM()
|
||||
c.conn.Close()
|
||||
}
|
||||
|
||||
// closeSHM is implemented in capture_x11_shm_linux.go.
|
||||
|
||||
// X11Poller wraps X11Capturer with a staleness-cached on-demand Capture:
|
||||
// sessions drive captures themselves through the encoder goroutine, so we
|
||||
// don't need a background ticker. The last result is cached for a short
|
||||
// window so concurrent sessions coalesce into one capture.
|
||||
//
|
||||
// The capturer is allocated lazily on first use and released when all
|
||||
// clients disconnect, so an idle peer holds no X connection or SHM segment.
|
||||
type X11Poller struct {
|
||||
mu sync.Mutex
|
||||
|
||||
capturer *X11Capturer
|
||||
w, h int
|
||||
// closed at Close so callers can stop waiting on retry backoff.
|
||||
done chan struct{}
|
||||
|
||||
// lastFrame/lastAt implement a small cache: multiple near-simultaneous
|
||||
// Capture calls (multi-client, or input-coalesced) return the same
|
||||
// frame instead of hammering the X server.
|
||||
lastFrame *image.RGBA
|
||||
lastAt time.Time
|
||||
|
||||
// initBackoffUntil throttles capturer re-init when the X server is
|
||||
// unavailable or flapping.
|
||||
initBackoffUntil time.Time
|
||||
|
||||
clients atomic.Int32
|
||||
display string
|
||||
}
|
||||
|
||||
// initRetryBackoff gates capturer re-init attempts after a failure so we
|
||||
// don't spin on X server errors.
|
||||
const initRetryBackoff = 2 * time.Second
|
||||
|
||||
// NewX11Poller creates a lazy on-demand capturer for the given X display.
|
||||
func NewX11Poller(display string) *X11Poller {
|
||||
return &X11Poller{
|
||||
display: display,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// ClientConnect increments the active client count. The first client triggers
|
||||
// eager capturer initialisation so that the first FBUpdateRequest doesn't
|
||||
// pay the X11 connect + SHM attach latency.
|
||||
func (p *X11Poller) ClientConnect() {
|
||||
if p.clients.Add(1) == 1 {
|
||||
p.mu.Lock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// ClientDisconnect decrements the active client count. On the last
|
||||
// disconnect we close the underlying capturer so idle peers cost nothing.
|
||||
func (p *X11Poller) ClientDisconnect() {
|
||||
if p.clients.Add(-1) == 0 {
|
||||
p.mu.Lock()
|
||||
if p.capturer != nil {
|
||||
p.capturer.Close()
|
||||
p.capturer = nil
|
||||
p.lastFrame = nil
|
||||
}
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Close releases all resources. Subsequent Capture calls will fail.
|
||||
func (p *X11Poller) Close() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
select {
|
||||
case <-p.done:
|
||||
default:
|
||||
close(p.done)
|
||||
}
|
||||
if p.capturer != nil {
|
||||
p.capturer.Close()
|
||||
p.capturer = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Width returns the screen width. Triggers lazy init if needed.
|
||||
func (p *X11Poller) Width() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
return p.w
|
||||
}
|
||||
|
||||
// Height returns the screen height. Triggers lazy init if needed.
|
||||
func (p *X11Poller) Height() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
return p.h
|
||||
}
|
||||
|
||||
// Cursor satisfies cursorSource by forwarding to the lazily-initialised
|
||||
// X11Capturer. Asking for the cursor on an idle poller triggers the same
|
||||
// lazy X11 connection setup as a capture would.
|
||||
func (p *X11Poller) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return nil, 0, 0, 0, err
|
||||
}
|
||||
return p.capturer.Cursor()
|
||||
}
|
||||
|
||||
// CursorPos satisfies cursorPositionSource by forwarding to the X11Capturer.
|
||||
func (p *X11Poller) CursorPos() (int, int, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return p.capturer.CursorPos()
|
||||
}
|
||||
|
||||
// Capture returns a fresh frame, serving from the short-lived cache if a
|
||||
// previous caller captured within freshWindow.
|
||||
func (p *X11Poller) Capture() (*image.RGBA, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.lastFrame != nil && time.Since(p.lastAt) < freshWindow {
|
||||
return p.lastFrame, nil
|
||||
}
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
img, err := p.capturer.Capture()
|
||||
if err != nil {
|
||||
// Drop the capturer so the next call re-inits; the X connection may
|
||||
// have died (e.g. Xorg restart).
|
||||
p.capturer.Close()
|
||||
p.capturer = nil
|
||||
p.initBackoffUntil = time.Now().Add(initRetryBackoff)
|
||||
return nil, fmt.Errorf("x11 capture: %w", err)
|
||||
}
|
||||
p.lastFrame = img
|
||||
p.lastAt = time.Now()
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// CaptureInto fills dst directly via the underlying capturer, bypassing
|
||||
// the freshness cache. The session's prevFrame/curFrame swap means each
|
||||
// session needs its own buffer anyway, so caching wouldn't help.
|
||||
func (p *X11Poller) CaptureInto(dst *image.RGBA) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := p.capturer.CaptureInto(dst); err != nil {
|
||||
p.capturer.Close()
|
||||
p.capturer = nil
|
||||
p.initBackoffUntil = time.Now().Add(initRetryBackoff)
|
||||
return fmt.Errorf("x11 capture: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureCapturerLocked initialises the underlying X11Capturer if not
|
||||
// already open. Caller must hold p.mu.
|
||||
func (p *X11Poller) ensureCapturerLocked() error {
|
||||
if p.capturer != nil {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-p.done:
|
||||
return fmt.Errorf("x11 capturer closed")
|
||||
default:
|
||||
}
|
||||
if time.Now().Before(p.initBackoffUntil) {
|
||||
return fmt.Errorf("x11 capturer unavailable (retry scheduled)")
|
||||
}
|
||||
c, err := NewX11Capturer(p.display)
|
||||
if err != nil {
|
||||
p.initBackoffUntil = time.Now().Add(initRetryBackoff)
|
||||
log.Debugf("X11 capturer: %v", err)
|
||||
return err
|
||||
}
|
||||
p.capturer = c
|
||||
p.w, p.h = c.Width(), c.Height()
|
||||
return nil
|
||||
}
|
||||
96
client/vnc/server/capture_x11_shm_linux.go
Normal file
96
client/vnc/server/capture_x11_shm_linux.go
Normal file
@@ -0,0 +1,96 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
|
||||
"github.com/jezek/xgb/shm"
|
||||
"github.com/jezek/xgb/xproto"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func (c *X11Capturer) initSHM() error {
|
||||
if err := shm.Init(c.conn); err != nil {
|
||||
return fmt.Errorf("init SHM extension: %w", err)
|
||||
}
|
||||
|
||||
size := c.w * c.h * 4
|
||||
id, err := unix.SysvShmGet(unix.IPC_PRIVATE, size, unix.IPC_CREAT|0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("shmget: %w", err)
|
||||
}
|
||||
|
||||
addr, err := unix.SysvShmAttach(id, 0, 0)
|
||||
if err != nil {
|
||||
if _, ctlErr := unix.SysvShmCtl(id, unix.IPC_RMID, nil); ctlErr != nil {
|
||||
log.Debugf("shmctl IPC_RMID on attach failure: %v", ctlErr)
|
||||
}
|
||||
return fmt.Errorf("shmat: %w", err)
|
||||
}
|
||||
|
||||
if _, err := unix.SysvShmCtl(id, unix.IPC_RMID, nil); err != nil {
|
||||
log.Debugf("shmctl IPC_RMID: %v", err)
|
||||
}
|
||||
|
||||
seg, err := shm.NewSegId(c.conn)
|
||||
if err != nil {
|
||||
if detachErr := unix.SysvShmDetach(addr); detachErr != nil {
|
||||
log.Debugf("shmdt on new-seg failure: %v", detachErr)
|
||||
}
|
||||
return fmt.Errorf("new SHM seg: %w", err)
|
||||
}
|
||||
|
||||
if err := shm.AttachChecked(c.conn, seg, uint32(id), false).Check(); err != nil {
|
||||
if detachErr := unix.SysvShmDetach(addr); detachErr != nil {
|
||||
log.Debugf("shmdt on attach-checked failure: %v", detachErr)
|
||||
}
|
||||
return fmt.Errorf("SHM attach to X: %w", err)
|
||||
}
|
||||
|
||||
c.shmID = id
|
||||
c.shmAddr = addr
|
||||
c.shmSeg = uint32(seg)
|
||||
c.useSHM = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
|
||||
if err := c.fillSHM(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
img := c.nextBuffer()
|
||||
swizzleBGRAtoRGBA(img.Pix, c.shmAddr[:c.w*c.h*4])
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// captureSHMInto runs a single SHM GetImage and swizzles directly into the
|
||||
// caller-provided destination, skipping the internal double-buffer.
|
||||
func (c *X11Capturer) captureSHMInto(dst *image.RGBA) error {
|
||||
if err := c.fillSHM(); err != nil {
|
||||
return err
|
||||
}
|
||||
swizzleBGRAtoRGBA(dst.Pix, c.shmAddr[:c.w*c.h*4])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *X11Capturer) fillSHM() error {
|
||||
cookie := shm.GetImage(c.conn, xproto.Drawable(c.screen.Root),
|
||||
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF,
|
||||
xproto.ImageFormatZPixmap, shm.Seg(c.shmSeg), 0)
|
||||
if _, err := cookie.Reply(); err != nil {
|
||||
return fmt.Errorf("SHM GetImage: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *X11Capturer) closeSHM() {
|
||||
if c.useSHM {
|
||||
shm.Detach(c.conn, shm.Seg(c.shmSeg))
|
||||
if err := unix.SysvShmDetach(c.shmAddr); err != nil {
|
||||
log.Debugf("shmdt on close: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
24
client/vnc/server/capture_x11_shm_stub.go
Normal file
24
client/vnc/server/capture_x11_shm_stub.go
Normal file
@@ -0,0 +1,24 @@
|
||||
//go:build freebsd
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
)
|
||||
|
||||
func (c *X11Capturer) initSHM() error {
|
||||
return fmt.Errorf("SysV SHM not available on this platform")
|
||||
}
|
||||
|
||||
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
|
||||
return nil, fmt.Errorf("SHM capture not available on this platform")
|
||||
}
|
||||
|
||||
func (c *X11Capturer) captureSHMInto(_ *image.RGBA) error {
|
||||
return fmt.Errorf("SHM capture not available on this platform")
|
||||
}
|
||||
|
||||
func (c *X11Capturer) closeSHM() {
|
||||
// no SHM to close on this platform
|
||||
}
|
||||
77
client/vnc/server/coalesce_test.go
Normal file
77
client/vnc/server/coalesce_test.go
Normal file
@@ -0,0 +1,77 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCoalesceRects(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in [][4]int
|
||||
want [][4]int
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
in: nil,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "single",
|
||||
in: [][4]int{{0, 0, 64, 64}},
|
||||
want: [][4]int{{0, 0, 64, 64}},
|
||||
},
|
||||
{
|
||||
name: "horizontal_run",
|
||||
in: [][4]int{{0, 0, 64, 64}, {64, 0, 64, 64}, {128, 0, 64, 64}},
|
||||
want: [][4]int{{0, 0, 192, 64}},
|
||||
},
|
||||
{
|
||||
name: "vertical_run",
|
||||
in: [][4]int{{0, 0, 64, 64}, {0, 64, 64, 64}, {0, 128, 64, 64}},
|
||||
want: [][4]int{{0, 0, 64, 192}},
|
||||
},
|
||||
{
|
||||
name: "block_2x2",
|
||||
in: [][4]int{
|
||||
{0, 0, 64, 64}, {64, 0, 64, 64},
|
||||
{0, 64, 64, 64}, {64, 64, 64, 64},
|
||||
},
|
||||
want: [][4]int{{0, 0, 128, 128}},
|
||||
},
|
||||
{
|
||||
name: "no_merge_gap",
|
||||
in: [][4]int{{0, 0, 64, 64}, {192, 0, 64, 64}},
|
||||
want: [][4]int{{0, 0, 64, 64}, {192, 0, 64, 64}},
|
||||
},
|
||||
{
|
||||
name: "two_disjoint_columns",
|
||||
in: [][4]int{
|
||||
{0, 0, 64, 64}, {192, 0, 64, 64},
|
||||
{0, 64, 64, 64}, {192, 64, 64, 64},
|
||||
},
|
||||
want: [][4]int{{0, 0, 64, 128}, {192, 0, 64, 128}},
|
||||
},
|
||||
{
|
||||
name: "misaligned_widths_no_vertical_merge",
|
||||
in: [][4]int{
|
||||
{0, 0, 128, 64},
|
||||
{0, 64, 64, 64},
|
||||
},
|
||||
want: [][4]int{{0, 0, 128, 64}, {0, 64, 64, 64}},
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := coalesceRects(tc.in)
|
||||
if len(got) == 0 && len(tc.want) == 0 {
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Fatalf("got %v want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
191
client/vnc/server/copyrect.go
Normal file
191
client/vnc/server/copyrect.go
Normal file
@@ -0,0 +1,191 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"hash/maphash"
|
||||
"image"
|
||||
)
|
||||
|
||||
// copyRectDetector finds tiles in the current frame that match the content
|
||||
// of some tile-aligned region of the previous frame, so we can emit them as
|
||||
// CopyRect rectangles (16 wire bytes) instead of re-encoding the pixels.
|
||||
//
|
||||
// The detector keeps two structures:
|
||||
// - tileHash, a flat slice of one hash per tile-aligned position, used as
|
||||
// the source of truth for the previous frame's tile content.
|
||||
// - prevTiles, a hash → position lookup used during findTileMatch.
|
||||
//
|
||||
// updateDirty rehashes only the tiles that changed this frame, so the
|
||||
// steady-state cost is proportional to the dirty set, not the framebuffer.
|
||||
// A full rebuild from scratch is only done on the first frame or when the
|
||||
// detector has not yet been initialized for the current resolution.
|
||||
//
|
||||
// Limitations:
|
||||
// - Only tile-aligned source positions are considered. Sub-tile-aligned
|
||||
// moves (e.g. window dragged by 7 pixels) are not detected. This still
|
||||
// covers the common case of vertical/horizontal scrolling, which always
|
||||
// produces tile-aligned matches at the tile granularity.
|
||||
// - 64-bit maphash collisions are assumed not to happen. The probability
|
||||
// for any single frame's hash universe is ~2^-32 * tileCount² which is
|
||||
// vanishingly small at typical resolutions; if we ever observe one we
|
||||
// can fall back to a full memcmp verification.
|
||||
type copyRectDetector struct {
|
||||
seed maphash.Seed
|
||||
tileSize int
|
||||
w, h int
|
||||
cols, rows int
|
||||
// tileHash[ty*cols + tx] is the current hash of the tile at (tx, ty)
|
||||
// in the previous frame. Lookup uses this to detect stale prevTiles
|
||||
// entries: incremental updates may leave hash→pos entries pointing
|
||||
// at a tile whose content has since changed.
|
||||
tileHash []uint64
|
||||
// prevTiles maps a tile hash to a (x, y) origin in the previous frame.
|
||||
prevTiles map[uint64][2]int
|
||||
// hash is reused across hash computations to keep the per-tile lookup
|
||||
// path allocation-free.
|
||||
hash maphash.Hash
|
||||
}
|
||||
|
||||
func newCopyRectDetector(tileSize int) *copyRectDetector {
|
||||
d := ©RectDetector{
|
||||
seed: maphash.MakeSeed(),
|
||||
tileSize: tileSize,
|
||||
prevTiles: make(map[uint64][2]int),
|
||||
}
|
||||
d.hash.SetSeed(d.seed)
|
||||
return d
|
||||
}
|
||||
|
||||
// resize ensures the per-tile tables match the given framebuffer size.
|
||||
// Called from rebuild before each full hash sweep.
|
||||
func (d *copyRectDetector) resize(w, h int) {
|
||||
if d.w == w && d.h == h && d.tileHash != nil {
|
||||
return
|
||||
}
|
||||
d.w, d.h = w, h
|
||||
d.cols = w / d.tileSize
|
||||
d.rows = h / d.tileSize
|
||||
d.tileHash = make([]uint64, d.cols*d.rows)
|
||||
}
|
||||
|
||||
// hashTile computes the 64-bit maphash of one tile-aligned tile of frame.
|
||||
func (d *copyRectDetector) hashTile(frame *image.RGBA, tx, ty int) uint64 {
|
||||
d.hash.Reset()
|
||||
ts := d.tileSize
|
||||
stride := frame.Stride
|
||||
rowBytes := ts * 4
|
||||
base := ty*stride + tx*4
|
||||
for row := 0; row < ts; row++ {
|
||||
off := base + row*stride
|
||||
_, _ = d.hash.Write(frame.Pix[off : off+rowBytes])
|
||||
}
|
||||
return d.hash.Sum64()
|
||||
}
|
||||
|
||||
// rebuild discards everything and rehashes the whole frame. O(w*h). Use
|
||||
// for the first frame or after the detector has been resized. Steady-state
|
||||
// updates should go through updateDirty instead.
|
||||
func (d *copyRectDetector) rebuild(frame *image.RGBA, w, h int) {
|
||||
d.resize(w, h)
|
||||
if d.prevTiles == nil {
|
||||
d.prevTiles = make(map[uint64][2]int)
|
||||
} else {
|
||||
clear(d.prevTiles)
|
||||
}
|
||||
ts := d.tileSize
|
||||
for ty := 0; ty+ts <= h; ty += ts {
|
||||
for tx := 0; tx+ts <= w; tx += ts {
|
||||
sum := d.hashTile(frame, tx, ty)
|
||||
d.tileHash[(ty/ts)*d.cols+(tx/ts)] = sum
|
||||
if _, exists := d.prevTiles[sum]; !exists {
|
||||
d.prevTiles[sum] = [2]int{tx, ty}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateDirty rehashes only the tiles named in dirty (each entry is
|
||||
// [x, y, w, h] with w and h equal to tileSize). O(len(dirty)) work, which
|
||||
// in the common case is a tiny fraction of the whole framebuffer.
|
||||
//
|
||||
// The prevTiles map is replaced on collision rather than first-wins so a
|
||||
// newly-hashed tile claims the slot. Old, stale entries pointing at tiles
|
||||
// that no longer carry that hash are filtered at lookup time via tileHash.
|
||||
func (d *copyRectDetector) updateDirty(frame *image.RGBA, w, h int, dirty [][4]int) {
|
||||
if d.w != w || d.h != h || d.tileHash == nil {
|
||||
d.rebuild(frame, w, h)
|
||||
return
|
||||
}
|
||||
ts := d.tileSize
|
||||
for _, r := range dirty {
|
||||
if r[2] != ts || r[3] != ts {
|
||||
continue
|
||||
}
|
||||
tx, ty := r[0], r[1]
|
||||
if tx+ts > w || ty+ts > h {
|
||||
continue
|
||||
}
|
||||
sum := d.hashTile(frame, tx, ty)
|
||||
d.tileHash[(ty/ts)*d.cols+(tx/ts)] = sum
|
||||
// Latest-wins on collision: ensures the most recent owner of this
|
||||
// hash is the one we'll return on lookup. The previous owner's
|
||||
// entry, if any, gets shadowed; if its content has changed it's
|
||||
// stale anyway and findTileMatch's verification will skip it.
|
||||
d.prevTiles[sum] = [2]int{tx, ty}
|
||||
}
|
||||
}
|
||||
|
||||
// findTileMatch hashes the current-frame tile at (dstX, dstY) and looks up
|
||||
// its hash in the previous-frame map. Returns (srcX, srcY, true) when a
|
||||
// matching tile-aligned tile exists at a different position whose stored
|
||||
// hash still equals the requested hash (so the result is not stale).
|
||||
func (d *copyRectDetector) findTileMatch(cur *image.RGBA, dstX, dstY int) (int, int, bool) {
|
||||
if len(d.prevTiles) == 0 || d.tileHash == nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
ts := d.tileSize
|
||||
if dstX+ts > cur.Rect.Dx() || dstY+ts > cur.Rect.Dy() {
|
||||
return 0, 0, false
|
||||
}
|
||||
sum := d.hashTile(cur, dstX, dstY)
|
||||
pos, ok := d.prevTiles[sum]
|
||||
if !ok {
|
||||
return 0, 0, false
|
||||
}
|
||||
if pos[0] == dstX && pos[1] == dstY {
|
||||
return 0, 0, false
|
||||
}
|
||||
// Reject stale entries: the position the map points at must still
|
||||
// carry the same hash according to our per-tile array.
|
||||
if d.tileHash[(pos[1]/ts)*d.cols+(pos[0]/ts)] != sum {
|
||||
return 0, 0, false
|
||||
}
|
||||
return pos[0], pos[1], true
|
||||
}
|
||||
|
||||
// extractCopyRectTiles examines the diff-produced (per-tile) dirty list and
|
||||
// pulls out any tiles whose current-frame content matches a prev-frame tile
|
||||
// at a different position. Returns the CopyRect candidates and the residual
|
||||
// dirty tiles that still need pixel encoding.
|
||||
type copyRectMove struct {
|
||||
srcX, srcY int
|
||||
dstX, dstY int
|
||||
}
|
||||
|
||||
func (d *copyRectDetector) extractCopyRectTiles(cur *image.RGBA, dirtyTiles [][4]int) (moves []copyRectMove, remaining [][4]int) {
|
||||
ts := d.tileSize
|
||||
remaining = dirtyTiles[:0:cap(dirtyTiles)]
|
||||
for _, r := range dirtyTiles {
|
||||
if r[2] == ts && r[3] == ts {
|
||||
if sx, sy, ok := d.findTileMatch(cur, r[0], r[1]); ok {
|
||||
moves = append(moves, copyRectMove{
|
||||
srcX: sx, srcY: sy, dstX: r[0], dstY: r[1],
|
||||
})
|
||||
continue
|
||||
}
|
||||
}
|
||||
remaining = append(remaining, r)
|
||||
}
|
||||
return moves, remaining
|
||||
}
|
||||
162
client/vnc/server/copyrect_test.go
Normal file
162
client/vnc/server/copyrect_test.go
Normal file
@@ -0,0 +1,162 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"image"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// fillTile paints a tileSize×tileSize block of img at (x,y) with the colour
|
||||
// derived from (r,g,b) so the test can construct distinct-content tiles.
|
||||
func fillTile(img *image.RGBA, x, y, ts int, r, g, b byte) {
|
||||
for row := 0; row < ts; row++ {
|
||||
off := (y+row)*img.Stride + x*4
|
||||
for col := 0; col < ts; col++ {
|
||||
img.Pix[off+col*4+0] = r
|
||||
img.Pix[off+col*4+1] = g
|
||||
img.Pix[off+col*4+2] = b
|
||||
img.Pix[off+col*4+3] = 0xff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copyTile copies a tileSize×tileSize block from src(sx,sy) to dst(dx,dy).
|
||||
func copyTile(dst, src *image.RGBA, sx, sy, dx, dy, ts int) {
|
||||
for row := 0; row < ts; row++ {
|
||||
srcOff := (sy+row)*src.Stride + sx*4
|
||||
dstOff := (dy+row)*dst.Stride + dx*4
|
||||
copy(dst.Pix[dstOff:dstOff+ts*4], src.Pix[srcOff:srcOff+ts*4])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyRectDetector_DetectsVerticalScroll(t *testing.T) {
|
||||
const w, h = 256, 192 // 4×3 tiles at 64px
|
||||
const ts = 64
|
||||
|
||||
prev := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
cur := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
|
||||
// prev: 12 tiles each with a unique colour.
|
||||
for ty := 0; ty < 3; ty++ {
|
||||
for tx := 0; tx < 4; tx++ {
|
||||
fillTile(prev, tx*ts, ty*ts, ts, byte(tx*40), byte(ty*60), 0x80)
|
||||
}
|
||||
}
|
||||
// cur: simulate a single-tile-row scroll upward, every tile copied from
|
||||
// the row below in prev, top row is new content.
|
||||
for ty := 0; ty < 2; ty++ {
|
||||
for tx := 0; tx < 4; tx++ {
|
||||
copyTile(cur, prev, tx*ts, (ty+1)*ts, tx*ts, ty*ts, ts)
|
||||
}
|
||||
}
|
||||
// Bottom row of cur: new colour, not a match.
|
||||
for tx := 0; tx < 4; tx++ {
|
||||
fillTile(cur, tx*ts, 2*ts, ts, 0xff, 0xff, 0xff)
|
||||
}
|
||||
|
||||
d := newCopyRectDetector(ts)
|
||||
d.rebuild(prev, w, h)
|
||||
|
||||
tiles := diffTiles(prev, cur, w, h, ts)
|
||||
moves, remaining := d.extractCopyRectTiles(cur, tiles)
|
||||
|
||||
// Expect 8 CopyRect moves (top two rows) and 4 residual tiles (bottom row).
|
||||
if len(moves) != 8 {
|
||||
t.Fatalf("moves: want 8, got %d", len(moves))
|
||||
}
|
||||
if len(remaining) != 4 {
|
||||
t.Fatalf("remaining: want 4, got %d", len(remaining))
|
||||
}
|
||||
// Spot-check one move: cur (0, 0) should map to prev (0, 64).
|
||||
var found bool
|
||||
for _, m := range moves {
|
||||
if m.dstX == 0 && m.dstY == 0 {
|
||||
if m.srcX != 0 || m.srcY != ts {
|
||||
t.Fatalf("move at (0,0): src=(%d,%d), want (0,%d)", m.srcX, m.srcY, ts)
|
||||
}
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("no move for dst (0,0)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyRectDetector_RejectsSelfMatch(t *testing.T) {
|
||||
const w, h = 128, 128
|
||||
const ts = 64
|
||||
|
||||
prev := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
cur := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
|
||||
// prev: 4 tiles, all unique
|
||||
fillTile(prev, 0, 0, ts, 0x10, 0x20, 0x30)
|
||||
fillTile(prev, ts, 0, ts, 0x40, 0x50, 0x60)
|
||||
fillTile(prev, 0, ts, ts, 0x70, 0x80, 0x90)
|
||||
fillTile(prev, ts, ts, ts, 0xa0, 0xb0, 0xc0)
|
||||
|
||||
// cur: tile (0,0) unchanged, others changed but content same as prev's (0,0).
|
||||
fillTile(cur, 0, 0, ts, 0x10, 0x20, 0x30) // self-match
|
||||
fillTile(cur, ts, 0, ts, 0xff, 0xff, 0xff)
|
||||
fillTile(cur, 0, ts, ts, 0xff, 0xff, 0xff)
|
||||
fillTile(cur, ts, ts, ts, 0xff, 0xff, 0xff)
|
||||
|
||||
d := newCopyRectDetector(ts)
|
||||
d.rebuild(prev, w, h)
|
||||
|
||||
// Tile (0,0) is not in the dirty list (it's unchanged) so it should not
|
||||
// produce a move even though its hash matches prev (0,0).
|
||||
tiles := diffTiles(prev, cur, w, h, ts)
|
||||
moves, _ := d.extractCopyRectTiles(cur, tiles)
|
||||
for _, m := range moves {
|
||||
if m.dstX == 0 && m.dstY == 0 {
|
||||
t.Fatalf("unexpected move at (0,0)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyRectDetector_PassThroughWhenNoMatch(t *testing.T) {
|
||||
const w, h = 64, 64
|
||||
const ts = 64
|
||||
|
||||
prev := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
cur := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
fillTile(prev, 0, 0, ts, 0x11, 0x22, 0x33)
|
||||
fillTile(cur, 0, 0, ts, 0xaa, 0xbb, 0xcc) // wholly different
|
||||
|
||||
d := newCopyRectDetector(ts)
|
||||
d.rebuild(prev, w, h)
|
||||
tiles := diffTiles(prev, cur, w, h, ts)
|
||||
moves, remaining := d.extractCopyRectTiles(cur, tiles)
|
||||
|
||||
if len(moves) != 0 {
|
||||
t.Fatalf("expected 0 moves, got %d", len(moves))
|
||||
}
|
||||
if len(remaining) != 1 {
|
||||
t.Fatalf("expected 1 residual tile, got %d", len(remaining))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeCopyRectBody_Layout(t *testing.T) {
|
||||
got := encodeCopyRectBody(100, 200, 300, 400, 64, 48)
|
||||
if len(got) != 16 {
|
||||
t.Fatalf("CopyRect body length: want 16, got %d", len(got))
|
||||
}
|
||||
// Dest position
|
||||
if got[0] != 0x01 || got[1] != 0x2c || got[2] != 0x01 || got[3] != 0x90 {
|
||||
t.Fatalf("bad dest bytes: % x", got[0:4])
|
||||
}
|
||||
// Width, height
|
||||
if got[4] != 0 || got[5] != 64 || got[6] != 0 || got[7] != 48 {
|
||||
t.Fatalf("bad size bytes: % x", got[4:8])
|
||||
}
|
||||
// Encoding = 1
|
||||
if got[11] != 0x01 {
|
||||
t.Fatalf("bad encoding byte: 0x%02x", got[11])
|
||||
}
|
||||
// Source position
|
||||
if got[12] != 0 || got[13] != 100 || got[14] != 0 || got[15] != 200 {
|
||||
t.Fatalf("bad src bytes: % x", got[12:16])
|
||||
}
|
||||
}
|
||||
194
client/vnc/server/cursor_darwin.go
Normal file
194
client/vnc/server/cursor_darwin.go
Normal file
@@ -0,0 +1,194 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"image"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
darwinCursorOnce sync.Once
|
||||
cgsCreateCursor func() uintptr
|
||||
darwinCursorErr error
|
||||
)
|
||||
|
||||
// initDarwinCursor binds a private symbol that returns the current
|
||||
// system cursor image. The classic CGSCreateCurrentCursorImage moved
|
||||
// from CoreGraphics to SkyLight around macOS 13 and is gone entirely
|
||||
// in Sequoia; we probe both frameworks for any of the historical
|
||||
// names so this keeps working on whichever release the binding still
|
||||
// exists. Without a hit the remote-cursor compositing path becomes a
|
||||
// no-op and we log the candidates we tried.
|
||||
func initDarwinCursor() {
|
||||
darwinCursorOnce.Do(func() {
|
||||
libs := []string{
|
||||
"/System/Library/PrivateFrameworks/SkyLight.framework/SkyLight",
|
||||
"/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics",
|
||||
}
|
||||
names := []string{
|
||||
"CGSCreateCurrentCursorImage",
|
||||
"CGSCopyCurrentCursorImage",
|
||||
"CGSCurrentCursorImage",
|
||||
"CGSHardwareCursorActiveImage",
|
||||
}
|
||||
var tried []string
|
||||
for _, path := range libs {
|
||||
h, err := purego.Dlopen(path, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
tried = append(tried, fmt.Sprintf("dlopen %s: %v", path, err))
|
||||
continue
|
||||
}
|
||||
for _, name := range names {
|
||||
sym, err := purego.Dlsym(h, name)
|
||||
if err != nil {
|
||||
tried = append(tried, fmt.Sprintf("%s!%s missing", path, name))
|
||||
continue
|
||||
}
|
||||
purego.RegisterFunc(&cgsCreateCursor, sym)
|
||||
log.Infof("macOS cursor: bound %s from %s", name, path)
|
||||
return
|
||||
}
|
||||
}
|
||||
darwinCursorErr = fmt.Errorf("no cursor image symbol available; tried: %v", tried)
|
||||
})
|
||||
}
|
||||
|
||||
// cgCursor holds the cached macOS cursor sprite and bumps a serial when
|
||||
// the bytes change. Hotspot is left at (0, 0): the public Cocoa hot-spot
|
||||
// query lives on NSCursor which is process-local and not reachable from
|
||||
// our purego-based bindings; the visual cost is a small misalignment for
|
||||
// non-arrow cursors (I-beam, crosshair, etc.).
|
||||
type cgCursor struct {
|
||||
mu sync.Mutex
|
||||
hashSeed maphash.Seed
|
||||
lastSum uint64
|
||||
cached *image.RGBA
|
||||
serial uint64
|
||||
}
|
||||
|
||||
func newCGCursor() *cgCursor {
|
||||
initDarwinCursor()
|
||||
return &cgCursor{hashSeed: maphash.MakeSeed()}
|
||||
}
|
||||
|
||||
// Cursor returns the current cursor sprite as RGBA. Errors that come from
|
||||
// missing private symbols are sticky; transient empty-image responses are
|
||||
// reported as such so the encoder skips this cycle.
|
||||
func (c *cgCursor) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if darwinCursorErr != nil {
|
||||
return nil, 0, 0, 0, darwinCursorErr
|
||||
}
|
||||
if cgsCreateCursor == nil {
|
||||
return nil, 0, 0, 0, fmt.Errorf("CGSCreateCurrentCursorImage unavailable")
|
||||
}
|
||||
cgImage := cgsCreateCursor()
|
||||
if cgImage == 0 {
|
||||
return nil, 0, 0, 0, fmt.Errorf("no cursor image available")
|
||||
}
|
||||
defer cgImageRelease(cgImage)
|
||||
|
||||
w := int(cgImageGetWidth(cgImage))
|
||||
h := int(cgImageGetHeight(cgImage))
|
||||
if w <= 0 || h <= 0 {
|
||||
return nil, 0, 0, 0, fmt.Errorf("cursor has zero extent")
|
||||
}
|
||||
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
|
||||
bpp := int(cgImageGetBitsPerPixel(cgImage))
|
||||
if bpp != 32 {
|
||||
return nil, 0, 0, 0, fmt.Errorf("unsupported cursor bpp: %d", bpp)
|
||||
}
|
||||
provider := cgImageGetDataProvider(cgImage)
|
||||
if provider == 0 {
|
||||
return nil, 0, 0, 0, fmt.Errorf("cursor data provider missing")
|
||||
}
|
||||
cfData := cgDataProviderCopyData(provider)
|
||||
if cfData == 0 {
|
||||
return nil, 0, 0, 0, fmt.Errorf("cursor data copy failed")
|
||||
}
|
||||
defer cfRelease(cfData)
|
||||
dataLen := int(cfDataGetLength(cfData))
|
||||
dataPtr := cfDataGetBytePtr(cfData)
|
||||
if dataPtr == 0 || dataLen == 0 {
|
||||
return nil, 0, 0, 0, fmt.Errorf("cursor data empty")
|
||||
}
|
||||
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
|
||||
|
||||
sum := maphash.Bytes(c.hashSeed, src)
|
||||
if c.cached != nil && sum == c.lastSum {
|
||||
return c.cached, 0, 0, c.serial, nil
|
||||
}
|
||||
|
||||
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
for y := 0; y < h; y++ {
|
||||
srcOff := y * bytesPerRow
|
||||
dstOff := y * w * 4
|
||||
for x := 0; x < w; x++ {
|
||||
si := srcOff + x*4
|
||||
di := dstOff + x*4
|
||||
img.Pix[di+0] = src[si+2]
|
||||
img.Pix[di+1] = src[si+1]
|
||||
img.Pix[di+2] = src[si+0]
|
||||
img.Pix[di+3] = src[si+3]
|
||||
}
|
||||
}
|
||||
|
||||
c.lastSum = sum
|
||||
c.cached = img
|
||||
c.serial++
|
||||
return img, 0, 0, c.serial, nil
|
||||
}
|
||||
|
||||
// Cursor on CGCapturer satisfies cursorSource. The cgCursor wrapper is
|
||||
// allocated lazily so a build that never asks for the cursor pays no cost.
|
||||
func (c *CGCapturer) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
c.cursorOnce.Do(func() {
|
||||
c.cursor = newCGCursor()
|
||||
})
|
||||
return c.cursor.Cursor()
|
||||
}
|
||||
|
||||
// CursorPos returns the current global mouse location via CGEventCreate /
|
||||
// CGEventGetLocation. Coordinates are screen pixels in the main display.
|
||||
func (c *CGCapturer) CursorPos() (int, int, error) {
|
||||
if cgEventCreate == nil || cgEventGetLocation == nil {
|
||||
return 0, 0, fmt.Errorf("CGEvent location APIs unavailable")
|
||||
}
|
||||
ev := cgEventCreate(0)
|
||||
if ev == 0 {
|
||||
return 0, 0, fmt.Errorf("CGEventCreate returned nil")
|
||||
}
|
||||
defer cfRelease(ev)
|
||||
pt := cgEventGetLocation(ev)
|
||||
return int(pt.X), int(pt.Y), nil
|
||||
}
|
||||
|
||||
// Cursor on MacPoller forwards to the lazy CGCapturer. ensureCapturerLocked
|
||||
// returns an error when Screen Recording permission has not been granted;
|
||||
// in that case there is no usable cursor source either.
|
||||
func (p *MacPoller) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return nil, 0, 0, 0, err
|
||||
}
|
||||
return p.capturer.Cursor()
|
||||
}
|
||||
|
||||
// CursorPos forwards to the lazy CGCapturer.
|
||||
func (p *MacPoller) CursorPos() (int, int, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return p.capturer.CursorPos()
|
||||
}
|
||||
407
client/vnc/server/cursor_windows.go
Normal file
407
client/vnc/server/cursor_windows.go
Normal file
@@ -0,0 +1,407 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
procGetCursorInfo = user32.NewProc("GetCursorInfo")
|
||||
procGetIconInfo = user32.NewProc("GetIconInfo")
|
||||
procGetObjectW = gdi32.NewProc("GetObjectW")
|
||||
procGetDIBits = gdi32.NewProc("GetDIBits")
|
||||
)
|
||||
|
||||
const (
|
||||
cursorShowing = 0x00000001
|
||||
diRgbColors = 0
|
||||
biRgb = 0
|
||||
dibSectionBytes = 40 // sizeof(BITMAPINFOHEADER)
|
||||
)
|
||||
|
||||
// hiddenHandle is a sentinel stored in cursorSampler.lastHandle while
|
||||
// Windows reports the cursor as hidden. It is not a valid HCURSOR value;
|
||||
// real handles never collide with this constant.
|
||||
const hiddenHandle = windows.Handle(^uintptr(0))
|
||||
|
||||
// transparentCursorImage returns a 1x1 fully transparent sprite. The
|
||||
// client renders this as "no cursor"; emitting it explicitly lets us
|
||||
// recover when an app un-hides the cursor a moment later.
|
||||
func transparentCursorImage() *image.RGBA {
|
||||
return image.NewRGBA(image.Rect(0, 0, 1, 1))
|
||||
}
|
||||
|
||||
type winPoint struct {
|
||||
X, Y int32
|
||||
}
|
||||
|
||||
type winCursorInfo struct {
|
||||
Size uint32
|
||||
Flags uint32
|
||||
Cursor windows.Handle
|
||||
PtPos winPoint
|
||||
}
|
||||
|
||||
type winIconInfo struct {
|
||||
FIcon int32
|
||||
XHotspot uint32
|
||||
YHotspot uint32
|
||||
HbmMask windows.Handle
|
||||
HbmColor windows.Handle
|
||||
}
|
||||
|
||||
type winBitmap struct {
|
||||
BmType int32
|
||||
BmWidth int32
|
||||
BmHeight int32
|
||||
BmWidthBytes int32
|
||||
BmPlanes uint16
|
||||
BmBitsPixel uint16
|
||||
BmBits uintptr
|
||||
}
|
||||
|
||||
type winBitmapInfoHeader struct {
|
||||
BiSize uint32
|
||||
BiWidth int32
|
||||
BiHeight int32
|
||||
BiPlanes uint16
|
||||
BiBitCount uint16
|
||||
BiCompression uint32
|
||||
BiSizeImage uint32
|
||||
BiXPelsPerMeter int32
|
||||
BiYPelsPerMeter int32
|
||||
BiClrUsed uint32
|
||||
BiClrImportant uint32
|
||||
}
|
||||
|
||||
// cursorSnapshot is the captured cursor state shared between the worker
|
||||
// (which polls the OS) and the session encoder (which reads it).
|
||||
type cursorSnapshot struct {
|
||||
img *image.RGBA
|
||||
hotX int
|
||||
hotY int
|
||||
posX int
|
||||
posY int
|
||||
hasPos bool
|
||||
serial uint64
|
||||
err error
|
||||
}
|
||||
|
||||
// cursorSampler captures the foreground process's cursor sprite via Win32
|
||||
// APIs. It must be called from a goroutine attached to the same window
|
||||
// station and desktop as the user session (the capture worker does this
|
||||
// via switchToInputDesktop). lastHandle dedupes per-shape work so we only
|
||||
// touch GDI when Windows hands us a new cursor.
|
||||
type cursorSampler struct {
|
||||
lastHandle windows.Handle
|
||||
serial uint64
|
||||
snapshot *cursorSnapshot
|
||||
}
|
||||
|
||||
// sample queries the current cursor and decodes a new sprite when Windows
|
||||
// reports a different HCURSOR than last time. Returns the current snapshot
|
||||
// regardless of whether anything changed; callers diff by serial.
|
||||
func (s *cursorSampler) sample() (*cursorSnapshot, error) {
|
||||
var ci winCursorInfo
|
||||
ci.Size = uint32(unsafe.Sizeof(ci))
|
||||
r, _, err := procGetCursorInfo.Call(uintptr(unsafe.Pointer(&ci)))
|
||||
if r == 0 {
|
||||
return nil, fmt.Errorf("GetCursorInfo: %w", err)
|
||||
}
|
||||
if ci.Flags&cursorShowing == 0 || ci.Cursor == 0 {
|
||||
// Cursor temporarily hidden by an app (text fields toggle it on
|
||||
// focus). Emit a 1x1 transparent sprite so the client renders no
|
||||
// cursor and stay armed for the next handle change rather than
|
||||
// treating this as a hard failure that would latch us off for
|
||||
// the session.
|
||||
if s.lastHandle == hiddenHandle {
|
||||
s.snapshot.posX = int(ci.PtPos.X)
|
||||
s.snapshot.posY = int(ci.PtPos.Y)
|
||||
s.snapshot.hasPos = true
|
||||
return s.snapshot, nil
|
||||
}
|
||||
s.lastHandle = hiddenHandle
|
||||
s.serial++
|
||||
s.snapshot = &cursorSnapshot{
|
||||
img: transparentCursorImage(),
|
||||
posX: int(ci.PtPos.X),
|
||||
posY: int(ci.PtPos.Y),
|
||||
hasPos: true,
|
||||
serial: s.serial,
|
||||
}
|
||||
return s.snapshot, nil
|
||||
}
|
||||
if ci.Cursor == s.lastHandle && s.snapshot != nil {
|
||||
s.snapshot.posX = int(ci.PtPos.X)
|
||||
s.snapshot.posY = int(ci.PtPos.Y)
|
||||
s.snapshot.hasPos = true
|
||||
return s.snapshot, nil
|
||||
}
|
||||
img, hotX, hotY, err := decodeCursor(ci.Cursor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.lastHandle = ci.Cursor
|
||||
s.serial++
|
||||
s.snapshot = &cursorSnapshot{
|
||||
img: img,
|
||||
hotX: hotX,
|
||||
hotY: hotY,
|
||||
posX: int(ci.PtPos.X),
|
||||
posY: int(ci.PtPos.Y),
|
||||
hasPos: true,
|
||||
serial: s.serial,
|
||||
}
|
||||
return s.snapshot, nil
|
||||
}
|
||||
|
||||
// decodeCursor extracts the sprite at hCur as RGBA along with the hotspot.
|
||||
// Color cursors are read from the colour bitmap with the AND mask combined
|
||||
// in for alpha. Monochrome cursors collapse the two halves of the mask
|
||||
// bitmap into a single visible sprite where the AND bit drives alpha.
|
||||
func decodeCursor(hCur windows.Handle) (*image.RGBA, int, int, error) {
|
||||
var info winIconInfo
|
||||
r, _, err := procGetIconInfo.Call(uintptr(hCur), uintptr(unsafe.Pointer(&info)))
|
||||
if r == 0 {
|
||||
return nil, 0, 0, fmt.Errorf("GetIconInfo: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if info.HbmMask != 0 {
|
||||
_, _, _ = procDeleteObject.Call(uintptr(info.HbmMask))
|
||||
}
|
||||
if info.HbmColor != 0 {
|
||||
_, _, _ = procDeleteObject.Call(uintptr(info.HbmColor))
|
||||
}
|
||||
}()
|
||||
hotX, hotY := int(info.XHotspot), int(info.YHotspot)
|
||||
if info.HbmColor != 0 {
|
||||
img, err := decodeColorCursor(info.HbmColor, info.HbmMask)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
return img, hotX, hotY, nil
|
||||
}
|
||||
img, err := decodeMonoCursor(info.HbmMask)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
return img, hotX, hotY, nil
|
||||
}
|
||||
|
||||
// readBitmap returns the BITMAP descriptor for hbm.
|
||||
func readBitmap(hbm windows.Handle) (winBitmap, error) {
|
||||
var bm winBitmap
|
||||
r, _, err := procGetObjectW.Call(uintptr(hbm), unsafe.Sizeof(bm), uintptr(unsafe.Pointer(&bm)))
|
||||
if r == 0 {
|
||||
return winBitmap{}, fmt.Errorf("GetObject: %w", err)
|
||||
}
|
||||
return bm, nil
|
||||
}
|
||||
|
||||
// dibCopy reads hbm as 32bpp top-down BGRA into a freshly allocated slice
|
||||
// matching w*h*4 bytes. The bitmap may be selected into the screen DC so
|
||||
// we use a memory DC to keep the call cheap.
|
||||
func dibCopy(hbm windows.Handle, w, h int32) ([]byte, error) {
|
||||
hdcScreen, _, _ := procGetDC.Call(0)
|
||||
if hdcScreen == 0 {
|
||||
return nil, fmt.Errorf("GetDC: failed")
|
||||
}
|
||||
defer func() { _, _, _ = procReleaseDC.Call(0, hdcScreen) }()
|
||||
hdcMem, _, _ := procCreateCompatDC.Call(hdcScreen)
|
||||
if hdcMem == 0 {
|
||||
return nil, fmt.Errorf("CreateCompatibleDC: failed")
|
||||
}
|
||||
defer func() { _, _, _ = procDeleteDC.Call(hdcMem) }()
|
||||
|
||||
var bih winBitmapInfoHeader
|
||||
bih.BiSize = dibSectionBytes
|
||||
bih.BiWidth = w
|
||||
bih.BiHeight = -h // top-down
|
||||
bih.BiPlanes = 1
|
||||
bih.BiBitCount = 32
|
||||
bih.BiCompression = biRgb
|
||||
|
||||
buf := make([]byte, int(w)*int(h)*4)
|
||||
r, _, err := procGetDIBits.Call(
|
||||
hdcMem,
|
||||
uintptr(hbm),
|
||||
0,
|
||||
uintptr(h),
|
||||
uintptr(unsafe.Pointer(&buf[0])),
|
||||
uintptr(unsafe.Pointer(&bih)),
|
||||
diRgbColors,
|
||||
)
|
||||
if r == 0 {
|
||||
return nil, fmt.Errorf("GetDIBits: %w", err)
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// decodeColorCursor reads a 32bpp colour cursor and folds the AND mask into
|
||||
// the alpha channel when the colour bitmap leaves it zero.
|
||||
func decodeColorCursor(hbmColor, hbmMask windows.Handle) (*image.RGBA, error) {
|
||||
bm, err := readBitmap(hbmColor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w, h := bm.BmWidth, bm.BmHeight
|
||||
color, err := dibCopy(hbmColor, w, h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var mask []byte
|
||||
if hbmMask != 0 {
|
||||
mask, _ = dibCopy(hbmMask, w, h)
|
||||
}
|
||||
hasAlpha := colorHasAlpha(color)
|
||||
img := image.NewRGBA(image.Rect(0, 0, int(w), int(h)))
|
||||
for y := int32(0); y < h; y++ {
|
||||
for x := int32(0); x < w; x++ {
|
||||
si := (y*w + x) * 4
|
||||
b := color[si]
|
||||
g := color[si+1]
|
||||
r := color[si+2]
|
||||
a := pixelAlpha(color[si+3], si, mask, hasAlpha)
|
||||
// Premultiply so the shared compositor can use the same
|
||||
// formula on every platform (X11 XFixes and macOS CG return
|
||||
// premultiplied bytes natively).
|
||||
if a != 255 && a != 0 {
|
||||
r = byte(uint32(r) * uint32(a) / 255)
|
||||
g = byte(uint32(g) * uint32(a) / 255)
|
||||
b = byte(uint32(b) * uint32(a) / 255)
|
||||
} else if a == 0 {
|
||||
r, g, b = 0, 0, 0
|
||||
}
|
||||
img.Pix[si+0] = r
|
||||
img.Pix[si+1] = g
|
||||
img.Pix[si+2] = b
|
||||
img.Pix[si+3] = a
|
||||
}
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// colorHasAlpha reports whether any pixel of a 32bpp BGRA buffer has a
|
||||
// non-zero alpha. Cursors authored without alpha leave the channel at 0
|
||||
// and rely on hbmMask for transparency.
|
||||
func colorHasAlpha(color []byte) bool {
|
||||
for i := 0; i < len(color); i += 4 {
|
||||
if color[i+3] != 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// pixelAlpha returns the effective alpha for a colour-cursor pixel. When
|
||||
// the source bitmap already has alpha we trust it; otherwise the AND mask
|
||||
// decides (1 = transparent, 0 = opaque). The 32bpp DIB stores each AND
|
||||
// bit as a 4-byte entry; the first byte carries the effective value.
|
||||
func pixelAlpha(colorA byte, si int32, mask []byte, hasAlpha bool) byte {
|
||||
if hasAlpha {
|
||||
return colorA
|
||||
}
|
||||
if mask != nil && mask[si] != 0 {
|
||||
return 0
|
||||
}
|
||||
return 255
|
||||
}
|
||||
|
||||
// decodeMonoCursor handles legacy 1bpp cursors where hbmMask is twice as
|
||||
// tall as the visible sprite: rows [0..h) are the AND mask and rows [h..2h)
|
||||
// are the XOR mask. We render the visible half into RGBA, treating
|
||||
// AND-mask=1 as transparent and the XOR bit as a black/white pixel.
|
||||
func decodeMonoCursor(hbmMask windows.Handle) (*image.RGBA, error) {
|
||||
bm, err := readBitmap(hbmMask)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w, fullH := bm.BmWidth, bm.BmHeight
|
||||
if fullH%2 != 0 {
|
||||
return nil, fmt.Errorf("unexpected mono cursor shape: %dx%d", w, fullH)
|
||||
}
|
||||
h := fullH / 2
|
||||
data, err := dibCopy(hbmMask, w, fullH)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
img := image.NewRGBA(image.Rect(0, 0, int(w), int(h)))
|
||||
for y := int32(0); y < h; y++ {
|
||||
for x := int32(0); x < w; x++ {
|
||||
and := data[(y*w+x)*4]
|
||||
xor := data[((y+h)*w+x)*4]
|
||||
di := (y*w + x) * 4
|
||||
if and != 0 {
|
||||
img.Pix[di+3] = 0
|
||||
continue
|
||||
}
|
||||
c := byte(0)
|
||||
if xor != 0 {
|
||||
c = 255
|
||||
}
|
||||
img.Pix[di+0] = c
|
||||
img.Pix[di+1] = c
|
||||
img.Pix[di+2] = c
|
||||
img.Pix[di+3] = 255
|
||||
}
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// cursorState is the latest snapshot shared between the worker and
|
||||
// session readers.
|
||||
type cursorState struct {
|
||||
mu sync.Mutex
|
||||
snapshot *cursorSnapshot
|
||||
}
|
||||
|
||||
func (s *cursorState) store(snap *cursorSnapshot) {
|
||||
s.mu.Lock()
|
||||
s.snapshot = snap
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *cursorState) load() *cursorSnapshot {
|
||||
s.mu.Lock()
|
||||
snap := s.snapshot
|
||||
s.mu.Unlock()
|
||||
return snap
|
||||
}
|
||||
|
||||
// Cursor satisfies cursorSource by returning the latest snapshot the
|
||||
// capture worker decoded. The "no sample yet" and "cursor hidden" cases
|
||||
// return img=nil with no error so callers skip emission this cycle
|
||||
// without latching the source off for the rest of the session.
|
||||
func (c *DesktopCapturer) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
snap := c.cursorState.load()
|
||||
if snap == nil {
|
||||
return nil, 0, 0, 0, nil
|
||||
}
|
||||
if snap.err != nil {
|
||||
return nil, 0, 0, 0, snap.err
|
||||
}
|
||||
return snap.img, snap.hotX, snap.hotY, snap.serial, nil
|
||||
}
|
||||
|
||||
// CursorPos returns the cursor screen position observed by the worker on
|
||||
// its last sample. Errors out if the worker hasn't yet captured a frame
|
||||
// or the most recent sample failed.
|
||||
func (c *DesktopCapturer) CursorPos() (int, int, error) {
|
||||
snap := c.cursorState.load()
|
||||
if snap == nil {
|
||||
return 0, 0, fmt.Errorf("cursor position not sampled yet")
|
||||
}
|
||||
if snap.err != nil {
|
||||
return 0, 0, snap.err
|
||||
}
|
||||
if !snap.hasPos {
|
||||
return 0, 0, fmt.Errorf("cursor position unavailable")
|
||||
}
|
||||
return snap.posX, snap.posY, nil
|
||||
}
|
||||
127
client/vnc/server/cursor_x11.go
Normal file
127
client/vnc/server/cursor_x11.go
Normal file
@@ -0,0 +1,127 @@
|
||||
//go:build unix && !darwin && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"sync"
|
||||
|
||||
"github.com/jezek/xgb"
|
||||
"github.com/jezek/xgb/xfixes"
|
||||
)
|
||||
|
||||
// xfixesCursor reports the current X cursor sprite via the XFixes extension.
|
||||
// CursorSerial changes whenever the server picks a different cursor, so
|
||||
// callers can cache by serial without comparing pixels.
|
||||
type xfixesCursor struct {
|
||||
mu sync.Mutex
|
||||
conn *xgb.Conn
|
||||
// lastPosX/lastPosY hold the cursor screen position observed on the
|
||||
// most recent successful GetCursorImage. cursorPositionSource readers
|
||||
// share this value so we do not pay a second X round-trip per frame.
|
||||
lastPosX, lastPosY int
|
||||
hasPos bool
|
||||
// lastImg, lastHotX, lastHotY, lastSerial cache the most recent good
|
||||
// GetCursorImage result so transient failures (cursor hidden, server
|
||||
// briefly unresponsive) reuse the previous sprite instead of going
|
||||
// dark. Without this the encoder's compositing path drops to no-op as
|
||||
// soon as the cursor becomes momentarily unavailable.
|
||||
lastImg *image.RGBA
|
||||
lastHotX int
|
||||
lastHotY int
|
||||
lastSerial uint64
|
||||
}
|
||||
|
||||
// newXFixesCursor initialises the XFixes extension on conn. Returns an
|
||||
// error if the extension is unavailable; callers can fall back to no
|
||||
// cursor emission instead of asking on every frame.
|
||||
func newXFixesCursor(conn *xgb.Conn) (*xfixesCursor, error) {
|
||||
if err := xfixes.Init(conn); err != nil {
|
||||
return nil, fmt.Errorf("xfixes init: %w", err)
|
||||
}
|
||||
if _, err := xfixes.QueryVersion(conn, 4, 0).Reply(); err != nil {
|
||||
return nil, fmt.Errorf("xfixes query version: %w", err)
|
||||
}
|
||||
return &xfixesCursor{conn: conn}, nil
|
||||
}
|
||||
|
||||
// Cursor returns the current cursor sprite as RGBA along with its hotspot
|
||||
// and serial. Callers should treat an unchanged serial as "no update". On
|
||||
// a transient GetCursorImage failure the last cached sprite is returned
|
||||
// so compositing keeps painting the cursor instead of disappearing.
|
||||
func (c *xfixesCursor) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
reply, err := xfixes.GetCursorImage(c.conn).Reply()
|
||||
if err != nil {
|
||||
if c.lastImg != nil {
|
||||
return c.lastImg, c.lastHotX, c.lastHotY, c.lastSerial, nil
|
||||
}
|
||||
return nil, 0, 0, 0, fmt.Errorf("xfixes GetCursorImage: %w", err)
|
||||
}
|
||||
c.lastPosX, c.lastPosY, c.hasPos = int(reply.X), int(reply.Y), true
|
||||
w, h := int(reply.Width), int(reply.Height)
|
||||
if w <= 0 || h <= 0 {
|
||||
if c.lastImg != nil {
|
||||
return c.lastImg, c.lastHotX, c.lastHotY, c.lastSerial, nil
|
||||
}
|
||||
return nil, 0, 0, 0, fmt.Errorf("cursor has zero extent")
|
||||
}
|
||||
if len(reply.CursorImage) < w*h {
|
||||
if c.lastImg != nil {
|
||||
return c.lastImg, c.lastHotX, c.lastHotY, c.lastSerial, nil
|
||||
}
|
||||
return nil, 0, 0, 0, fmt.Errorf("cursor pixel buffer truncated: %d < %d", len(reply.CursorImage), w*h)
|
||||
}
|
||||
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
// XFixes packs each pixel as a uint32 in ARGB order with premultiplied
|
||||
// alpha. Unpack into the standard RGBA byte layout.
|
||||
for i, p := range reply.CursorImage[:w*h] {
|
||||
o := i * 4
|
||||
img.Pix[o+0] = byte(p >> 16)
|
||||
img.Pix[o+1] = byte(p >> 8)
|
||||
img.Pix[o+2] = byte(p)
|
||||
img.Pix[o+3] = byte(p >> 24)
|
||||
}
|
||||
c.lastImg = img
|
||||
c.lastHotX = int(reply.Xhot)
|
||||
c.lastHotY = int(reply.Yhot)
|
||||
c.lastSerial = uint64(reply.CursorSerial)
|
||||
return img, c.lastHotX, c.lastHotY, c.lastSerial, nil
|
||||
}
|
||||
|
||||
// Cursor on X11Capturer satisfies cursorSource. The XFixes binding is
|
||||
// created lazily on the same X connection used for screen capture; the
|
||||
// first init failure is latched so we stop asking on every frame.
|
||||
func (x *X11Capturer) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
x.mu.Lock()
|
||||
if x.cursor == nil && x.cursorInitErr == nil {
|
||||
x.cursor, x.cursorInitErr = newXFixesCursor(x.conn)
|
||||
}
|
||||
cur := x.cursor
|
||||
initErr := x.cursorInitErr
|
||||
x.mu.Unlock()
|
||||
if initErr != nil {
|
||||
return nil, 0, 0, 0, initErr
|
||||
}
|
||||
return cur.Cursor()
|
||||
}
|
||||
|
||||
// CursorPos on X11Capturer returns the screen position from the most
|
||||
// recent successful Cursor() call. Sessions call Cursor() once per encode
|
||||
// cycle, so this stays current without a second X round-trip.
|
||||
func (x *X11Capturer) CursorPos() (int, int, error) {
|
||||
x.mu.Lock()
|
||||
cur := x.cursor
|
||||
x.mu.Unlock()
|
||||
if cur == nil {
|
||||
return 0, 0, fmt.Errorf("cursor source not initialised")
|
||||
}
|
||||
cur.mu.Lock()
|
||||
defer cur.mu.Unlock()
|
||||
if !cur.hasPos {
|
||||
return 0, 0, fmt.Errorf("cursor position not sampled yet")
|
||||
}
|
||||
return cur.lastPosX, cur.lastPosY, nil
|
||||
}
|
||||
159
client/vnc/server/extclipboard.go
Normal file
159
client/vnc/server/extclipboard.go
Normal file
@@ -0,0 +1,159 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/zlib"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// ExtendedClipboard is an RFB community extension (pseudo-encoding
|
||||
// 0xC0A1E5CE) that replaces legacy CutText with a Caps/Notify/Request/
|
||||
// Provide/Peek handshake. Wins versus legacy CutText:
|
||||
// - UTF-8 text format (legacy is Latin-1).
|
||||
// - Pull-based: a Notify announces "I have new content", the peer fetches
|
||||
// via Request only when it actually needs the data. Saves bandwidth on
|
||||
// high-latency transports versus pushing every change.
|
||||
// - zlib-compressed payloads.
|
||||
// - Caps negotiation so each side knows the other's per-format max size.
|
||||
//
|
||||
// The extension reuses message opcodes 3 (ServerCutText) and 6 (ClientCutText)
|
||||
// and signals "extended" by encoding the length field as a negative int32;
|
||||
// the absolute value is the payload size in bytes. The first 4 bytes of
|
||||
// payload are a flags word: top byte is the action, low 16 bits are the
|
||||
// format mask.
|
||||
const pseudoEncExtendedClipboard = -1063131698 // 0xC0A1E5CE as int32
|
||||
|
||||
const (
|
||||
extClipActionCaps uint32 = 0x01000000
|
||||
extClipActionRequest uint32 = 0x02000000
|
||||
extClipActionPeek uint32 = 0x04000000
|
||||
extClipActionNotify uint32 = 0x08000000
|
||||
extClipActionProvide uint32 = 0x10000000
|
||||
extClipActionMask uint32 = 0x1F000000
|
||||
|
||||
extClipFormatText uint32 = 0x00000001
|
||||
extClipFormatRTF uint32 = 0x00000002
|
||||
extClipFormatHTML uint32 = 0x00000004
|
||||
extClipFormatDIB uint32 = 0x00000008
|
||||
extClipFormatFiles uint32 = 0x00000010
|
||||
extClipFormatMask uint32 = 0x0000FFFF
|
||||
|
||||
// extClipMaxText caps our accepted text payload. Mirrors the legacy
|
||||
// maxCutTextBytes (1 MiB); advertised in Caps and enforced on Provide.
|
||||
extClipMaxText = maxCutTextBytes
|
||||
|
||||
// extClipMaxPayload bounds the raw on-wire payload we will read for an
|
||||
// extended CutText message. Includes flags header, length prefixes, NUL,
|
||||
// and zlib framing overhead on top of the text body.
|
||||
extClipMaxPayload = extClipMaxText + 1024
|
||||
)
|
||||
|
||||
// buildExtClipCaps emits the Caps payload. The flags word advertises every
|
||||
// action we support in the high byte (Caps + Request + Peek + Notify +
|
||||
// Provide) and every format we accept in the low 16 bits. Clients use
|
||||
// these action bits to decide whether to auto-Request on Notify; without
|
||||
// Request in our Caps a conforming client silently drops our Notify
|
||||
// messages. After the flags word we emit one uint32 max size per format
|
||||
// bit set, in ascending bit order.
|
||||
func buildExtClipCaps() []byte {
|
||||
flags := extClipActionCaps | extClipActionRequest | extClipActionPeek |
|
||||
extClipActionNotify | extClipActionProvide | extClipFormatText
|
||||
payload := make([]byte, 4+4)
|
||||
binary.BigEndian.PutUint32(payload[0:4], flags)
|
||||
binary.BigEndian.PutUint32(payload[4:8], uint32(extClipMaxText))
|
||||
return payload
|
||||
}
|
||||
|
||||
// buildExtClipNotify emits a Notify announcing that we have new clipboard
|
||||
// content available in the given format mask. No data is shipped; the peer
|
||||
// pulls via Request when it actually needs to paste.
|
||||
func buildExtClipNotify(formats uint32) []byte {
|
||||
payload := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(payload, extClipActionNotify|formats)
|
||||
return payload
|
||||
}
|
||||
|
||||
// buildExtClipRequest emits a Request asking the peer to send Provide for
|
||||
// the given format mask. Sent in response to an inbound Notify.
|
||||
func buildExtClipRequest(formats uint32) []byte {
|
||||
payload := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(payload, extClipActionRequest|formats)
|
||||
return payload
|
||||
}
|
||||
|
||||
// buildExtClipProvideText emits a Provide carrying UTF-8 text. The inner
|
||||
// stream (4-byte length including the trailing NUL, then UTF-8 bytes, then
|
||||
// NUL) is zlib-compressed; each Provide uses an independent zlib context
|
||||
// per the extension spec. Rejects oversized input so a caller bug can't
|
||||
// produce a payload larger than the size advertised in our Caps.
|
||||
func buildExtClipProvideText(text string) ([]byte, error) {
|
||||
if len(text)+1 > extClipMaxText {
|
||||
return nil, fmt.Errorf("clipboard text exceeds extClipMaxText (%d > %d)", len(text)+1, extClipMaxText)
|
||||
}
|
||||
body := make([]byte, 0, 4+len(text)+1)
|
||||
var lenBuf [4]byte
|
||||
binary.BigEndian.PutUint32(lenBuf[:], uint32(len(text)+1))
|
||||
body = append(body, lenBuf[:]...)
|
||||
body = append(body, text...)
|
||||
body = append(body, 0)
|
||||
|
||||
var compressed bytes.Buffer
|
||||
zw := zlib.NewWriter(&compressed)
|
||||
if _, err := zw.Write(body); err != nil {
|
||||
return nil, fmt.Errorf("zlib write: %w", err)
|
||||
}
|
||||
if err := zw.Close(); err != nil {
|
||||
return nil, fmt.Errorf("zlib close: %w", err)
|
||||
}
|
||||
|
||||
payload := make([]byte, 4+compressed.Len())
|
||||
binary.BigEndian.PutUint32(payload[0:4], extClipActionProvide|extClipFormatText)
|
||||
copy(payload[4:], compressed.Bytes())
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// parseExtClipProvideText decompresses a Provide payload (the bytes after
|
||||
// the 4-byte flags header) and returns the UTF-8 text record if the text
|
||||
// format bit is set. Records for other formats are skipped. The trailing
|
||||
// NUL byte the spec appends to text records is stripped.
|
||||
func parseExtClipProvideText(flags uint32, payload []byte) (string, error) {
|
||||
zr, err := zlib.NewReader(bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("zlib reader: %w", err)
|
||||
}
|
||||
defer zr.Close()
|
||||
|
||||
limited := io.LimitReader(zr, int64(extClipMaxText)+16)
|
||||
var text string
|
||||
for bit := uint32(1); bit <= extClipFormatFiles; bit <<= 1 {
|
||||
if flags&bit == 0 {
|
||||
continue
|
||||
}
|
||||
var sizeBuf [4]byte
|
||||
if _, err := io.ReadFull(limited, sizeBuf[:]); err != nil {
|
||||
if bit == extClipFormatText && err == io.EOF {
|
||||
return "", nil
|
||||
}
|
||||
return "", fmt.Errorf("read record size: %w", err)
|
||||
}
|
||||
size := binary.BigEndian.Uint32(sizeBuf[:])
|
||||
if size > uint32(extClipMaxText) {
|
||||
return "", fmt.Errorf("record too large: %d", size)
|
||||
}
|
||||
rec := make([]byte, size)
|
||||
if _, err := io.ReadFull(limited, rec); err != nil {
|
||||
return "", fmt.Errorf("read record: %w", err)
|
||||
}
|
||||
if bit == extClipFormatText {
|
||||
if len(rec) > 0 && rec[len(rec)-1] == 0 {
|
||||
rec = rec[:len(rec)-1]
|
||||
}
|
||||
text = string(rec)
|
||||
}
|
||||
}
|
||||
return text, nil
|
||||
}
|
||||
102
client/vnc/server/extclipboard_test.go
Normal file
102
client/vnc/server/extclipboard_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildExtClipCaps(t *testing.T) {
|
||||
payload := buildExtClipCaps()
|
||||
require.Len(t, payload, 8, "Caps with one format should be 4 bytes flags + 4 bytes size")
|
||||
|
||||
flags := binary.BigEndian.Uint32(payload[0:4])
|
||||
// Clients check individual action bits in our Caps to decide whether to
|
||||
// auto-Request on Notify, so all supported actions must be advertised.
|
||||
assert.NotZero(t, flags&extClipActionCaps, "Caps action bit must be set")
|
||||
assert.NotZero(t, flags&extClipActionRequest, "Request action bit must be set")
|
||||
assert.NotZero(t, flags&extClipActionPeek, "Peek action bit must be set")
|
||||
assert.NotZero(t, flags&extClipActionNotify, "Notify action bit must be set")
|
||||
assert.NotZero(t, flags&extClipActionProvide, "Provide action bit must be set")
|
||||
assert.Equal(t, extClipFormatText, flags&extClipFormatMask, "should advertise text format")
|
||||
|
||||
maxSize := binary.BigEndian.Uint32(payload[4:8])
|
||||
assert.Equal(t, uint32(extClipMaxText), maxSize, "should advertise extClipMaxText")
|
||||
}
|
||||
|
||||
func TestBuildExtClipNotify(t *testing.T) {
|
||||
payload := buildExtClipNotify(extClipFormatText)
|
||||
require.Len(t, payload, 4)
|
||||
flags := binary.BigEndian.Uint32(payload)
|
||||
assert.Equal(t, extClipActionNotify, flags&extClipActionMask)
|
||||
assert.Equal(t, extClipFormatText, flags&extClipFormatMask)
|
||||
}
|
||||
|
||||
func TestBuildExtClipRequest(t *testing.T) {
|
||||
payload := buildExtClipRequest(extClipFormatText)
|
||||
require.Len(t, payload, 4)
|
||||
flags := binary.BigEndian.Uint32(payload)
|
||||
assert.Equal(t, extClipActionRequest, flags&extClipActionMask)
|
||||
assert.Equal(t, extClipFormatText, flags&extClipFormatMask)
|
||||
}
|
||||
|
||||
func TestExtClipProvideRoundTripASCII(t *testing.T) {
|
||||
const original = "hello world"
|
||||
payload, err := buildExtClipProvideText(original)
|
||||
require.NoError(t, err)
|
||||
|
||||
flags := binary.BigEndian.Uint32(payload[0:4])
|
||||
require.Equal(t, extClipActionProvide, flags&extClipActionMask)
|
||||
require.Equal(t, extClipFormatText, flags&extClipFormatMask)
|
||||
|
||||
text, err := parseExtClipProvideText(flags, payload[4:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, original, text)
|
||||
}
|
||||
|
||||
func TestExtClipProvideRoundTripUTF8(t *testing.T) {
|
||||
original := "héllo 🦀 世界"
|
||||
payload, err := buildExtClipProvideText(original)
|
||||
require.NoError(t, err)
|
||||
|
||||
flags := binary.BigEndian.Uint32(payload[0:4])
|
||||
text, err := parseExtClipProvideText(flags, payload[4:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, original, text, "UTF-8 should round-trip without mangling")
|
||||
}
|
||||
|
||||
func TestExtClipProvideRoundTripEmpty(t *testing.T) {
|
||||
payload, err := buildExtClipProvideText("")
|
||||
require.NoError(t, err)
|
||||
|
||||
flags := binary.BigEndian.Uint32(payload[0:4])
|
||||
text, err := parseExtClipProvideText(flags, payload[4:])
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, text)
|
||||
}
|
||||
|
||||
func TestExtClipProvideRoundTripLarge(t *testing.T) {
|
||||
original := strings.Repeat("abcd", 200000) // 800 KiB, below cap
|
||||
payload, err := buildExtClipProvideText(original)
|
||||
require.NoError(t, err)
|
||||
assert.Less(t, len(payload), len(original)/2,
|
||||
"highly repetitive text should compress significantly")
|
||||
|
||||
flags := binary.BigEndian.Uint32(payload[0:4])
|
||||
text, err := parseExtClipProvideText(flags, payload[4:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, original, text)
|
||||
}
|
||||
|
||||
func TestParseExtClipProvideTextRejectsOversized(t *testing.T) {
|
||||
var fakePayload [4]byte
|
||||
// 4 bytes of zlib-compressed garbage won't decode; we want to ensure we
|
||||
// don't panic, not that we accept it.
|
||||
_, err := parseExtClipProvideText(extClipActionProvide|extClipFormatText, fakePayload[:])
|
||||
assert.Error(t, err)
|
||||
}
|
||||
865
client/vnc/server/input_darwin.go
Normal file
865
client/vnc/server/input_darwin.go
Normal file
@@ -0,0 +1,865 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Core Graphics event constants.
|
||||
const (
|
||||
kCGEventSourceStateCombinedSessionState int32 = 0
|
||||
|
||||
kCGEventLeftMouseDown int32 = 1
|
||||
kCGEventLeftMouseUp int32 = 2
|
||||
kCGEventRightMouseDown int32 = 3
|
||||
kCGEventRightMouseUp int32 = 4
|
||||
kCGEventMouseMoved int32 = 5
|
||||
kCGEventLeftMouseDragged int32 = 6
|
||||
kCGEventRightMouseDragged int32 = 7
|
||||
kCGEventKeyDown int32 = 10
|
||||
kCGEventKeyUp int32 = 11
|
||||
kCGEventFlagsChanged int32 = 12
|
||||
kCGEventOtherMouseDown int32 = 25
|
||||
kCGEventOtherMouseUp int32 = 26
|
||||
|
||||
kCGMouseButtonLeft int32 = 0
|
||||
kCGMouseButtonRight int32 = 1
|
||||
kCGMouseButtonCenter int32 = 2
|
||||
|
||||
kCGHIDEventTap int32 = 0
|
||||
|
||||
// kCGEventFlagMaskSecondaryFn is the CGEventFlags bit Apple sets when
|
||||
// a key was activated via the Fn modifier on internal keyboards. The
|
||||
// navigation cluster (ForwardDelete, Home, End, PageUp, PageDown,
|
||||
// Help/Insert, arrows) lives in the Fn-shifted region of an Apple
|
||||
// keyboard, so synthesising those keycodes without this bit leaves the
|
||||
// system in a confused "Fn implied" state where the next plain
|
||||
// letter is treated as a menu accelerator.
|
||||
kCGEventFlagMaskSecondaryFn uint64 = 0x00800000
|
||||
|
||||
// kCGMouseEventClickState (event field 1) tells macOS how many
|
||||
// consecutive clicks of this button have happened. Without it, a
|
||||
// double click looks like two independent single clicks and apps
|
||||
// never see the dblclick (window-bar maximize, text word-select, ...).
|
||||
kCGMouseEventClickState int32 = 1
|
||||
|
||||
// doubleClickWindow is the upper bound on the gap between two
|
||||
// down events that still counts as a multi-click. macOS reads the
|
||||
// user's setting from CGEventSourceGetDoubleClickInterval; 500ms is
|
||||
// the default and works as a safe injection-side ceiling.
|
||||
doubleClickWindow = 500 * time.Millisecond
|
||||
|
||||
// IOKit power management constants.
|
||||
kIOPMUserActiveLocal int32 = 0
|
||||
kIOPMAssertionLevelOn uint32 = 255
|
||||
kCFStringEncodingUTF8 uint32 = 0x08000100
|
||||
)
|
||||
|
||||
var darwinInputOnce sync.Once
|
||||
|
||||
var (
|
||||
cgEventSourceCreate func(int32) uintptr
|
||||
cgEventCreateKeyboardEvent func(uintptr, uint16, bool) uintptr
|
||||
// CGEventCreateMouseEvent takes CGPoint as two separate float64 args.
|
||||
// purego can't handle array/struct types but individual float64s work.
|
||||
cgEventCreateMouseEvent func(uintptr, int32, float64, float64, int32) uintptr
|
||||
cgEventPost func(int32, uintptr)
|
||||
cgEventSetIntegerValueField func(uintptr, int32, int64)
|
||||
cgEventSetFlags func(uintptr, uint64)
|
||||
cgEventSetType func(uintptr, int32)
|
||||
cgEventCreateForInput func(uintptr) uintptr
|
||||
|
||||
// CGEventCreateScrollWheelEvent is variadic, call via SyscallN.
|
||||
cgEventCreateScrollWheelEventAddr uintptr
|
||||
|
||||
axIsProcessTrusted func() bool
|
||||
// axIsProcessTrustedWithOptions takes a CFDictionary; when the dict's
|
||||
// kAXTrustedCheckOptionPrompt key is true, macOS shows the native
|
||||
// Accessibility prompt with an "Open System Settings" button the
|
||||
// first time the process asks. The bare AXIsProcessTrusted variant is
|
||||
// a silent check that never prompts.
|
||||
axIsProcessTrustedWithOptions func(uintptr) bool
|
||||
// cfDictionaryCreate builds the options dictionary above.
|
||||
cfDictionaryCreate func(uintptr, *uintptr, *uintptr, int64, uintptr, uintptr) uintptr
|
||||
// cfBooleanTrue is the global CF boolean we cache from a Dlsym lookup.
|
||||
cfBooleanTrue uintptr
|
||||
// axTrustedCheckOptionPromptCFStr is the option key for the dict.
|
||||
axTrustedCheckOptionPromptCFStr uintptr
|
||||
// kCFTypeDictionaryKey/Value CallBacks: standard CF retain/release
|
||||
// callback tables. Required so the dict properly manages refcounts on
|
||||
// the CFString key and CFBoolean value.
|
||||
kCFTypeDictionaryKeyCallBacksAddr uintptr
|
||||
kCFTypeDictionaryValueCallBacksAddr uintptr
|
||||
|
||||
// IOKit power-management bindings used to wake the display and inhibit
|
||||
// idle sleep while a VNC client is driving input.
|
||||
iopmAssertionDeclareUserActivity func(uintptr, int32, *uint32) int32
|
||||
iopmAssertionCreateWithName func(uintptr, uint32, uintptr, *uint32) int32
|
||||
iopmAssertionRelease func(uint32) int32
|
||||
cfStringCreateWithCString func(uintptr, string, uint32) uintptr
|
||||
|
||||
// Cached CFStrings for assertion name and idle-sleep type.
|
||||
pmAssertionNameCFStr uintptr
|
||||
pmPreventIdleDisplayCFStr uintptr
|
||||
|
||||
// Assertion IDs. userActivityID is reused across input events so repeated
|
||||
// calls refresh the same assertion rather than create new ones.
|
||||
pmMu sync.Mutex
|
||||
userActivityID uint32
|
||||
preventSleepID uint32
|
||||
preventSleepHeld bool
|
||||
// preventSleepRef tracks the refcount of held assertions across
|
||||
// concurrent injectors and sessions.
|
||||
preventSleepRef int
|
||||
|
||||
darwinInputReady bool
|
||||
darwinEventSource uintptr
|
||||
)
|
||||
|
||||
func initDarwinInput() {
|
||||
darwinInputOnce.Do(func() {
|
||||
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load CoreGraphics for input: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
purego.RegisterLibFunc(&cgEventSourceCreate, cg, "CGEventSourceCreate")
|
||||
purego.RegisterLibFunc(&cgEventCreateKeyboardEvent, cg, "CGEventCreateKeyboardEvent")
|
||||
purego.RegisterLibFunc(&cgEventCreateMouseEvent, cg, "CGEventCreateMouseEvent")
|
||||
purego.RegisterLibFunc(&cgEventPost, cg, "CGEventPost")
|
||||
purego.RegisterLibFunc(&cgEventSetIntegerValueField, cg, "CGEventSetIntegerValueField")
|
||||
purego.RegisterLibFunc(&cgEventSetFlags, cg, "CGEventSetFlags")
|
||||
purego.RegisterLibFunc(&cgEventSetType, cg, "CGEventSetType")
|
||||
purego.RegisterLibFunc(&cgEventCreateForInput, cg, "CGEventCreate")
|
||||
|
||||
sym, err := purego.Dlsym(cg, "CGEventCreateScrollWheelEvent")
|
||||
if err == nil {
|
||||
cgEventCreateScrollWheelEventAddr = sym
|
||||
}
|
||||
|
||||
if ax, err := purego.Dlopen("/System/Library/Frameworks/ApplicationServices.framework/ApplicationServices", purego.RTLD_NOW|purego.RTLD_GLOBAL); err == nil {
|
||||
if sym, err := purego.Dlsym(ax, "AXIsProcessTrusted"); err == nil {
|
||||
purego.RegisterFunc(&axIsProcessTrusted, sym)
|
||||
}
|
||||
if sym, err := purego.Dlsym(ax, "AXIsProcessTrustedWithOptions"); err == nil {
|
||||
purego.RegisterFunc(&axIsProcessTrustedWithOptions, sym)
|
||||
}
|
||||
}
|
||||
|
||||
// initPowerAssertions registers cfStringCreateWithCString, which
|
||||
// initCFDictionarySymbols then uses to build the AX prompt key.
|
||||
initPowerAssertions()
|
||||
initCFDictionarySymbols()
|
||||
|
||||
darwinInputReady = true
|
||||
})
|
||||
}
|
||||
|
||||
// initCFDictionarySymbols loads the CF symbols needed to build the
|
||||
// options dictionary for AXIsProcessTrustedWithOptions. Best-effort:
|
||||
// failure here just leaves axIsProcessTrustedWithOptions unusable and we
|
||||
// fall back to the silent check.
|
||||
func initCFDictionarySymbols() {
|
||||
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load CoreFoundation for AX prompt dict: %v", err)
|
||||
return
|
||||
}
|
||||
if sym, err := purego.Dlsym(cf, "CFDictionaryCreate"); err == nil {
|
||||
purego.RegisterFunc(&cfDictionaryCreate, sym)
|
||||
}
|
||||
if sym, err := purego.Dlsym(cf, "kCFTypeDictionaryKeyCallBacks"); err == nil {
|
||||
kCFTypeDictionaryKeyCallBacksAddr = sym
|
||||
}
|
||||
if sym, err := purego.Dlsym(cf, "kCFTypeDictionaryValueCallBacks"); err == nil {
|
||||
kCFTypeDictionaryValueCallBacksAddr = sym
|
||||
}
|
||||
if sym, err := purego.Dlsym(cf, "kCFBooleanTrue"); err == nil {
|
||||
// kCFBooleanTrue is a pointer-to-pointer (CFBooleanRef stored at the
|
||||
// symbol address). Dereference once to get the actual CFBoolean.
|
||||
cfBooleanTrue = *(*uintptr)(unsafe.Pointer(sym))
|
||||
}
|
||||
if cfStringCreateWithCString != nil {
|
||||
axTrustedCheckOptionPromptCFStr = cfStringCreateWithCString(0, "AXTrustedCheckOptionPrompt", kCFStringEncodingUTF8)
|
||||
}
|
||||
}
|
||||
|
||||
func initPowerAssertions() {
|
||||
iokit, err := purego.Dlopen("/System/Library/Frameworks/IOKit.framework/IOKit", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load IOKit: %v", err)
|
||||
return
|
||||
}
|
||||
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load CoreFoundation for power assertions: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
purego.RegisterLibFunc(&cfStringCreateWithCString, cf, "CFStringCreateWithCString")
|
||||
purego.RegisterLibFunc(&iopmAssertionDeclareUserActivity, iokit, "IOPMAssertionDeclareUserActivity")
|
||||
purego.RegisterLibFunc(&iopmAssertionCreateWithName, iokit, "IOPMAssertionCreateWithName")
|
||||
purego.RegisterLibFunc(&iopmAssertionRelease, iokit, "IOPMAssertionRelease")
|
||||
|
||||
pmAssertionNameCFStr = cfStringCreateWithCString(0, "NetBird VNC input", kCFStringEncodingUTF8)
|
||||
pmPreventIdleDisplayCFStr = cfStringCreateWithCString(0, "PreventUserIdleDisplaySleep", kCFStringEncodingUTF8)
|
||||
}
|
||||
|
||||
// wakeDisplay declares user activity so macOS treats the synthesized input as
|
||||
// real HID activity, waking the display if it is asleep. Called on every key
|
||||
// and pointer event; the kernel coalesces repeated calls cheaply.
|
||||
func wakeDisplay() {
|
||||
if iopmAssertionDeclareUserActivity == nil || pmAssertionNameCFStr == 0 {
|
||||
return
|
||||
}
|
||||
pmMu.Lock()
|
||||
defer pmMu.Unlock()
|
||||
id := userActivityID
|
||||
r := iopmAssertionDeclareUserActivity(pmAssertionNameCFStr, kIOPMUserActiveLocal, &id)
|
||||
if r != 0 {
|
||||
log.Tracef("IOPMAssertionDeclareUserActivity returned %d", r)
|
||||
return
|
||||
}
|
||||
userActivityID = id
|
||||
}
|
||||
|
||||
// holdPreventIdleSleep creates an assertion that keeps the display from going
|
||||
// idle-to-sleep while a VNC session is active. Reference-counted so multiple
|
||||
// concurrent sessions don't yank the assertion when one of them releases.
|
||||
func holdPreventIdleSleep() {
|
||||
if iopmAssertionCreateWithName == nil || pmPreventIdleDisplayCFStr == 0 || pmAssertionNameCFStr == 0 {
|
||||
return
|
||||
}
|
||||
pmMu.Lock()
|
||||
defer pmMu.Unlock()
|
||||
preventSleepRef++
|
||||
if preventSleepRef > 1 {
|
||||
return
|
||||
}
|
||||
var id uint32
|
||||
r := iopmAssertionCreateWithName(pmPreventIdleDisplayCFStr, kIOPMAssertionLevelOn, pmAssertionNameCFStr, &id)
|
||||
if r != 0 {
|
||||
log.Debugf("IOPMAssertionCreateWithName returned %d", r)
|
||||
// Reset the refcount on failure so a later successful hold can take it.
|
||||
preventSleepRef = 0
|
||||
return
|
||||
}
|
||||
preventSleepID = id
|
||||
preventSleepHeld = true
|
||||
}
|
||||
|
||||
// releasePreventIdleSleep decrements the assertion refcount and only drops
|
||||
// the actual IOKit assertion on the final release.
|
||||
func releasePreventIdleSleep() {
|
||||
if iopmAssertionRelease == nil {
|
||||
return
|
||||
}
|
||||
pmMu.Lock()
|
||||
defer pmMu.Unlock()
|
||||
if !preventSleepHeld || preventSleepRef == 0 {
|
||||
return
|
||||
}
|
||||
preventSleepRef--
|
||||
if preventSleepRef > 0 {
|
||||
return
|
||||
}
|
||||
if r := iopmAssertionRelease(preventSleepID); r != 0 {
|
||||
log.Debugf("IOPMAssertionRelease returned %d", r)
|
||||
}
|
||||
preventSleepHeld = false
|
||||
preventSleepID = 0
|
||||
}
|
||||
|
||||
func ensureEventSource() uintptr {
|
||||
if darwinEventSource != 0 {
|
||||
return darwinEventSource
|
||||
}
|
||||
darwinEventSource = cgEventSourceCreate(kCGEventSourceStateCombinedSessionState)
|
||||
return darwinEventSource
|
||||
}
|
||||
|
||||
// MacInputInjector injects keyboard and mouse events via Core Graphics.
|
||||
type MacInputInjector struct {
|
||||
lastButtons uint16
|
||||
pbcopyPath string
|
||||
pbpastePath string
|
||||
// clickCount[i] / clickAt[i] track the multi-click sequence for
|
||||
// button i (0=left, 1=right, 2=middle). macOS apps reconstruct
|
||||
// double/triple click semantics from the kCGMouseEventClickState
|
||||
// field on each posted event, not from event timing.
|
||||
clickCount [5]int64
|
||||
clickAt [5]time.Time
|
||||
}
|
||||
|
||||
// NewMacInputInjector creates a macOS input injector.
|
||||
func NewMacInputInjector() (*MacInputInjector, error) {
|
||||
initDarwinInput()
|
||||
if !darwinInputReady {
|
||||
return nil, fmt.Errorf("CoreGraphics not available for input injection")
|
||||
}
|
||||
checkMacPermissions()
|
||||
|
||||
m := &MacInputInjector{}
|
||||
if path, err := exec.LookPath("pbcopy"); err == nil {
|
||||
m.pbcopyPath = path
|
||||
}
|
||||
if path, err := exec.LookPath("pbpaste"); err == nil {
|
||||
m.pbpastePath = path
|
||||
}
|
||||
if m.pbcopyPath == "" || m.pbpastePath == "" {
|
||||
log.Debugf("clipboard tools not found (pbcopy=%q, pbpaste=%q)", m.pbcopyPath, m.pbpastePath)
|
||||
}
|
||||
|
||||
holdPreventIdleSleep()
|
||||
|
||||
log.Info("macOS input injector ready")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// checkMacPermissions probes Accessibility access. Prefers the prompting
|
||||
// variant of AXIsProcessTrusted: when the process is not yet trusted,
|
||||
// macOS shows its native "would like to control your computer" dialog
|
||||
// with an "Open System Settings" button. The silent variant is the
|
||||
// fallback when the prompting symbol or its CF dictionary plumbing
|
||||
// couldn't be loaded.
|
||||
func checkMacPermissions() {
|
||||
if !axProcessIsTrusted() {
|
||||
log.Warn("Accessibility permission not granted. Input injection will not work. " +
|
||||
"Approve the prompt or grant in System Settings > Privacy & Security > Accessibility.")
|
||||
openPrivacyPane("Privacy_Accessibility")
|
||||
}
|
||||
}
|
||||
|
||||
// axProcessIsTrusted asks macOS whether netbird has Accessibility access,
|
||||
// and triggers the native prompt the first time when not trusted. Returns
|
||||
// the current trust status either way.
|
||||
func axProcessIsTrusted() bool {
|
||||
if axIsProcessTrustedWithOptions != nil &&
|
||||
cfDictionaryCreate != nil &&
|
||||
axTrustedCheckOptionPromptCFStr != 0 &&
|
||||
cfBooleanTrue != 0 &&
|
||||
kCFTypeDictionaryKeyCallBacksAddr != 0 &&
|
||||
kCFTypeDictionaryValueCallBacksAddr != 0 {
|
||||
keys := [1]uintptr{axTrustedCheckOptionPromptCFStr}
|
||||
values := [1]uintptr{cfBooleanTrue}
|
||||
dict := cfDictionaryCreate(0, &keys[0], &values[0], 1,
|
||||
kCFTypeDictionaryKeyCallBacksAddr,
|
||||
kCFTypeDictionaryValueCallBacksAddr)
|
||||
if dict != 0 {
|
||||
return axIsProcessTrustedWithOptions(dict)
|
||||
}
|
||||
}
|
||||
if axIsProcessTrusted != nil {
|
||||
return axIsProcessTrusted()
|
||||
}
|
||||
// Symbol load failed entirely. Assume trusted so we don't spam the
|
||||
// log every cycle; capture/inject calls will report concrete errors
|
||||
// if access really is missing.
|
||||
return true
|
||||
}
|
||||
|
||||
// openPrivacyPane opens the relevant pane of System Settings so the user
|
||||
// can toggle the permission without navigating manually. The
|
||||
// x-apple.systempreferences URL scheme works on every macOS release from
|
||||
// 10.10 onward; the per-pane anchor (Privacy_Accessibility, Privacy_ScreenCapture)
|
||||
// is what System Settings/Preferences uses to land on the right row.
|
||||
func openPrivacyPane(pane string) {
|
||||
url := "x-apple.systempreferences:com.apple.preference.security?" + pane
|
||||
if err := exec.Command("open", url).Start(); err != nil {
|
||||
log.Debugf("open privacy pane %s: %v", pane, err)
|
||||
}
|
||||
}
|
||||
|
||||
// InjectKey simulates a key press or release.
|
||||
func (m *MacInputInjector) InjectKey(keysym uint32, down bool) {
|
||||
wakeDisplay()
|
||||
src := ensureEventSource()
|
||||
if src == 0 {
|
||||
return
|
||||
}
|
||||
keycode := keysymToMacKeycode(keysym)
|
||||
if keycode == 0xFFFF {
|
||||
return
|
||||
}
|
||||
m.postMacKey(src, keycode, down)
|
||||
}
|
||||
|
||||
// InjectKeyScancode injects using the QEMU scancode, mapped via the
|
||||
// qemuToMacVK table to Apple's virtual-keycode space. Apple uses an
|
||||
// entirely different scheme from PC AT scancodes, so the table is the
|
||||
// authoritative bridge. On miss we fall back to the keysym path.
|
||||
func (m *MacInputInjector) InjectKeyScancode(scancode, keysym uint32, down bool) {
|
||||
wakeDisplay()
|
||||
src := ensureEventSource()
|
||||
if src == 0 {
|
||||
return
|
||||
}
|
||||
vk, ok := qemuToMacVK[scancode]
|
||||
if !ok {
|
||||
// Fall back to the keysym path so unmapped keys still work.
|
||||
m.InjectKey(keysym, down)
|
||||
return
|
||||
}
|
||||
m.postMacKey(src, vk, down)
|
||||
}
|
||||
|
||||
// postMacKey emits a single key down/up event via Core Graphics. For
|
||||
// keycodes that live in the Fn-shifted region of an Apple keyboard we
|
||||
// also emit explicit flagsChanged events around the keypress: posting
|
||||
// the Fn flag on the key event alone leaves macOS's modifier state
|
||||
// machine without a matching transition, which manifests as "Fn stays
|
||||
// active" for the next key (e.g. the next letter activates a menu
|
||||
// accelerator).
|
||||
func (m *MacInputInjector) postMacKey(src uintptr, keycode uint16, down bool) {
|
||||
fnShifted := isFnShiftedKeycode(keycode)
|
||||
if fnShifted && down {
|
||||
postFnFlagsChanged(src, true)
|
||||
}
|
||||
event := cgEventCreateKeyboardEvent(src, keycode, down)
|
||||
if event == 0 {
|
||||
if fnShifted && !down {
|
||||
postFnFlagsChanged(src, false)
|
||||
}
|
||||
return
|
||||
}
|
||||
if fnShifted && cgEventSetFlags != nil {
|
||||
cgEventSetFlags(event, kCGEventFlagMaskSecondaryFn)
|
||||
}
|
||||
cgEventPost(kCGHIDEventTap, event)
|
||||
cfRelease(event)
|
||||
if fnShifted && !down {
|
||||
postFnFlagsChanged(src, false)
|
||||
}
|
||||
}
|
||||
|
||||
// postFnFlagsChanged emits a synthetic Fn modifier transition so the
|
||||
// system updates its global modifier state to match the key events we
|
||||
// post for the navigation cluster. Without this, posting a Fn-flagged
|
||||
// key event leaves macOS thinking Fn is still held after the key is
|
||||
// released.
|
||||
func postFnFlagsChanged(src uintptr, fnOn bool) {
|
||||
if cgEventCreateForInput == nil || cgEventSetType == nil || cgEventSetFlags == nil {
|
||||
return
|
||||
}
|
||||
event := cgEventCreateForInput(src)
|
||||
if event == 0 {
|
||||
return
|
||||
}
|
||||
cgEventSetType(event, kCGEventFlagsChanged)
|
||||
var flags uint64
|
||||
if fnOn {
|
||||
flags = kCGEventFlagMaskSecondaryFn
|
||||
}
|
||||
cgEventSetFlags(event, flags)
|
||||
cgEventPost(kCGHIDEventTap, event)
|
||||
cfRelease(event)
|
||||
}
|
||||
|
||||
// fnShiftedKeycodes are the Apple navigation/edit keys that hardware produces
|
||||
// with the Fn modifier held.
|
||||
var fnShiftedKeycodes = map[uint16]struct{}{
|
||||
0x72: {}, // Help / Insert
|
||||
0x73: {}, // Home
|
||||
0x74: {}, // PageUp
|
||||
0x75: {}, // ForwardDelete
|
||||
0x77: {}, // End
|
||||
0x79: {}, // PageDown
|
||||
0x7B: {}, // Left
|
||||
0x7C: {}, // Right
|
||||
0x7D: {}, // Down
|
||||
0x7E: {}, // Up
|
||||
}
|
||||
|
||||
// isFnShiftedKeycode reports whether keycode is one of the Apple
|
||||
// navigation/edit keys that hardware produces with the Fn modifier held.
|
||||
func isFnShiftedKeycode(keycode uint16) bool {
|
||||
_, ok := fnShiftedKeycodes[keycode]
|
||||
return ok
|
||||
}
|
||||
|
||||
// InjectPointer simulates mouse movement and button events.
|
||||
func (m *MacInputInjector) InjectPointer(buttonMask uint16, px, py, serverW, serverH int) {
|
||||
wakeDisplay()
|
||||
if serverW == 0 || serverH == 0 {
|
||||
return
|
||||
}
|
||||
src := ensureEventSource()
|
||||
if src == 0 {
|
||||
return
|
||||
}
|
||||
x, y := scalePxToLogical(px, py, serverW, serverH)
|
||||
m.dispatchPointer(src, buttonMask, x, y)
|
||||
m.lastButtons = buttonMask
|
||||
}
|
||||
|
||||
// scalePxToLogical converts framebuffer coordinates (physical pixels) into
|
||||
// the logical points CGEventCreateMouseEvent expects. Falls back to a 1:1
|
||||
// mapping if the display API is unavailable.
|
||||
func scalePxToLogical(px, py, serverW, serverH int) (float64, float64) {
|
||||
x, y := float64(px), float64(py)
|
||||
if cgDisplayPixelsWide == nil || cgMainDisplayID == nil {
|
||||
return x, y
|
||||
}
|
||||
displayID := cgMainDisplayID()
|
||||
logicalW := int(cgDisplayPixelsWide(displayID))
|
||||
logicalH := int(cgDisplayPixelsHigh(displayID))
|
||||
if logicalW <= 0 || logicalH <= 0 {
|
||||
return x, y
|
||||
}
|
||||
return float64(px) * float64(logicalW) / float64(serverW),
|
||||
float64(py) * float64(logicalH) / float64(serverH)
|
||||
}
|
||||
|
||||
func (m *MacInputInjector) dispatchPointer(src uintptr, buttonMask uint16, x, y float64) {
|
||||
leftDown := buttonMask&0x01 != 0
|
||||
rightDown := buttonMask&0x04 != 0
|
||||
middleDown := buttonMask&0x02 != 0
|
||||
m.postMoveOrDrag(src, leftDown, rightDown, x, y)
|
||||
m.postButtonTransitions(src, buttonMask, x, y)
|
||||
m.postScrollWheel(src, buttonMask)
|
||||
_ = middleDown
|
||||
}
|
||||
|
||||
func (m *MacInputInjector) postMoveOrDrag(src uintptr, leftDown, rightDown bool, x, y float64) {
|
||||
switch {
|
||||
case leftDown:
|
||||
m.postMouse(src, kCGEventLeftMouseDragged, x, y, kCGMouseButtonLeft)
|
||||
case rightDown:
|
||||
m.postMouse(src, kCGEventRightMouseDragged, x, y, kCGMouseButtonRight)
|
||||
default:
|
||||
m.postMouse(src, kCGEventMouseMoved, x, y, kCGMouseButtonLeft)
|
||||
}
|
||||
}
|
||||
|
||||
// postButtonTransitions emits the up/down events for each button whose
|
||||
// state changed against m.lastButtons, computing the click count so
|
||||
// macOS recognises double / triple clicks.
|
||||
func (m *MacInputInjector) postButtonTransitions(src uintptr, buttonMask uint16, x, y float64) {
|
||||
emit := func(curBit, prevBit uint16, down, up int32, button int32, idx int) {
|
||||
cur := buttonMask&curBit != 0
|
||||
prev := m.lastButtons&prevBit != 0
|
||||
if cur && !prev {
|
||||
now := time.Now()
|
||||
if !m.clickAt[idx].IsZero() && now.Sub(m.clickAt[idx]) <= doubleClickWindow {
|
||||
m.clickCount[idx]++
|
||||
} else {
|
||||
m.clickCount[idx] = 1
|
||||
}
|
||||
m.clickAt[idx] = now
|
||||
m.postMouseClick(src, down, x, y, button, m.clickCount[idx])
|
||||
} else if !cur && prev {
|
||||
count := m.clickCount[idx]
|
||||
if count == 0 {
|
||||
count = 1
|
||||
}
|
||||
m.postMouseClick(src, up, x, y, button, count)
|
||||
}
|
||||
}
|
||||
emit(0x01, 0x01, kCGEventLeftMouseDown, kCGEventLeftMouseUp, kCGMouseButtonLeft, 0)
|
||||
emit(0x04, 0x04, kCGEventRightMouseDown, kCGEventRightMouseUp, kCGMouseButtonRight, 1)
|
||||
emit(0x02, 0x02, kCGEventOtherMouseDown, kCGEventOtherMouseUp, kCGMouseButtonCenter, 2)
|
||||
// CG mouse-button numbers 3 (back) and 4 (forward) are emitted as
|
||||
// "other" events; macOS apps that swallow Browser nav (Finder, web
|
||||
// views) react to these directly.
|
||||
emit(1<<7, 1<<7, kCGEventOtherMouseDown, kCGEventOtherMouseUp, 3, 3)
|
||||
emit(1<<8, 1<<8, kCGEventOtherMouseDown, kCGEventOtherMouseUp, 4, 4)
|
||||
}
|
||||
|
||||
func (m *MacInputInjector) postScrollWheel(src uintptr, buttonMask uint16) {
|
||||
if buttonMask&0x08 != 0 {
|
||||
m.postScroll(src, scrollPixelsPerWheelTick)
|
||||
}
|
||||
if buttonMask&0x10 != 0 {
|
||||
m.postScroll(src, -scrollPixelsPerWheelTick)
|
||||
}
|
||||
}
|
||||
|
||||
// scrollPixelsPerWheelTick is the pixel delta we post for one VNC wheel
|
||||
// button event. Browser-based RFB clients typically emit one press+release
|
||||
// per ~10 px of host wheel/trackpad motion, so a real gesture arrives as
|
||||
// many small events; ~20 px per event keeps the resulting macOS scroll
|
||||
// fluid without overshooting on a single notch.
|
||||
const scrollPixelsPerWheelTick int32 = 22
|
||||
|
||||
func (m *MacInputInjector) postMouse(src uintptr, eventType int32, x, y float64, button int32) {
|
||||
if cgEventCreateMouseEvent == nil {
|
||||
return
|
||||
}
|
||||
event := cgEventCreateMouseEvent(src, eventType, x, y, button)
|
||||
if event == 0 {
|
||||
return
|
||||
}
|
||||
cgEventPost(kCGHIDEventTap, event)
|
||||
cfRelease(event)
|
||||
}
|
||||
|
||||
// postMouseClick stamps the click count on the event before posting it.
|
||||
// Without this stamp macOS treats every press as a fresh single click.
|
||||
func (m *MacInputInjector) postMouseClick(src uintptr, eventType int32, x, y float64, button int32, clickCount int64) {
|
||||
if cgEventCreateMouseEvent == nil {
|
||||
return
|
||||
}
|
||||
event := cgEventCreateMouseEvent(src, eventType, x, y, button)
|
||||
if event == 0 {
|
||||
return
|
||||
}
|
||||
if cgEventSetIntegerValueField != nil && clickCount > 1 {
|
||||
cgEventSetIntegerValueField(event, kCGMouseEventClickState, clickCount)
|
||||
}
|
||||
cgEventPost(kCGHIDEventTap, event)
|
||||
cfRelease(event)
|
||||
}
|
||||
|
||||
func (m *MacInputInjector) postScroll(src uintptr, deltaY int32) {
|
||||
if cgEventCreateScrollWheelEventAddr == 0 {
|
||||
return
|
||||
}
|
||||
// CGEventCreateScrollWheelEvent(source, units, wheelCount, wheel1delta).
|
||||
// Pixel units (0) feel smoother given the small per-event deltas typical
|
||||
// of RFB wheel events than line units (1) where each event jumps a
|
||||
// whole line. Variadic C function, pass via SyscallN.
|
||||
r1, _, _ := purego.SyscallN(cgEventCreateScrollWheelEventAddr,
|
||||
src, 0, 1, uintptr(uint32(deltaY)))
|
||||
if r1 == 0 {
|
||||
return
|
||||
}
|
||||
cgEventPost(kCGHIDEventTap, r1)
|
||||
cfRelease(r1)
|
||||
}
|
||||
|
||||
// SetClipboard sets the macOS clipboard using pbcopy.
|
||||
func (m *MacInputInjector) SetClipboard(text string) {
|
||||
if m.pbcopyPath == "" {
|
||||
return
|
||||
}
|
||||
cmd := exec.Command(m.pbcopyPath)
|
||||
cmd.Stdin = strings.NewReader(text)
|
||||
if err := cmd.Run(); err != nil {
|
||||
log.Tracef("set clipboard via pbcopy: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TypeText synthesizes the given text as keystrokes via Core Graphics.
|
||||
// Lets a client push host clipboard content to the focused remote app
|
||||
// even when the app doesn't honor pbpaste-style clipboard sync (e.g.
|
||||
// login screens, locked-down apps). ASCII printable runes only; others
|
||||
// are skipped.
|
||||
func (m *MacInputInjector) TypeText(text string) {
|
||||
wakeDisplay()
|
||||
src := ensureEventSource()
|
||||
if src == 0 {
|
||||
return
|
||||
}
|
||||
const maxChars = 4096
|
||||
count := 0
|
||||
for _, r := range text {
|
||||
if count >= maxChars {
|
||||
break
|
||||
}
|
||||
count++
|
||||
typeRune(src, r)
|
||||
}
|
||||
}
|
||||
|
||||
// typeRune emits the press/release events for a single ASCII rune, framing
|
||||
// the keystroke with Shift-down/up when required by the keysym.
|
||||
func typeRune(src uintptr, r rune) {
|
||||
const shiftKey = uint16(0x38) // kVK_Shift
|
||||
keysym, shift, ok := keysymForASCIIRune(r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
keycode := keysymToMacKeycode(keysym)
|
||||
if keycode == 0xFFFF {
|
||||
return
|
||||
}
|
||||
if shift {
|
||||
postKey(src, shiftKey, true)
|
||||
}
|
||||
postKey(src, keycode, true)
|
||||
postKey(src, keycode, false)
|
||||
if shift {
|
||||
postKey(src, shiftKey, false)
|
||||
}
|
||||
}
|
||||
|
||||
func postKey(src uintptr, keycode uint16, down bool) {
|
||||
e := cgEventCreateKeyboardEvent(src, keycode, down)
|
||||
if e == 0 {
|
||||
return
|
||||
}
|
||||
cgEventPost(kCGHIDEventTap, e)
|
||||
cfRelease(e)
|
||||
}
|
||||
|
||||
// GetClipboard reads the macOS clipboard using pbpaste.
|
||||
func (m *MacInputInjector) GetClipboard() string {
|
||||
if m.pbpastePath == "" {
|
||||
return ""
|
||||
}
|
||||
out, err := exec.Command(m.pbpastePath).Output()
|
||||
if err != nil {
|
||||
// pbpaste exits 1 when the pasteboard has no string flavour.
|
||||
return ""
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
|
||||
// Close releases the idle-sleep assertion held for the injector's lifetime.
|
||||
func (m *MacInputInjector) Close() {
|
||||
releasePreventIdleSleep()
|
||||
}
|
||||
|
||||
func keysymToMacKeycode(keysym uint32) uint16 {
|
||||
if keysym >= 0x61 && keysym <= 0x7a {
|
||||
return asciiToMacKey[keysym-0x61]
|
||||
}
|
||||
if keysym >= 0x41 && keysym <= 0x5a {
|
||||
return asciiToMacKey[keysym-0x41]
|
||||
}
|
||||
if keysym >= 0x30 && keysym <= 0x39 {
|
||||
return digitToMacKey[keysym-0x30]
|
||||
}
|
||||
if code, ok := specialKeyMap[keysym]; ok {
|
||||
return code
|
||||
}
|
||||
return 0xFFFF
|
||||
}
|
||||
|
||||
var asciiToMacKey = [26]uint16{
|
||||
0x00, 0x0B, 0x08, 0x02, 0x0E, 0x03, 0x05, 0x04,
|
||||
0x22, 0x26, 0x28, 0x25, 0x2E, 0x2D, 0x1F, 0x23,
|
||||
0x0C, 0x0F, 0x01, 0x11, 0x20, 0x09, 0x0D, 0x07,
|
||||
0x10, 0x06,
|
||||
}
|
||||
|
||||
var digitToMacKey = [10]uint16{
|
||||
0x1D, 0x12, 0x13, 0x14, 0x15, 0x17, 0x16, 0x1A, 0x1C, 0x19,
|
||||
}
|
||||
|
||||
var specialKeyMap = map[uint32]uint16{
|
||||
// Whitespace and editing
|
||||
0x0020: 0x31, // space
|
||||
0xff08: 0x33, // BackSpace
|
||||
0xff09: 0x30, // Tab
|
||||
0xff0d: 0x24, // Return
|
||||
0xff1b: 0x35, // Escape
|
||||
0xffff: 0x75, // Delete (forward)
|
||||
|
||||
// Navigation
|
||||
0xff50: 0x73, // Home
|
||||
0xff51: 0x7B, // Left
|
||||
0xff52: 0x7E, // Up
|
||||
0xff53: 0x7C, // Right
|
||||
0xff54: 0x7D, // Down
|
||||
0xff55: 0x74, // Page_Up
|
||||
0xff56: 0x79, // Page_Down
|
||||
0xff57: 0x77, // End
|
||||
0xff63: 0x72, // Insert (Help on Mac)
|
||||
|
||||
// Modifiers
|
||||
0xffe1: 0x38, // Shift_L
|
||||
0xffe2: 0x3C, // Shift_R
|
||||
0xffe3: 0x3B, // Control_L
|
||||
0xffe4: 0x3E, // Control_R
|
||||
0xffe5: 0x39, // Caps_Lock
|
||||
0xffe9: 0x3A, // Alt_L (Option)
|
||||
0xffea: 0x3D, // Alt_R (Option)
|
||||
0xffe7: 0x37, // Meta_L (Command)
|
||||
0xffe8: 0x36, // Meta_R (Command)
|
||||
0xffeb: 0x37, // Super_L (Command)
|
||||
0xffec: 0x36, // Super_R (Command)
|
||||
|
||||
// Mode_switch / ISO_Level3_Shift (for macOS Option remap on layouts)
|
||||
0xff7e: 0x3A, // Mode_switch -> Option
|
||||
0xfe03: 0x3D, // ISO_Level3_Shift -> Right Option
|
||||
|
||||
// Function keys
|
||||
0xffbe: 0x7A, // F1
|
||||
0xffbf: 0x78, // F2
|
||||
0xffc0: 0x63, // F3
|
||||
0xffc1: 0x76, // F4
|
||||
0xffc2: 0x60, // F5
|
||||
0xffc3: 0x61, // F6
|
||||
0xffc4: 0x62, // F7
|
||||
0xffc5: 0x64, // F8
|
||||
0xffc6: 0x65, // F9
|
||||
0xffc7: 0x6D, // F10
|
||||
0xffc8: 0x67, // F11
|
||||
0xffc9: 0x6F, // F12
|
||||
0xffca: 0x69, // F13
|
||||
0xffcb: 0x6B, // F14
|
||||
0xffcc: 0x71, // F15
|
||||
0xffcd: 0x6A, // F16
|
||||
0xffce: 0x40, // F17
|
||||
0xffcf: 0x4F, // F18
|
||||
0xffd0: 0x50, // F19
|
||||
0xffd1: 0x5A, // F20
|
||||
|
||||
// Punctuation (US keyboard layout, keysym = ASCII code)
|
||||
0x002d: 0x1B, // minus -
|
||||
0x003d: 0x18, // equal =
|
||||
0x005b: 0x21, // bracketleft [
|
||||
0x005d: 0x1E, // bracketright ]
|
||||
0x005c: 0x2A, // backslash
|
||||
0x003b: 0x29, // semicolon ;
|
||||
0x0027: 0x27, // apostrophe '
|
||||
0x0060: 0x32, // grave `
|
||||
0x002c: 0x2B, // comma ,
|
||||
0x002e: 0x2F, // period .
|
||||
0x002f: 0x2C, // slash /
|
||||
|
||||
// Shifted punctuation (clients sometimes send these as separate keysyms)
|
||||
0x005f: 0x1B, // underscore _ (shift+minus)
|
||||
0x002b: 0x18, // plus + (shift+equal)
|
||||
0x007b: 0x21, // braceleft { (shift+[)
|
||||
0x007d: 0x1E, // braceright } (shift+])
|
||||
0x007c: 0x2A, // bar | (shift+\)
|
||||
0x003a: 0x29, // colon : (shift+;)
|
||||
0x0022: 0x27, // quotedbl " (shift+')
|
||||
0x007e: 0x32, // tilde ~ (shift+`)
|
||||
0x003c: 0x2B, // less < (shift+,)
|
||||
0x003e: 0x2F, // greater > (shift+.)
|
||||
0x003f: 0x2C, // question ? (shift+/)
|
||||
0x0021: 0x12, // exclam ! (shift+1)
|
||||
0x0040: 0x13, // at @ (shift+2)
|
||||
0x0023: 0x14, // numbersign # (shift+3)
|
||||
0x0024: 0x15, // dollar $ (shift+4)
|
||||
0x0025: 0x17, // percent % (shift+5)
|
||||
0x005e: 0x16, // asciicircum ^ (shift+6)
|
||||
0x0026: 0x1A, // ampersand & (shift+7)
|
||||
0x002a: 0x1C, // asterisk * (shift+8)
|
||||
0x0028: 0x19, // parenleft ( (shift+9)
|
||||
0x0029: 0x1D, // parenright ) (shift+0)
|
||||
|
||||
// Numpad
|
||||
0xffb0: 0x52, // KP_0
|
||||
0xffb1: 0x53, // KP_1
|
||||
0xffb2: 0x54, // KP_2
|
||||
0xffb3: 0x55, // KP_3
|
||||
0xffb4: 0x56, // KP_4
|
||||
0xffb5: 0x57, // KP_5
|
||||
0xffb6: 0x58, // KP_6
|
||||
0xffb7: 0x59, // KP_7
|
||||
0xffb8: 0x5B, // KP_8
|
||||
0xffb9: 0x5C, // KP_9
|
||||
0xffae: 0x41, // KP_Decimal
|
||||
0xffaa: 0x43, // KP_Multiply
|
||||
0xffab: 0x45, // KP_Add
|
||||
0xffad: 0x4E, // KP_Subtract
|
||||
0xffaf: 0x4B, // KP_Divide
|
||||
0xff8d: 0x4C, // KP_Enter
|
||||
0xffbd: 0x51, // KP_Equal
|
||||
}
|
||||
|
||||
var _ InputInjector = (*MacInputInjector)(nil)
|
||||
17
client/vnc/server/input_uinput_freebsd.go
Normal file
17
client/vnc/server/input_uinput_freebsd.go
Normal file
@@ -0,0 +1,17 @@
|
||||
//go:build freebsd
|
||||
|
||||
package server
|
||||
|
||||
import "fmt"
|
||||
|
||||
// UInputInjector is a freebsd placeholder; the linux uinput implementation
|
||||
// uses Linux-only ioctls (UI_DEV_CREATE etc.) and is not portable.
|
||||
type UInputInjector struct {
|
||||
StubInputInjector
|
||||
}
|
||||
|
||||
// NewUInputInjector always returns an error on freebsd so callers fall back
|
||||
// to a stub or platform-appropriate injector.
|
||||
func NewUInputInjector(_, _ int) (*UInputInjector, error) {
|
||||
return nil, fmt.Errorf("uinput not implemented on freebsd")
|
||||
}
|
||||
488
client/vnc/server/input_uinput_linux.go
Normal file
488
client/vnc/server/input_uinput_linux.go
Normal file
@@ -0,0 +1,488 @@
|
||||
//go:build linux
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// /dev/uinput ioctl numbers. Computed from the kernel _IO/_IOW macros so
|
||||
// we don't depend on cgo. UINPUT_IOCTL_BASE = 'U' = 0x55.
|
||||
const (
|
||||
uiDevCreate = 0x5501
|
||||
uiDevDestroy = 0x5502
|
||||
// _IOW('U', 3, struct uinput_setup); uinput_setup is 92 bytes on amd64.
|
||||
uiDevSetup = (1 << 30) | (92 << 16) | (0x55 << 8) | 3
|
||||
uiSetEvBit = (1 << 30) | (4 << 16) | (0x55 << 8) | 100
|
||||
uiSetKeyBit = (1 << 30) | (4 << 16) | (0x55 << 8) | 101
|
||||
uiSetAbsBit = (1 << 30) | (4 << 16) | (0x55 << 8) | 103
|
||||
uinputAbsSize = 64 // legacy struct uses absmin/absmax/absfuzz/absflat[64].
|
||||
)
|
||||
|
||||
// Linux input event types and key codes (linux/input-event-codes.h).
|
||||
const (
|
||||
evSyn = 0x00
|
||||
evKey = 0x01
|
||||
evAbs = 0x03
|
||||
evRep = 0x14
|
||||
|
||||
synReport = 0
|
||||
|
||||
absX = 0x00
|
||||
absY = 0x01
|
||||
|
||||
btnLeft = 0x110
|
||||
btnRight = 0x111
|
||||
btnMiddle = 0x112
|
||||
btnSide = 0x113 // mouse-back (X1)
|
||||
btnExtra = 0x114 // mouse-forward (X2)
|
||||
)
|
||||
|
||||
// inputEvent matches struct input_event for x86_64 (timeval is 16 bytes).
|
||||
// Total size 24 bytes; Go's natural alignment matches the kernel layout.
|
||||
type inputEvent struct {
|
||||
TvSec int64
|
||||
TvUsec int64
|
||||
Type uint16
|
||||
Code uint16
|
||||
Value int32
|
||||
}
|
||||
|
||||
// UInputInjector synthesizes keyboard and mouse events via /dev/uinput.
|
||||
// Used as a fallback when X11 isn't running, e.g. at the kernel console
|
||||
// or pre-login screen on a server without X. Requires root or
|
||||
// CAP_SYS_ADMIN, which the netbird service has.
|
||||
type UInputInjector struct {
|
||||
mu sync.Mutex
|
||||
fd int
|
||||
closeOnce sync.Once
|
||||
keysymToKey map[uint32]uint16
|
||||
prevButtons uint16
|
||||
screenW int
|
||||
screenH int
|
||||
}
|
||||
|
||||
// NewUInputInjector opens /dev/uinput and registers a virtual keyboard +
|
||||
// absolute pointer device sized to (w, h). The dimensions are needed
|
||||
// because uinput's ABS axes don't autoscale; we always send absolute
|
||||
// coordinates and let the kernel route them to the right monitor.
|
||||
func NewUInputInjector(w, h int) (*UInputInjector, error) {
|
||||
if w <= 0 || h <= 0 {
|
||||
return nil, fmt.Errorf("invalid screen size: %dx%d", w, h)
|
||||
}
|
||||
fd, err := unix.Open("/dev/uinput", unix.O_WRONLY|unix.O_NONBLOCK, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open /dev/uinput: %w", err)
|
||||
}
|
||||
|
||||
if err := setBit(fd, uiSetEvBit, evKey); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
if err := setBit(fd, uiSetEvBit, evAbs); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
if err := setBit(fd, uiSetEvBit, evSyn); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
// Advertise key auto-repeat so the kernel input core repeats held
|
||||
// keys at the configured rate (default ~250 ms delay, ~33 ms period).
|
||||
// Without this, holding Backspace etc. only deletes one character.
|
||||
if err := setBit(fd, uiSetEvBit, evRep); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keymap := buildUInputKeymap()
|
||||
for _, key := range keymap {
|
||||
if err := setBit(fd, uiSetKeyBit, uint32(key)); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("UI_SET_KEYBIT %d: %w", key, err)
|
||||
}
|
||||
}
|
||||
for _, btn := range []uint16{btnLeft, btnRight, btnMiddle, btnSide, btnExtra} {
|
||||
if err := setBit(fd, uiSetKeyBit, uint32(btn)); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("UI_SET_KEYBIT btn %d: %w", btn, err)
|
||||
}
|
||||
}
|
||||
if err := setBit(fd, uiSetAbsBit, absX); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
if err := setBit(fd, uiSetAbsBit, absY); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := writeUInputUserDev(fd, w, h); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
if _, _, e := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uiDevCreate, 0); e != 0 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("UI_DEV_CREATE: %v", e)
|
||||
}
|
||||
// Give udev a moment to settle before sending events.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
inj := &UInputInjector{
|
||||
fd: fd,
|
||||
keysymToKey: keymapByKeysym(keymap),
|
||||
screenW: w,
|
||||
screenH: h,
|
||||
}
|
||||
log.Infof("uinput injector ready: %dx%d, %d keys", w, h, len(inj.keysymToKey))
|
||||
return inj, nil
|
||||
}
|
||||
|
||||
func setBit(fd int, op uintptr, code uint32) error {
|
||||
if _, _, e := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), op, uintptr(code)); e != 0 {
|
||||
return fmt.Errorf("ioctl 0x%x %d: %v", op, code, e)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeUInputUserDev uses the legacy uinput_user_dev path (write the
|
||||
// whole struct then UI_DEV_CREATE) which is universally supported on
|
||||
// older and current kernels alike. uinput_user_dev is name(80) + id(8) +
|
||||
// ff_effects_max(4) + absmax/absmin/absfuzz/absflat[64] = 92 + 4*64*4 =
|
||||
// 1116 bytes total.
|
||||
func writeUInputUserDev(fd, w, h int) error {
|
||||
const sz = 80 + 8 + 4 + uinputAbsSize*4*4
|
||||
buf := make([]byte, sz)
|
||||
copy(buf[0:80], []byte("netbird-vnc-uinput"))
|
||||
// id: BUS_VIRTUAL=0x06, vendor=0x0001, product=0x0001, version=1.
|
||||
binary.LittleEndian.PutUint16(buf[80:82], 0x06)
|
||||
binary.LittleEndian.PutUint16(buf[82:84], 0x0001)
|
||||
binary.LittleEndian.PutUint16(buf[84:86], 0x0001)
|
||||
binary.LittleEndian.PutUint16(buf[86:88], 0x0001)
|
||||
// ff_effects_max(4) at 88..92 stays zero.
|
||||
// absmax[64] at 92..348: set absX/absY.
|
||||
absmaxOff := 80 + 8 + 4
|
||||
absminOff := absmaxOff + uinputAbsSize*4
|
||||
binary.LittleEndian.PutUint32(buf[absmaxOff+absX*4:], uint32(w-1))
|
||||
binary.LittleEndian.PutUint32(buf[absmaxOff+absY*4:], uint32(h-1))
|
||||
binary.LittleEndian.PutUint32(buf[absminOff+absX*4:], 0)
|
||||
binary.LittleEndian.PutUint32(buf[absminOff+absY*4:], 0)
|
||||
if _, err := unix.Write(fd, buf); err != nil {
|
||||
return fmt.Errorf("write uinput_user_dev: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// emit writes a single input_event to the device. Caller-locked.
|
||||
func (u *UInputInjector) emit(typ, code uint16, value int32) error {
|
||||
ev := inputEvent{Type: typ, Code: code, Value: value}
|
||||
buf := (*[unsafe.Sizeof(inputEvent{})]byte)(unsafe.Pointer(&ev))[:]
|
||||
_, err := unix.Write(u.fd, buf)
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *UInputInjector) sync() {
|
||||
_ = u.emit(evSyn, synReport, 0)
|
||||
}
|
||||
|
||||
// InjectKey synthesizes a press or release for the given X11 keysym.
|
||||
func (u *UInputInjector) InjectKey(keysym uint32, down bool) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
code, ok := u.keysymToKey[keysym]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
u.emitKeyCode(code, down)
|
||||
}
|
||||
|
||||
// InjectKeyScancode injects a press or release using the QEMU scancode.
|
||||
// uinput speaks Linux KEY_* codes natively, so we map QEMU scancode →
|
||||
// KEY_* via qemuToLinuxKey. On miss (scancode we don't have a mapping
|
||||
// for) we fall back to the keysym path, which is exactly the legacy
|
||||
// behaviour.
|
||||
func (u *UInputInjector) InjectKeyScancode(scancode, keysym uint32, down bool) {
|
||||
code := qemuScancodeToLinuxKey(scancode)
|
||||
if code == 0 {
|
||||
u.InjectKey(keysym, down)
|
||||
return
|
||||
}
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
u.emitKeyCode(uint16(code), down)
|
||||
}
|
||||
|
||||
// emitKeyCode emits one key down/up event plus a sync. Caller holds u.mu.
|
||||
func (u *UInputInjector) emitKeyCode(code uint16, down bool) {
|
||||
value := int32(0)
|
||||
if down {
|
||||
value = 1
|
||||
}
|
||||
if err := u.emit(evKey, code, value); err != nil {
|
||||
log.Tracef("uinput emit key: %v", err)
|
||||
return
|
||||
}
|
||||
u.sync()
|
||||
}
|
||||
|
||||
// InjectPointer moves the absolute pointer and presses/releases buttons
|
||||
// based on the RFB button mask delta against the previous mask.
|
||||
func (u *UInputInjector) InjectPointer(buttonMask uint16, x, y, serverW, serverH int) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
if serverW <= 1 || serverH <= 1 {
|
||||
return
|
||||
}
|
||||
absXVal := int32(x * (u.screenW - 1) / (serverW - 1))
|
||||
absYVal := int32(y * (u.screenH - 1) / (serverH - 1))
|
||||
_ = u.emit(evAbs, absX, absXVal)
|
||||
_ = u.emit(evAbs, absY, absYVal)
|
||||
|
||||
type btnMap struct {
|
||||
bit uint16
|
||||
key uint16
|
||||
}
|
||||
for _, b := range []btnMap{
|
||||
{0x01, btnLeft},
|
||||
{0x02, btnMiddle},
|
||||
{0x04, btnRight},
|
||||
{1 << 7, btnSide},
|
||||
{1 << 8, btnExtra},
|
||||
} {
|
||||
pressed := buttonMask&b.bit != 0
|
||||
was := u.prevButtons&b.bit != 0
|
||||
if pressed && !was {
|
||||
_ = u.emit(evKey, b.key, 1)
|
||||
} else if !pressed && was {
|
||||
_ = u.emit(evKey, b.key, 0)
|
||||
}
|
||||
}
|
||||
u.prevButtons = buttonMask
|
||||
u.sync()
|
||||
}
|
||||
|
||||
// SetClipboard is a no-op on the framebuffer console: there is no system
|
||||
// clipboard daemon. Use TypeText (Paste button) to deliver host text.
|
||||
func (u *UInputInjector) SetClipboard(_ string) {
|
||||
// no system clipboard daemon on framebuffer console
|
||||
}
|
||||
|
||||
// GetClipboard returns empty: no clipboard outside X11/Wayland.
|
||||
func (u *UInputInjector) GetClipboard() string { return "" }
|
||||
|
||||
// TypeText synthesizes the given UTF-8 text as keystrokes. Only ASCII
|
||||
// printable characters and newline are typed; other runes are skipped.
|
||||
// This drives the "paste" button: with no console clipboard available,
|
||||
// keystroke-by-keystroke entry is the only way to deliver a password to
|
||||
// a TTY login prompt.
|
||||
func (u *UInputInjector) TypeText(text string) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
const maxChars = 4096
|
||||
count := 0
|
||||
for _, r := range text {
|
||||
if count >= maxChars {
|
||||
break
|
||||
}
|
||||
count++
|
||||
code, shift, ok := keyForRune(r)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if shift {
|
||||
_ = u.emit(evKey, keyLeftShift, 1)
|
||||
}
|
||||
_ = u.emit(evKey, code, 1)
|
||||
_ = u.emit(evKey, code, 0)
|
||||
if shift {
|
||||
_ = u.emit(evKey, keyLeftShift, 0)
|
||||
}
|
||||
u.sync()
|
||||
}
|
||||
}
|
||||
|
||||
// Close destroys the virtual uinput device and closes the file descriptor.
|
||||
func (u *UInputInjector) Close() {
|
||||
u.closeOnce.Do(func() {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
if u.fd >= 0 {
|
||||
_, _, _ = unix.Syscall(unix.SYS_IOCTL, uintptr(u.fd), uiDevDestroy, 0)
|
||||
_ = unix.Close(u.fd)
|
||||
u.fd = -1
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Linux KEY_* codes live in scancodes.go (shared with the QEMU scancode
|
||||
// path). Don't duplicate them here.
|
||||
|
||||
// buildUInputKeymap returns every linux KEY_ code we want the virtual
|
||||
// device to advertise during UI_SET_KEYBIT. Order doesn't matter.
|
||||
func buildUInputKeymap() []uint16 {
|
||||
out := make([]uint16, 0, 128)
|
||||
// Letters: KEY_A=30, KEY_B=48, etc; not a clean range. The kernel's
|
||||
// row-by-row layout is qwertyuiop / asdfghjkl / zxcvbnm.
|
||||
letters := []uint16{
|
||||
30, 48, 46, 32, 18, 33, 34, 35, 23, 36, 37, 38, 50, // a..m
|
||||
49, 24, 25, 16, 19, 31, 20, 22, 47, 17, 45, 21, 44, // n..z
|
||||
}
|
||||
out = append(out, letters...)
|
||||
// Top-row digits: KEY_1..KEY_0 = 2..11.
|
||||
for i := uint16(2); i <= 11; i++ {
|
||||
out = append(out, i)
|
||||
}
|
||||
// Function keys F1..F12 = 59..68 + 87, 88. We only register F1..F12
|
||||
// which the kernel header enumerates as a contiguous block.
|
||||
for i := uint16(59); i <= 68; i++ {
|
||||
out = append(out, i)
|
||||
}
|
||||
out = append(out, 87, 88)
|
||||
out = append(out, []uint16{
|
||||
keyEsc, keyMinus, keyEqual, keyBackspace, keyTab, keyEnter,
|
||||
keyLeftCtrl, keyRightCtrl, keyLeftShift, keyRightShift,
|
||||
keyLeftAlt, keyRightAlt, keyLeftMeta, keyRightMeta,
|
||||
keySpace, keyCapsLock,
|
||||
keyLeftBracket, keyRightBracket, keyBackslash,
|
||||
keySemicolon, keyApostrophe, keyGrave,
|
||||
keyComma, keyDot, keySlash,
|
||||
keyHome, keyEnd, keyPageUp, keyPageDown,
|
||||
keyUp, keyDown, keyLeft, keyRight,
|
||||
keyInsert, keyDelete,
|
||||
}...)
|
||||
return out
|
||||
}
|
||||
|
||||
// keymapByKeysym maps X11 keysyms (the values our session receives over
|
||||
// RFB) onto Linux KEY_ codes. Shifted ASCII keysyms (uppercase letters,
|
||||
// "!@#..." etc.) map to the same scan code as their unshifted twin: the
|
||||
// client also sends a separate Shift keysym (0xffe1), so the kernel
|
||||
// composes the final character from the held modifier + scan code.
|
||||
func keymapByKeysym(_ []uint16) map[uint32]uint16 {
|
||||
letters := map[rune]uint16{
|
||||
'a': 30, 'b': 48, 'c': 46, 'd': 32, 'e': 18, 'f': 33, 'g': 34,
|
||||
'h': 35, 'i': 23, 'j': 36, 'k': 37, 'l': 38, 'm': 50,
|
||||
'n': 49, 'o': 24, 'p': 25, 'q': 16, 'r': 19, 's': 31, 't': 20,
|
||||
'u': 22, 'v': 47, 'w': 17, 'x': 45, 'y': 21, 'z': 44,
|
||||
}
|
||||
m := map[uint32]uint16{
|
||||
// Digits.
|
||||
'0': 11, '1': 2, '2': 3, '3': 4, '4': 5, '5': 6, '6': 7,
|
||||
'7': 8, '8': 9, '9': 10,
|
||||
// Shifted digits (US layout).
|
||||
')': 11, '!': 2, '@': 3, '#': 4, '$': 5, '%': 6, '^': 7,
|
||||
'&': 8, '*': 9, '(': 10,
|
||||
// Punctuation (US layout) and shifted twins.
|
||||
' ': keySpace,
|
||||
'-': keyMinus, '_': keyMinus,
|
||||
'=': keyEqual, '+': keyEqual,
|
||||
'[': keyLeftBracket, '{': keyLeftBracket,
|
||||
']': keyRightBracket, '}': keyRightBracket,
|
||||
'\\': keyBackslash, '|': keyBackslash,
|
||||
';': keySemicolon, ':': keySemicolon,
|
||||
'\'': keyApostrophe, '"': keyApostrophe,
|
||||
'`': keyGrave, '~': keyGrave,
|
||||
',': keyComma, '<': keyComma,
|
||||
'.': keyDot, '>': keyDot,
|
||||
'/': keySlash, '?': keySlash,
|
||||
// Special keys (X11 keysyms).
|
||||
0xff08: keyBackspace, 0xff09: keyTab, 0xff0d: keyEnter,
|
||||
0xff1b: keyEsc, 0xffff: keyDelete,
|
||||
0xff50: keyHome, 0xff57: keyEnd,
|
||||
0xff51: keyLeft, 0xff52: keyUp, 0xff53: keyRight, 0xff54: keyDown,
|
||||
0xff55: keyPageUp, 0xff56: keyPageDown, 0xff63: keyInsert,
|
||||
0xffe1: keyLeftShift, 0xffe2: keyRightShift,
|
||||
0xffe3: keyLeftCtrl, 0xffe4: keyRightCtrl,
|
||||
0xffe9: keyLeftAlt, 0xffea: keyRightAlt,
|
||||
0xffeb: keyLeftMeta, 0xffec: keyRightMeta,
|
||||
}
|
||||
// Letters: register both lowercase and uppercase keysyms onto the same
|
||||
// KEY_ code. The client sends Shift separately for uppercase.
|
||||
for r, code := range letters {
|
||||
m[uint32(r)] = code
|
||||
m[uint32(r-'a'+'A')] = code
|
||||
}
|
||||
// Function keys F1..F12 (X11 keysyms 0xffbe..0xffc9 → KEY_F1..KEY_F12).
|
||||
xF := uint32(0xffbe)
|
||||
codes := []uint16{59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 87, 88}
|
||||
for i, c := range codes {
|
||||
m[xF+uint32(i)] = c
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// keyForRune maps a printable rune to (keycode, needsShift). Used by
|
||||
// TypeText to synthesize keystrokes for a paste payload.
|
||||
func keyForRune(r rune) (uint16, bool, bool) {
|
||||
if r >= 'a' && r <= 'z' {
|
||||
m := map[rune]uint16{
|
||||
'a': 30, 'b': 48, 'c': 46, 'd': 32, 'e': 18, 'f': 33, 'g': 34,
|
||||
'h': 35, 'i': 23, 'j': 36, 'k': 37, 'l': 38, 'm': 50,
|
||||
'n': 49, 'o': 24, 'p': 25, 'q': 16, 'r': 19, 's': 31, 't': 20,
|
||||
'u': 22, 'v': 47, 'w': 17, 'x': 45, 'y': 21, 'z': 44,
|
||||
}
|
||||
return m[r], false, true
|
||||
}
|
||||
if r >= 'A' && r <= 'Z' {
|
||||
c, _, ok := keyForRune(unicode.ToLower(r))
|
||||
return c, true, ok
|
||||
}
|
||||
if r >= '0' && r <= '9' {
|
||||
nums := []uint16{11, 2, 3, 4, 5, 6, 7, 8, 9, 10}
|
||||
idx := int(r - '0')
|
||||
if idx < 0 || idx >= len(nums) { //nolint:gosec // explicit bound disarms G602
|
||||
return 0, false, false
|
||||
}
|
||||
return nums[idx], false, true
|
||||
}
|
||||
if r == '\n' || r == '\r' {
|
||||
return keyEnter, false, true
|
||||
}
|
||||
if k, ok := punctUnshifted[r]; ok {
|
||||
return k, false, true
|
||||
}
|
||||
if k, ok := punctShifted[r]; ok {
|
||||
return k, true, true
|
||||
}
|
||||
return 0, false, false
|
||||
}
|
||||
|
||||
// punctUnshifted maps ASCII punctuation that needs no Shift to its uinput
|
||||
// KEY_* code. Split out of keyForRune's switch to keep the function's
|
||||
// cognitive complexity below Sonar's threshold.
|
||||
var punctUnshifted = map[rune]uint16{
|
||||
' ': keySpace,
|
||||
'\t': keyTab,
|
||||
'-': keyMinus,
|
||||
'=': keyEqual,
|
||||
'[': keyLeftBracket,
|
||||
']': keyRightBracket,
|
||||
'\\': keyBackslash,
|
||||
';': keySemicolon,
|
||||
'\'': keyApostrophe,
|
||||
'`': keyGrave,
|
||||
',': keyComma,
|
||||
'.': keyDot,
|
||||
'/': keySlash,
|
||||
}
|
||||
|
||||
// punctShifted maps ASCII punctuation that requires Shift to its base KEY_*
|
||||
// code; the caller adds the shift modifier itself.
|
||||
var punctShifted = map[rune]uint16{
|
||||
'!': 2, '@': 3, '#': 4, '$': 5, '%': 6, '^': 7, '&': 8, '*': 9,
|
||||
'(': 10, ')': 11,
|
||||
'_': keyMinus, '+': keyEqual,
|
||||
'{': keyLeftBracket, '}': keyRightBracket, '|': keyBackslash,
|
||||
':': keySemicolon, '"': keyApostrophe, '~': keyGrave,
|
||||
'<': keyComma, '>': keyDot, '?': keySlash,
|
||||
}
|
||||
|
||||
var _ InputInjector = (*UInputInjector)(nil)
|
||||
599
client/vnc/server/input_windows.go
Normal file
599
client/vnc/server/input_windows.go
Normal file
@@ -0,0 +1,599 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
procOpenEventW = kernel32.NewProc("OpenEventW")
|
||||
procSendInput = user32.NewProc("SendInput")
|
||||
procVkKeyScanA = user32.NewProc("VkKeyScanA")
|
||||
)
|
||||
|
||||
const eventModifyState = 0x0002
|
||||
|
||||
const (
|
||||
inputMouse = 0
|
||||
inputKeyboard = 1
|
||||
|
||||
mouseeventfMove = 0x0001
|
||||
mouseeventfLeftDown = 0x0002
|
||||
mouseeventfLeftUp = 0x0004
|
||||
mouseeventfRightDown = 0x0008
|
||||
mouseeventfRightUp = 0x0010
|
||||
mouseeventfMiddleDown = 0x0020
|
||||
mouseeventfMiddleUp = 0x0040
|
||||
mouseeventfXDown = 0x0080
|
||||
mouseeventfXUp = 0x0100
|
||||
mouseeventfWheel = 0x0800
|
||||
mouseeventfAbsolute = 0x8000
|
||||
|
||||
// X-button identifiers carried in the dwData field of MOUSEEVENTF_X*
|
||||
// events. XBUTTON1 is mouse-back, XBUTTON2 is mouse-forward.
|
||||
xButton1 = 0x0001
|
||||
xButton2 = 0x0002
|
||||
|
||||
wheelDelta = 120
|
||||
|
||||
keyeventfExtendedKey = 0x0001
|
||||
keyeventfKeyUp = 0x0002
|
||||
keyeventfUnicode = 0x0004
|
||||
keyeventfScanCode = 0x0008
|
||||
)
|
||||
|
||||
// maxTypedClipboardChars caps the number of characters we will synthesize as
|
||||
// keystrokes when falling back on the Winlogon desktop. Passwords are short;
|
||||
// a huge clipboard getting typed into the login screen would be surprising.
|
||||
const maxTypedClipboardChars = 4096
|
||||
|
||||
type mouseInput struct {
|
||||
Dx int32
|
||||
Dy int32
|
||||
MouseData uint32
|
||||
DwFlags uint32
|
||||
Time uint32
|
||||
DwExtraInfo uintptr
|
||||
}
|
||||
|
||||
type keybdInput struct {
|
||||
WVk uint16
|
||||
WScan uint16
|
||||
DwFlags uint32
|
||||
Time uint32
|
||||
DwExtraInfo uintptr
|
||||
_ [8]byte
|
||||
}
|
||||
|
||||
type inputUnion [32]byte
|
||||
|
||||
type winInput struct {
|
||||
Type uint32
|
||||
_ [4]byte
|
||||
Data inputUnion
|
||||
}
|
||||
|
||||
func sendMouseInput(flags uint32, dx, dy int32, mouseData uint32) {
|
||||
mi := mouseInput{
|
||||
Dx: dx,
|
||||
Dy: dy,
|
||||
MouseData: mouseData,
|
||||
DwFlags: flags,
|
||||
}
|
||||
inp := winInput{Type: inputMouse}
|
||||
copy(inp.Data[:], (*[unsafe.Sizeof(mi)]byte)(unsafe.Pointer(&mi))[:])
|
||||
r, _, err := procSendInput.Call(1, uintptr(unsafe.Pointer(&inp)), unsafe.Sizeof(inp))
|
||||
if r == 0 {
|
||||
log.Tracef("SendInput(mouse flags=0x%x): %v", flags, err)
|
||||
}
|
||||
}
|
||||
|
||||
func sendKeyInput(vk uint16, scanCode uint16, flags uint32) {
|
||||
ki := keybdInput{
|
||||
WVk: vk,
|
||||
WScan: scanCode,
|
||||
DwFlags: flags,
|
||||
}
|
||||
inp := winInput{Type: inputKeyboard}
|
||||
copy(inp.Data[:], (*[unsafe.Sizeof(ki)]byte)(unsafe.Pointer(&ki))[:])
|
||||
r, _, err := procSendInput.Call(1, uintptr(unsafe.Pointer(&inp)), unsafe.Sizeof(inp))
|
||||
if r == 0 {
|
||||
log.Tracef("SendInput(key vk=0x%x): %v", vk, err)
|
||||
}
|
||||
}
|
||||
|
||||
const sasEventName = `Global\NetBirdVNC_SAS`
|
||||
|
||||
type inputCmd struct {
|
||||
isKey bool
|
||||
isScancode bool
|
||||
isClipboard bool
|
||||
isType bool
|
||||
keysym uint32
|
||||
scancode uint32
|
||||
down bool
|
||||
buttonMask uint16
|
||||
x, y int
|
||||
serverW int
|
||||
serverH int
|
||||
clipText string
|
||||
}
|
||||
|
||||
// WindowsInputInjector delivers input events from a dedicated OS thread that
|
||||
// calls switchToInputDesktop before each injection. SendInput targets the
|
||||
// calling thread's desktop, so the injection thread must be on the same
|
||||
// desktop the user sees.
|
||||
type WindowsInputInjector struct {
|
||||
ch chan inputCmd
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
prevButtonMask uint16
|
||||
// lastQueuedButtonMask is the most recent buttonMask submitted to ch
|
||||
// by InjectPointer. Compared against the incoming sample to decide
|
||||
// whether the new event is move-only (lossy enqueue) or carries a
|
||||
// button/wheel transition (reliable enqueue).
|
||||
lastQueuedButtonMask uint16
|
||||
lastQueuedMaskValid bool
|
||||
queueMu sync.Mutex
|
||||
ctrlDown bool
|
||||
altDown bool
|
||||
}
|
||||
|
||||
// NewWindowsInputInjector creates a desktop-aware input injector.
|
||||
func NewWindowsInputInjector() *WindowsInputInjector {
|
||||
w := &WindowsInputInjector{
|
||||
ch: make(chan inputCmd, 64),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
go w.loop()
|
||||
return w
|
||||
}
|
||||
|
||||
// Close stops the injector loop. Safe to call multiple times. Subsequent
|
||||
// Inject*/SetClipboard/TypeText calls become no-ops; we use a separate
|
||||
// signal channel rather than closing ch so late senders can't panic.
|
||||
func (w *WindowsInputInjector) Close() {
|
||||
w.closeOnce.Do(func() {
|
||||
close(w.closed)
|
||||
})
|
||||
}
|
||||
|
||||
// tryEnqueue posts a command unless the injector is closed or the channel is
|
||||
// full. Non-blocking so callers (RFB read loop) never stall.
|
||||
func (w *WindowsInputInjector) tryEnqueue(cmd inputCmd) {
|
||||
select {
|
||||
case <-w.closed:
|
||||
case w.ch <- cmd:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// enqueueReliable posts a command and blocks until it's accepted or the
|
||||
// injector closes. Used for edge-triggered events (button/wheel) where a
|
||||
// drop would desynchronize prevButtonMask in dispatch().
|
||||
func (w *WindowsInputInjector) enqueueReliable(cmd inputCmd) {
|
||||
select {
|
||||
case <-w.closed:
|
||||
return
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case w.ch <- cmd:
|
||||
case <-w.closed:
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WindowsInputInjector) loop() {
|
||||
runtime.LockOSThread()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.closed:
|
||||
return
|
||||
case cmd := <-w.ch:
|
||||
w.dispatch(cmd)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WindowsInputInjector) dispatch(cmd inputCmd) {
|
||||
// Switch to the current input desktop so SendInput and the clipboard
|
||||
// API target the desktop the user sees. The returned name tells us
|
||||
// whether we are on the secure Winlogon desktop.
|
||||
_, _ = switchToInputDesktop()
|
||||
|
||||
switch {
|
||||
case cmd.isClipboard:
|
||||
w.doSetClipboard(cmd.clipText)
|
||||
case cmd.isType:
|
||||
w.typeUnicodeText(cmd.clipText)
|
||||
case cmd.isScancode:
|
||||
w.doInjectKeyScancode(cmd.scancode, cmd.keysym, cmd.down)
|
||||
case cmd.isKey:
|
||||
w.doInjectKey(cmd.keysym, cmd.down)
|
||||
default:
|
||||
w.doInjectPointer(cmd.buttonMask, cmd.x, cmd.y, cmd.serverW, cmd.serverH)
|
||||
}
|
||||
}
|
||||
|
||||
// InjectKey queues a key event for injection on the input desktop thread.
|
||||
func (w *WindowsInputInjector) InjectKey(keysym uint32, down bool) {
|
||||
w.tryEnqueue(inputCmd{isKey: true, keysym: keysym, down: down})
|
||||
}
|
||||
|
||||
// InjectKeyScancode queues a raw-scancode key event. PC AT Set 1 maps
|
||||
// directly onto what SendInput's KEYEVENTF_SCANCODE flag wants, so the
|
||||
// only translation is splitting the optional 0xE0 prefix off into the
|
||||
// KEYEVENTF_EXTENDEDKEY flag. keysym is the client-provided fallback we
|
||||
// reach for if the scancode is zero.
|
||||
func (w *WindowsInputInjector) InjectKeyScancode(scancode uint32, keysym uint32, down bool) {
|
||||
if scancode == 0 {
|
||||
w.InjectKey(keysym, down)
|
||||
return
|
||||
}
|
||||
w.tryEnqueue(inputCmd{isScancode: true, scancode: scancode, keysym: keysym, down: down})
|
||||
}
|
||||
|
||||
// InjectPointer queues a pointer event for injection on the input desktop
|
||||
// thread. Move-only updates use lossy enqueue (next sample carries fresher
|
||||
// position anyway), but any sample whose buttonMask differs from the last
|
||||
// queued mask is enqueued reliably so wheel ticks and button transitions
|
||||
// can't be dropped under backpressure.
|
||||
func (w *WindowsInputInjector) InjectPointer(buttonMask uint16, x, y, serverW, serverH int) {
|
||||
cmd := inputCmd{buttonMask: buttonMask, x: x, y: y, serverW: serverW, serverH: serverH}
|
||||
w.queueMu.Lock()
|
||||
transition := !w.lastQueuedMaskValid || w.lastQueuedButtonMask != buttonMask
|
||||
w.lastQueuedButtonMask = buttonMask
|
||||
w.lastQueuedMaskValid = true
|
||||
w.queueMu.Unlock()
|
||||
if transition {
|
||||
w.enqueueReliable(cmd)
|
||||
return
|
||||
}
|
||||
w.tryEnqueue(cmd)
|
||||
}
|
||||
|
||||
// doInjectKeyScancode injects a key event using the QEMU scancode directly,
|
||||
// bypassing the keysym→VK lookup. Windows accepts PC AT Set 1 scancodes
|
||||
// natively via KEYEVENTF_SCANCODE, so the only work is splitting the
|
||||
// optional 0xE0 prefix off into the EXTENDEDKEY flag and tracking
|
||||
// modifier state for the SAS Ctrl+Alt+Del shortcut.
|
||||
func (w *WindowsInputInjector) doInjectKeyScancode(scancode, keysym uint32, down bool) {
|
||||
switch keysym {
|
||||
case 0xffe3, 0xffe4:
|
||||
w.ctrlDown = down
|
||||
case 0xffe9, 0xffea:
|
||||
w.altDown = down
|
||||
}
|
||||
if (keysym == 0xff9f || keysym == 0xffff) && w.ctrlDown && w.altDown && down {
|
||||
signalSAS()
|
||||
return
|
||||
}
|
||||
flags := uint32(keyeventfScanCode)
|
||||
if !down {
|
||||
flags |= keyeventfKeyUp
|
||||
}
|
||||
if qemuScancodeIsExtended(scancode) {
|
||||
flags |= keyeventfExtendedKey
|
||||
}
|
||||
sendKeyInput(0, qemuScancodeLowByte(scancode), flags)
|
||||
}
|
||||
|
||||
func (w *WindowsInputInjector) doInjectKey(keysym uint32, down bool) {
|
||||
switch keysym {
|
||||
case 0xffe3, 0xffe4:
|
||||
w.ctrlDown = down
|
||||
case 0xffe9, 0xffea:
|
||||
w.altDown = down
|
||||
}
|
||||
|
||||
if (keysym == 0xff9f || keysym == 0xffff) && w.ctrlDown && w.altDown && down {
|
||||
signalSAS()
|
||||
return
|
||||
}
|
||||
|
||||
vk, _, extended := keysym2VK(keysym)
|
||||
if vk == 0 {
|
||||
return
|
||||
}
|
||||
var flags uint32
|
||||
if !down {
|
||||
flags |= keyeventfKeyUp
|
||||
}
|
||||
if extended {
|
||||
flags |= keyeventfExtendedKey
|
||||
}
|
||||
sendKeyInput(vk, 0, flags)
|
||||
}
|
||||
|
||||
// signalSAS signals the SAS named event. A listener in Session 0
|
||||
// (startSASListener) calls SendSAS to trigger the Secure Attention Sequence.
|
||||
func signalSAS() {
|
||||
namePtr, err := windows.UTF16PtrFromString(sasEventName)
|
||||
if err != nil {
|
||||
log.Warnf("SAS UTF16: %v", err)
|
||||
return
|
||||
}
|
||||
h, _, lerr := procOpenEventW.Call(
|
||||
uintptr(eventModifyState),
|
||||
0,
|
||||
uintptr(unsafe.Pointer(namePtr)),
|
||||
)
|
||||
if h == 0 {
|
||||
log.Warnf("OpenEvent(%s): %v", sasEventName, lerr)
|
||||
return
|
||||
}
|
||||
ev := windows.Handle(h)
|
||||
defer func() { _ = windows.CloseHandle(ev) }()
|
||||
if err := windows.SetEvent(ev); err != nil {
|
||||
log.Warnf("SetEvent SAS: %v", err)
|
||||
} else {
|
||||
log.Info("SAS event signaled")
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WindowsInputInjector) doInjectPointer(buttonMask uint16, x, y, serverW, serverH int) {
|
||||
if serverW == 0 || serverH == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
absX := int32(x * 65535 / serverW)
|
||||
absY := int32(y * 65535 / serverH)
|
||||
|
||||
sendMouseInput(mouseeventfMove|mouseeventfAbsolute, absX, absY, 0)
|
||||
|
||||
changed := buttonMask ^ w.prevButtonMask
|
||||
w.prevButtonMask = buttonMask
|
||||
|
||||
type btnMap struct {
|
||||
bit uint16
|
||||
down uint32
|
||||
up uint32
|
||||
}
|
||||
buttons := [...]btnMap{
|
||||
{0x01, mouseeventfLeftDown, mouseeventfLeftUp},
|
||||
{0x02, mouseeventfMiddleDown, mouseeventfMiddleUp},
|
||||
{0x04, mouseeventfRightDown, mouseeventfRightUp},
|
||||
}
|
||||
for _, b := range buttons {
|
||||
if changed&b.bit == 0 {
|
||||
continue
|
||||
}
|
||||
var flags uint32
|
||||
if buttonMask&b.bit != 0 {
|
||||
flags = b.down
|
||||
} else {
|
||||
flags = b.up
|
||||
}
|
||||
sendMouseInput(flags|mouseeventfAbsolute, absX, absY, 0)
|
||||
}
|
||||
|
||||
negWheelDelta := ^uint32(wheelDelta - 1)
|
||||
if changed&0x08 != 0 && buttonMask&0x08 != 0 {
|
||||
sendMouseInput(mouseeventfWheel|mouseeventfAbsolute, absX, absY, wheelDelta)
|
||||
}
|
||||
if changed&0x10 != 0 && buttonMask&0x10 != 0 {
|
||||
sendMouseInput(mouseeventfWheel|mouseeventfAbsolute, absX, absY, negWheelDelta)
|
||||
}
|
||||
|
||||
// XBUTTON1/back at bit 7, XBUTTON2/forward at bit 8. SendInput
|
||||
// MOUSEEVENTF_X{DOWN,UP} carries the X button number in dwData.
|
||||
xbuttons := [...]struct {
|
||||
bit uint16
|
||||
data uint32
|
||||
}{
|
||||
{1 << 7, xButton1},
|
||||
{1 << 8, xButton2},
|
||||
}
|
||||
for _, b := range xbuttons {
|
||||
if changed&b.bit == 0 {
|
||||
continue
|
||||
}
|
||||
var flags uint32 = mouseeventfXUp
|
||||
if buttonMask&b.bit != 0 {
|
||||
flags = mouseeventfXDown
|
||||
}
|
||||
sendMouseInput(flags|mouseeventfAbsolute, absX, absY, b.data)
|
||||
}
|
||||
}
|
||||
|
||||
// keysym2VK converts an X11 keysym to a Windows virtual key code.
|
||||
func keysym2VK(keysym uint32) (vk uint16, scan uint16, extended bool) {
|
||||
if keysym >= 0x20 && keysym <= 0x7e {
|
||||
r, _, _ := procVkKeyScanA.Call(uintptr(keysym))
|
||||
vk = uint16(r & 0xff)
|
||||
return
|
||||
}
|
||||
|
||||
if keysym >= 0xffbe && keysym <= 0xffc9 {
|
||||
vk = uint16(0x70 + keysym - 0xffbe)
|
||||
return
|
||||
}
|
||||
|
||||
switch keysym {
|
||||
case 0xff08:
|
||||
vk = 0x08 // Backspace
|
||||
case 0xff09:
|
||||
vk = 0x09 // Tab
|
||||
case 0xff0d:
|
||||
vk = 0x0d // Return
|
||||
case 0xff1b:
|
||||
vk = 0x1b // Escape
|
||||
case 0xff63:
|
||||
vk, extended = 0x2d, true // Insert
|
||||
case 0xff9f, 0xffff:
|
||||
vk, extended = 0x2e, true // Delete
|
||||
case 0xff50:
|
||||
vk, extended = 0x24, true // Home
|
||||
case 0xff57:
|
||||
vk, extended = 0x23, true // End
|
||||
case 0xff55:
|
||||
vk, extended = 0x21, true // PageUp
|
||||
case 0xff56:
|
||||
vk, extended = 0x22, true // PageDown
|
||||
case 0xff51:
|
||||
vk, extended = 0x25, true // Left
|
||||
case 0xff52:
|
||||
vk, extended = 0x26, true // Up
|
||||
case 0xff53:
|
||||
vk, extended = 0x27, true // Right
|
||||
case 0xff54:
|
||||
vk, extended = 0x28, true // Down
|
||||
case 0xffe1, 0xffe2:
|
||||
vk = 0x10 // Shift
|
||||
case 0xffe3, 0xffe4:
|
||||
vk = 0x11 // Control
|
||||
case 0xffe9, 0xffea:
|
||||
vk = 0x12 // Alt
|
||||
case 0xffe5:
|
||||
vk = 0x14 // CapsLock
|
||||
case 0xffe7, 0xffeb:
|
||||
vk, extended = 0x5B, true // Meta_L / Super_L -> Left Windows
|
||||
case 0xffe8, 0xffec:
|
||||
vk, extended = 0x5C, true // Meta_R / Super_R -> Right Windows
|
||||
case 0xff61:
|
||||
vk = 0x2c // PrintScreen
|
||||
case 0xff13:
|
||||
vk = 0x13 // Pause
|
||||
case 0xff14:
|
||||
vk = 0x91 // ScrollLock
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
procOpenClipboard = user32.NewProc("OpenClipboard")
|
||||
procCloseClipboard = user32.NewProc("CloseClipboard")
|
||||
procEmptyClipboard = user32.NewProc("EmptyClipboard")
|
||||
procSetClipboardData = user32.NewProc("SetClipboardData")
|
||||
procGetClipboardData = user32.NewProc("GetClipboardData")
|
||||
procIsClipboardFormatAvailable = user32.NewProc("IsClipboardFormatAvailable")
|
||||
|
||||
procGlobalAlloc = kernel32.NewProc("GlobalAlloc")
|
||||
procGlobalLock = kernel32.NewProc("GlobalLock")
|
||||
procGlobalUnlock = kernel32.NewProc("GlobalUnlock")
|
||||
procGlobalFree = kernel32.NewProc("GlobalFree")
|
||||
)
|
||||
|
||||
const (
|
||||
cfUnicodeText = 13
|
||||
gmemMoveable = 0x0002
|
||||
)
|
||||
|
||||
// SetClipboard queues a request to update the Windows clipboard with the
|
||||
// given UTF-8 text. The work runs on the input thread so it follows the
|
||||
// current input desktop. Secure desktops (Winlogon, UAC) have isolated
|
||||
// clipboards we cannot reach, so the call is a no-op there; use TypeText
|
||||
// to enter text into a secure desktop instead.
|
||||
func (w *WindowsInputInjector) SetClipboard(text string) {
|
||||
w.tryEnqueue(inputCmd{isClipboard: true, clipText: text})
|
||||
}
|
||||
|
||||
// TypeText queues a request to synthesize the given text as Unicode
|
||||
// keystrokes on the current input desktop. Targets the secure desktop
|
||||
// when the user is on Winlogon/UAC, where the clipboard is unreachable.
|
||||
func (w *WindowsInputInjector) TypeText(text string) {
|
||||
w.tryEnqueue(inputCmd{isType: true, clipText: text})
|
||||
}
|
||||
|
||||
func (w *WindowsInputInjector) doSetClipboard(text string) {
|
||||
utf16, err := windows.UTF16FromString(text)
|
||||
if err != nil {
|
||||
log.Tracef("clipboard UTF16 encode: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
size := uintptr(len(utf16) * 2)
|
||||
hMem, _, _ := procGlobalAlloc.Call(gmemMoveable, size)
|
||||
if hMem == 0 {
|
||||
log.Tracef("GlobalAlloc for clipboard: allocation returned nil")
|
||||
return
|
||||
}
|
||||
|
||||
ptr, _, _ := procGlobalLock.Call(hMem)
|
||||
if ptr == 0 {
|
||||
log.Tracef("GlobalLock for clipboard: lock returned nil")
|
||||
_, _, _ = procGlobalFree.Call(hMem)
|
||||
return
|
||||
}
|
||||
copy(unsafe.Slice((*uint16)(unsafe.Pointer(ptr)), len(utf16)), utf16)
|
||||
_, _, _ = procGlobalUnlock.Call(hMem)
|
||||
|
||||
r, _, lerr := procOpenClipboard.Call(0)
|
||||
if r == 0 {
|
||||
log.Tracef("OpenClipboard: %v", lerr)
|
||||
_, _, _ = procGlobalFree.Call(hMem)
|
||||
return
|
||||
}
|
||||
defer logCleanupCall("CloseClipboard", procCloseClipboard)
|
||||
|
||||
_, _, _ = procEmptyClipboard.Call()
|
||||
r, _, lerr = procSetClipboardData.Call(cfUnicodeText, hMem)
|
||||
if r == 0 {
|
||||
log.Tracef("SetClipboardData: %v", lerr)
|
||||
// Ownership only transfers to the OS on success; on failure we
|
||||
// still own hMem and must free it.
|
||||
_, _, _ = procGlobalFree.Call(hMem)
|
||||
}
|
||||
}
|
||||
|
||||
// typeUnicodeText synthesizes the given text as Unicode keystrokes via
|
||||
// SendInput+KEYEVENTF_UNICODE. Used on the Winlogon secure desktop where the
|
||||
// clipboard is isolated: this lets a VNC client paste a password into the
|
||||
// login or credential prompt by sending ClientCutText.
|
||||
func (w *WindowsInputInjector) typeUnicodeText(text string) {
|
||||
utf16, err := windows.UTF16FromString(text)
|
||||
if err != nil {
|
||||
log.Tracef("clipboard UTF16 encode: %v", err)
|
||||
return
|
||||
}
|
||||
if len(utf16) > 0 && utf16[len(utf16)-1] == 0 {
|
||||
utf16 = utf16[:len(utf16)-1]
|
||||
}
|
||||
if len(utf16) > maxTypedClipboardChars {
|
||||
log.Warnf("clipboard paste on Winlogon truncated to %d chars", maxTypedClipboardChars)
|
||||
utf16 = utf16[:maxTypedClipboardChars]
|
||||
}
|
||||
for _, c := range utf16 {
|
||||
sendKeyInput(0, c, keyeventfUnicode)
|
||||
sendKeyInput(0, c, keyeventfUnicode|keyeventfKeyUp)
|
||||
}
|
||||
}
|
||||
|
||||
// GetClipboard reads the Windows clipboard as UTF-8 text.
|
||||
func (w *WindowsInputInjector) GetClipboard() string {
|
||||
r, _, _ := procIsClipboardFormatAvailable.Call(cfUnicodeText)
|
||||
if r == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
r, _, lerr := procOpenClipboard.Call(0)
|
||||
if r == 0 {
|
||||
log.Tracef("OpenClipboard for read: %v", lerr)
|
||||
return ""
|
||||
}
|
||||
defer logCleanupCall("CloseClipboard", procCloseClipboard)
|
||||
|
||||
hData, _, _ := procGetClipboardData.Call(cfUnicodeText)
|
||||
if hData == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
ptr, _, _ := procGlobalLock.Call(hData)
|
||||
if ptr == 0 {
|
||||
return ""
|
||||
}
|
||||
defer logCleanupCallArgs("GlobalUnlock", procGlobalUnlock, hData)
|
||||
|
||||
return windows.UTF16PtrToString((*uint16)(unsafe.Pointer(ptr)))
|
||||
}
|
||||
|
||||
var _ InputInjector = (*WindowsInputInjector)(nil)
|
||||
|
||||
var _ ScreenCapturer = (*DesktopCapturer)(nil)
|
||||
312
client/vnc/server/input_x11.go
Normal file
312
client/vnc/server/input_x11.go
Normal file
@@ -0,0 +1,312 @@
|
||||
//go:build unix && !darwin && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/jezek/xgb"
|
||||
"github.com/jezek/xgb/xproto"
|
||||
"github.com/jezek/xgb/xtest"
|
||||
)
|
||||
|
||||
// X11InputInjector injects keyboard and mouse events via the XTest extension.
|
||||
type X11InputInjector struct {
|
||||
conn *xgb.Conn
|
||||
root xproto.Window
|
||||
screen *xproto.ScreenInfo
|
||||
display string
|
||||
keysymMap map[uint32]byte
|
||||
lastButtons uint16
|
||||
clipboardTool string
|
||||
clipboardToolName string
|
||||
}
|
||||
|
||||
// NewX11InputInjector connects to the X11 display and initializes XTest.
|
||||
func NewX11InputInjector(display string) (*X11InputInjector, error) {
|
||||
detectX11Display()
|
||||
|
||||
if display == "" {
|
||||
display = os.Getenv(envDisplay)
|
||||
}
|
||||
if display == "" {
|
||||
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
|
||||
}
|
||||
|
||||
conn, err := xgb.NewConnDisplay(display)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
|
||||
}
|
||||
|
||||
if err := xtest.Init(conn); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("init XTest extension: %w", err)
|
||||
}
|
||||
|
||||
setup := xproto.Setup(conn)
|
||||
if len(setup.Roots) == 0 {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("no X11 screens")
|
||||
}
|
||||
screen := setup.Roots[0]
|
||||
|
||||
inj := &X11InputInjector{
|
||||
conn: conn,
|
||||
root: screen.Root,
|
||||
screen: &screen,
|
||||
display: display,
|
||||
}
|
||||
inj.cacheKeyboardMapping()
|
||||
inj.resolveClipboardTool()
|
||||
|
||||
log.Infof("X11 input injector ready (display=%s)", display)
|
||||
return inj, nil
|
||||
}
|
||||
|
||||
// InjectKey simulates a key press or release. keysym is an X11 KeySym.
|
||||
func (x *X11InputInjector) InjectKey(keysym uint32, down bool) {
|
||||
keycode := x.keysymToKeycode(keysym)
|
||||
if keycode == 0 {
|
||||
return
|
||||
}
|
||||
x.fakeKeyEvent(keycode, down)
|
||||
}
|
||||
|
||||
// InjectKeyScancode injects using the QEMU scancode by translating to a
|
||||
// Linux KEY_ code and then to an X11 keycode (KEY_* + xkbKeycodeOffset).
|
||||
// On a server running a standard XKB keymap this is layout-independent:
|
||||
// the scancode names the physical key, the server's layout determines the
|
||||
// resulting character. Falls back to the keysym path when the scancode
|
||||
// has no Linux mapping.
|
||||
func (x *X11InputInjector) InjectKeyScancode(scancode, keysym uint32, down bool) {
|
||||
linuxKey := qemuScancodeToLinuxKey(scancode)
|
||||
if linuxKey == 0 {
|
||||
x.InjectKey(keysym, down)
|
||||
return
|
||||
}
|
||||
x.fakeKeyEvent(byte(linuxKey+xkbKeycodeOffset), down)
|
||||
}
|
||||
|
||||
// xkbKeycodeOffset is the per-server constant offset between Linux KEY_*
|
||||
// event codes and the X server's keycode space under XKB. The X protocol
|
||||
// reserves keycodes 0..7 for internal use, so any normal XKB keymap
|
||||
// starts at 8 (KEY_ESC=1 → X keycode 9, KEY_A=30 → X keycode 38, etc.).
|
||||
const xkbKeycodeOffset = 8
|
||||
|
||||
// fakeKeyEvent sends an XTest FakeInput for a press or release.
|
||||
func (x *X11InputInjector) fakeKeyEvent(keycode byte, down bool) {
|
||||
var eventType byte
|
||||
if down {
|
||||
eventType = xproto.KeyPress
|
||||
} else {
|
||||
eventType = xproto.KeyRelease
|
||||
}
|
||||
xtest.FakeInput(x.conn, eventType, keycode, 0, x.root, 0, 0, 0)
|
||||
}
|
||||
|
||||
// InjectPointer simulates mouse movement and button events.
|
||||
func (x *X11InputInjector) InjectPointer(buttonMask uint16, px, py, serverW, serverH int) {
|
||||
if serverW == 0 || serverH == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Scale to actual screen coordinates.
|
||||
screenW := int(x.screen.WidthInPixels)
|
||||
screenH := int(x.screen.HeightInPixels)
|
||||
absX := px * screenW / serverW
|
||||
absY := py * screenH / serverH
|
||||
|
||||
// Move pointer.
|
||||
xtest.FakeInput(x.conn, xproto.MotionNotify, 0, 0, x.root, int16(absX), int16(absY), 0)
|
||||
|
||||
// Handle button events. RFB button mask: bit0=left, bit1=middle, bit2=right,
|
||||
// bit3=scrollUp, bit4=scrollDown. X11 buttons: 1=left, 2=middle, 3=right,
|
||||
// 4=scrollUp, 5=scrollDown.
|
||||
type btnMap struct {
|
||||
rfbBit uint16
|
||||
x11Btn byte
|
||||
}
|
||||
// X11 button numbers: 1=left, 2=middle, 3=right, 4/5=scroll up/down,
|
||||
// 6/7=scroll left/right (skipped), 8=back, 9=forward.
|
||||
buttons := [...]btnMap{
|
||||
{0x01, 1},
|
||||
{0x02, 2},
|
||||
{0x04, 3},
|
||||
{0x08, 4},
|
||||
{0x10, 5},
|
||||
{1 << 7, 8},
|
||||
{1 << 8, 9},
|
||||
}
|
||||
|
||||
for _, b := range buttons {
|
||||
pressed := buttonMask&b.rfbBit != 0
|
||||
wasPressed := x.lastButtons&b.rfbBit != 0
|
||||
if b.x11Btn == 4 || b.x11Btn == 5 {
|
||||
// Scroll: send press+release on each scroll event.
|
||||
if pressed {
|
||||
xtest.FakeInput(x.conn, xproto.ButtonPress, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||
xtest.FakeInput(x.conn, xproto.ButtonRelease, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||
}
|
||||
} else {
|
||||
if pressed && !wasPressed {
|
||||
xtest.FakeInput(x.conn, xproto.ButtonPress, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||
} else if !pressed && wasPressed {
|
||||
xtest.FakeInput(x.conn, xproto.ButtonRelease, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
x.lastButtons = buttonMask
|
||||
}
|
||||
|
||||
// cacheKeyboardMapping fetches the X11 keyboard mapping once and stores it
|
||||
// as a keysym-to-keycode map, avoiding a round-trip per keystroke.
|
||||
func (x *X11InputInjector) cacheKeyboardMapping() {
|
||||
setup := xproto.Setup(x.conn)
|
||||
minKeycode := setup.MinKeycode
|
||||
maxKeycode := setup.MaxKeycode
|
||||
|
||||
reply, err := xproto.GetKeyboardMapping(x.conn, minKeycode,
|
||||
byte(maxKeycode-minKeycode+1)).Reply()
|
||||
if err != nil {
|
||||
log.Debugf("cache keyboard mapping: %v", err)
|
||||
x.keysymMap = make(map[uint32]byte)
|
||||
return
|
||||
}
|
||||
|
||||
m := make(map[uint32]byte, int(maxKeycode-minKeycode+1)*int(reply.KeysymsPerKeycode))
|
||||
keysymsPerKeycode := int(reply.KeysymsPerKeycode)
|
||||
for i := int(minKeycode); i <= int(maxKeycode); i++ {
|
||||
offset := (i - int(minKeycode)) * keysymsPerKeycode
|
||||
for j := 0; j < keysymsPerKeycode; j++ {
|
||||
ks := uint32(reply.Keysyms[offset+j])
|
||||
if ks != 0 {
|
||||
if _, exists := m[ks]; !exists {
|
||||
m[ks] = byte(i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
x.keysymMap = m
|
||||
}
|
||||
|
||||
// keysymToKeycode looks up a cached keysym-to-keycode mapping.
|
||||
// Returns 0 if the keysym is not mapped.
|
||||
func (x *X11InputInjector) keysymToKeycode(keysym uint32) byte {
|
||||
return x.keysymMap[keysym]
|
||||
}
|
||||
|
||||
// SetClipboard sets the X11 clipboard using xclip or xsel.
|
||||
func (x *X11InputInjector) SetClipboard(text string) {
|
||||
if x.clipboardTool == "" {
|
||||
return
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if x.clipboardToolName == "xclip" {
|
||||
cmd = exec.Command(x.clipboardTool, "-selection", "clipboard")
|
||||
} else {
|
||||
cmd = exec.Command(x.clipboardTool, "--clipboard", "--input")
|
||||
}
|
||||
cmd.Env = x.clipboardEnv()
|
||||
cmd.Stdin = strings.NewReader(text)
|
||||
if err := cmd.Run(); err != nil {
|
||||
log.Debugf("set clipboard via %s: %v", x.clipboardToolName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TypeText synthesizes the given text as keystrokes via XTest. Used in
|
||||
// places where the focused application isn't clipboard-aware (e.g. a TTY
|
||||
// login in an X11 session, an SDDM/GDM password field that ignores
|
||||
// XSelection, or a kiosk app), so stuffing the X clipboard and relying on
|
||||
// Ctrl+V would not reach the input.
|
||||
//
|
||||
// Limitation: only ASCII printable characters are typed. Non-ASCII runes
|
||||
// are skipped: a paste workflow for them needs Wayland-aware text input
|
||||
// or layout introspection that this path does not implement.
|
||||
func (x *X11InputInjector) TypeText(text string) {
|
||||
const maxChars = 4096
|
||||
count := 0
|
||||
for _, r := range text {
|
||||
if count >= maxChars {
|
||||
break
|
||||
}
|
||||
count++
|
||||
keysym, shift, ok := keysymForASCIIRune(r)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
keycode := x.keysymToKeycode(keysym)
|
||||
if keycode == 0 {
|
||||
continue
|
||||
}
|
||||
var shiftCode byte
|
||||
if shift {
|
||||
shiftCode = x.keysymToKeycode(0xffe1) // Shift_L
|
||||
if shiftCode != 0 {
|
||||
xtest.FakeInput(x.conn, xproto.KeyPress, shiftCode, 0, x.root, 0, 0, 0)
|
||||
}
|
||||
}
|
||||
xtest.FakeInput(x.conn, xproto.KeyPress, keycode, 0, x.root, 0, 0, 0)
|
||||
xtest.FakeInput(x.conn, xproto.KeyRelease, keycode, 0, x.root, 0, 0, 0)
|
||||
if shift && shiftCode != 0 {
|
||||
xtest.FakeInput(x.conn, xproto.KeyRelease, shiftCode, 0, x.root, 0, 0, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (x *X11InputInjector) resolveClipboardTool() {
|
||||
for _, name := range []string{"xclip", "xsel"} {
|
||||
path, err := exec.LookPath(name)
|
||||
if err == nil {
|
||||
x.clipboardTool = path
|
||||
x.clipboardToolName = name
|
||||
log.Debugf("clipboard tool resolved to %s", path)
|
||||
return
|
||||
}
|
||||
}
|
||||
log.Debugf("no clipboard tool (xclip/xsel) found, clipboard sync disabled")
|
||||
}
|
||||
|
||||
// GetClipboard reads the X11 clipboard using xclip or xsel.
|
||||
func (x *X11InputInjector) GetClipboard() string {
|
||||
if x.clipboardTool == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if x.clipboardToolName == "xclip" {
|
||||
cmd = exec.Command(x.clipboardTool, "-selection", "clipboard", "-o")
|
||||
} else {
|
||||
cmd = exec.Command(x.clipboardTool, "--clipboard", "--output")
|
||||
}
|
||||
cmd.Env = x.clipboardEnv()
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
// Exit status 1 just means there is no STRING selection set yet,
|
||||
// which is the steady state on a fresh Xvfb session, logging it
|
||||
// every clipboard poll (2s) floods the trace stream.
|
||||
return ""
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
|
||||
func (x *X11InputInjector) clipboardEnv() []string {
|
||||
env := []string{envDisplay + "=" + x.display}
|
||||
if auth := os.Getenv(envXAuthority); auth != "" {
|
||||
env = append(env, envXAuthority+"="+auth)
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// Close releases X11 resources.
|
||||
func (x *X11InputInjector) Close() {
|
||||
x.conn.Close()
|
||||
}
|
||||
|
||||
var _ InputInjector = (*X11InputInjector)(nil)
|
||||
var _ ScreenCapturer = (*X11Poller)(nil)
|
||||
73
client/vnc/server/keysym_typetext.go
Normal file
73
client/vnc/server/keysym_typetext.go
Normal file
@@ -0,0 +1,73 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
|
||||
// keysymForASCIIRune maps an ASCII rune to (X11 keysym for the unshifted
|
||||
// version, needsShift). Used by TypeText implementations on each platform
|
||||
// so the caller can explicitly press Shift instead of relying on the
|
||||
// server-side modifier state. Returns ok=false for runes outside the
|
||||
// supported set; non-ASCII text is dropped by TypeText.
|
||||
func keysymForASCIIRune(r rune) (uint32, bool, bool) {
|
||||
if r >= 'a' && r <= 'z' {
|
||||
return uint32(r), false, true
|
||||
}
|
||||
if r >= 'A' && r <= 'Z' {
|
||||
return uint32(r - 'A' + 'a'), true, true
|
||||
}
|
||||
if r >= '0' && r <= '9' {
|
||||
return uint32(r), false, true
|
||||
}
|
||||
switch r {
|
||||
case ' ':
|
||||
return 0x20, false, true
|
||||
case '\n', '\r':
|
||||
return 0xff0d, false, true // Return
|
||||
case '\t':
|
||||
return 0xff09, false, true // Tab
|
||||
case '-', '=', '[', ']', '\\', ';', '\'', '`', ',', '.', '/':
|
||||
return uint32(r), false, true
|
||||
case '!':
|
||||
return '1', true, true
|
||||
case '@':
|
||||
return '2', true, true
|
||||
case '#':
|
||||
return '3', true, true
|
||||
case '$':
|
||||
return '4', true, true
|
||||
case '%':
|
||||
return '5', true, true
|
||||
case '^':
|
||||
return '6', true, true
|
||||
case '&':
|
||||
return '7', true, true
|
||||
case '*':
|
||||
return '8', true, true
|
||||
case '(':
|
||||
return '9', true, true
|
||||
case ')':
|
||||
return '0', true, true
|
||||
case '_':
|
||||
return '-', true, true
|
||||
case '+':
|
||||
return '=', true, true
|
||||
case '{':
|
||||
return '[', true, true
|
||||
case '}':
|
||||
return ']', true, true
|
||||
case '|':
|
||||
return '\\', true, true
|
||||
case ':':
|
||||
return ';', true, true
|
||||
case '"':
|
||||
return '\'', true, true
|
||||
case '~':
|
||||
return '`', true, true
|
||||
case '<':
|
||||
return ',', true, true
|
||||
case '>':
|
||||
return '.', true, true
|
||||
case '?':
|
||||
return '/', true, true
|
||||
}
|
||||
return 0, false, false
|
||||
}
|
||||
225
client/vnc/server/metrics_conn.go
Normal file
225
client/vnc/server/metrics_conn.go
Normal file
@@ -0,0 +1,225 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SessionTick is one sampling slice of a VNC session's wire activity.
|
||||
// BytesOut / Writes / FBUs are deltas observed during this tick;
|
||||
// Max* fields are the high-water marks observed during this tick (reset
|
||||
// at the start of the next). Period is the wall-clock duration covered
|
||||
// (typically sessionTickInterval, shorter for the final flush).
|
||||
type SessionTick struct {
|
||||
Period time.Duration
|
||||
BytesOut uint64
|
||||
Writes uint64
|
||||
FBUs uint64
|
||||
MaxFBUBytes uint64
|
||||
MaxFBURects uint64
|
||||
MaxWriteBytes uint64
|
||||
WriteNanos uint64
|
||||
}
|
||||
|
||||
// sessionTickInterval is how often metricsConn emits a SessionTick. One
|
||||
// second covers roughly one FBU round-trip at typical client request
|
||||
// cadences during steady-state activity.
|
||||
const sessionTickInterval = time.Second
|
||||
|
||||
// metricsConn wraps a net.Conn and tracks per-session byte / write / FBU
|
||||
// counters. Updates are atomic so the cost is a few atomic ops per Write
|
||||
// (well under 100 ns), negligible against the syscall itself, so the wrap
|
||||
// is always installed. A goroutine emits a SessionTick to the recorder
|
||||
// every sessionTickInterval (only when the tick has activity to report);
|
||||
// a final partial-tick flush runs on Close.
|
||||
type metricsConn struct {
|
||||
net.Conn
|
||||
|
||||
recorder func(SessionTick)
|
||||
|
||||
bytesOut atomic.Uint64
|
||||
writes atomic.Uint64
|
||||
writeNanos atomic.Uint64
|
||||
largestPkt atomic.Uint64
|
||||
fbus atomic.Uint64
|
||||
fbuBytes atomic.Uint64
|
||||
fbuRects atomic.Uint64
|
||||
maxFBUBytes atomic.Uint64
|
||||
maxFBURects atomic.Uint64
|
||||
|
||||
tickMu sync.Mutex
|
||||
tickStart time.Time
|
||||
tickPrevB uint64
|
||||
tickPrevW uint64
|
||||
tickPrevF uint64
|
||||
tickPrevNS uint64
|
||||
|
||||
// busyMu guards the sliding window used by BusyFraction.
|
||||
busyMu sync.Mutex
|
||||
busyLastTime time.Time
|
||||
busyLastNanos uint64
|
||||
busyFraction float64
|
||||
|
||||
closeOnce sync.Once
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newMetricsConn(c net.Conn, recorder func(SessionTick)) net.Conn {
|
||||
m := &metricsConn{
|
||||
Conn: c,
|
||||
recorder: recorder,
|
||||
tickStart: time.Now(),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
if recorder != nil {
|
||||
go m.tickLoop()
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// tickLoop emits a SessionTick every sessionTickInterval until done.
|
||||
// Empty ticks (no writes since the last tick) are skipped.
|
||||
func (m *metricsConn) tickLoop() {
|
||||
t := time.NewTicker(sessionTickInterval)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-m.done:
|
||||
return
|
||||
case <-t.C:
|
||||
m.flushTick(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// flushTick computes deltas since the last tick, resets the per-tick max
|
||||
// trackers, and emits a SessionTick to the recorder. final=true forces
|
||||
// emission even if no writes happened (used at session close to record
|
||||
// the trailing partial period).
|
||||
func (m *metricsConn) flushTick(final bool) {
|
||||
m.tickMu.Lock()
|
||||
defer m.tickMu.Unlock()
|
||||
|
||||
b := m.bytesOut.Load()
|
||||
w := m.writes.Load()
|
||||
f := m.fbus.Load()
|
||||
ns := m.writeNanos.Load()
|
||||
|
||||
db := b - m.tickPrevB
|
||||
dw := w - m.tickPrevW
|
||||
df := f - m.tickPrevF
|
||||
dns := ns - m.tickPrevNS
|
||||
m.tickPrevB, m.tickPrevW, m.tickPrevF, m.tickPrevNS = b, w, f, ns
|
||||
|
||||
maxFBU := m.maxFBUBytes.Swap(0)
|
||||
maxRects := m.maxFBURects.Swap(0)
|
||||
maxPkt := m.largestPkt.Swap(0)
|
||||
|
||||
period := time.Since(m.tickStart)
|
||||
m.tickStart = time.Now()
|
||||
|
||||
if dw == 0 && !final {
|
||||
return
|
||||
}
|
||||
m.recorder(SessionTick{
|
||||
Period: period,
|
||||
BytesOut: db,
|
||||
Writes: dw,
|
||||
FBUs: df,
|
||||
MaxFBUBytes: maxFBU,
|
||||
MaxFBURects: maxRects,
|
||||
MaxWriteBytes: maxPkt,
|
||||
WriteNanos: dns,
|
||||
})
|
||||
}
|
||||
|
||||
// BusyFraction reports the fraction of recent wall time that Write spent
|
||||
// blocked in the underlying socket, as an exponentially smoothed value in
|
||||
// [0, 1]. Approximates downstream backpressure: persistent values near 1
|
||||
// mean the socket cannot keep up with the encoder's output. Callers can
|
||||
// throttle JPEG quality or skip frames in response.
|
||||
func (m *metricsConn) BusyFraction() float64 {
|
||||
now := time.Now()
|
||||
ns := m.writeNanos.Load()
|
||||
|
||||
m.busyMu.Lock()
|
||||
defer m.busyMu.Unlock()
|
||||
if m.busyLastTime.IsZero() {
|
||||
m.busyLastTime = now
|
||||
m.busyLastNanos = ns
|
||||
return 0
|
||||
}
|
||||
period := now.Sub(m.busyLastTime)
|
||||
if period < 50*time.Millisecond {
|
||||
return m.busyFraction
|
||||
}
|
||||
delta := ns - m.busyLastNanos
|
||||
sample := float64(delta) / float64(period.Nanoseconds())
|
||||
if sample > 1 {
|
||||
sample = 1
|
||||
}
|
||||
const alpha = 0.4
|
||||
m.busyFraction = alpha*sample + (1-alpha)*m.busyFraction
|
||||
m.busyLastTime = now
|
||||
m.busyLastNanos = ns
|
||||
return m.busyFraction
|
||||
}
|
||||
|
||||
// isFBUHeader reports whether the given Write payload is the 4-byte
|
||||
// FramebufferUpdate header (message type 0, padding 0, rect-count high
|
||||
// byte). Rect bodies are written separately by sendDirtyAndMoves, so the
|
||||
// FBU/rect boundary lines up with Write boundaries.
|
||||
func isFBUHeader(p []byte) bool {
|
||||
return len(p) == 4 && p[0] == serverFramebufferUpdate
|
||||
}
|
||||
|
||||
func (m *metricsConn) Write(p []byte) (int, error) {
|
||||
if isFBUHeader(p) {
|
||||
if b := m.fbuBytes.Swap(0); b > 0 {
|
||||
if b > m.maxFBUBytes.Load() {
|
||||
m.maxFBUBytes.Store(b)
|
||||
}
|
||||
}
|
||||
if r := m.fbuRects.Swap(0); r > 0 {
|
||||
if r > m.maxFBURects.Load() {
|
||||
m.maxFBURects.Store(r)
|
||||
}
|
||||
}
|
||||
m.fbus.Add(1)
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
n, err := m.Conn.Write(p)
|
||||
m.writeNanos.Add(uint64(time.Since(t0).Nanoseconds()))
|
||||
m.bytesOut.Add(uint64(n))
|
||||
m.writes.Add(1)
|
||||
if !isFBUHeader(p) {
|
||||
m.fbuBytes.Add(uint64(n))
|
||||
m.fbuRects.Add(1)
|
||||
}
|
||||
if uint64(n) > m.largestPkt.Load() {
|
||||
m.largestPkt.Store(uint64(n))
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (m *metricsConn) Close() error {
|
||||
m.closeOnce.Do(func() {
|
||||
close(m.done)
|
||||
if m.recorder == nil {
|
||||
return
|
||||
}
|
||||
if b := m.fbuBytes.Swap(0); b > m.maxFBUBytes.Load() {
|
||||
m.maxFBUBytes.Store(b)
|
||||
}
|
||||
if r := m.fbuRects.Swap(0); r > m.maxFBURects.Load() {
|
||||
m.maxFBURects.Store(r)
|
||||
}
|
||||
m.flushTick(true)
|
||||
})
|
||||
return m.Conn.Close()
|
||||
}
|
||||
432
client/vnc/server/noise_auth_test.go
Normal file
432
client/vnc/server/noise_auth_test.go
Normal file
@@ -0,0 +1,432 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
// noiseTestServer starts a VNC server with a freshly generated identity
|
||||
// key and returns the listener address, the server, and the server's
|
||||
// static public key for client-side handshake setup.
|
||||
func noiseTestServer(t *testing.T) (net.Addr, *Server, []byte) {
|
||||
t.Helper()
|
||||
|
||||
kp, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, kp.Private)
|
||||
srv.SetDisableAuth(false)
|
||||
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
require.NoError(t, srv.Start(t.Context(), addr, network))
|
||||
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
||||
t.Cleanup(func() { _ = srv.Stop() })
|
||||
|
||||
return srv.listener.Addr(), srv, kp.Public
|
||||
}
|
||||
|
||||
// registerSessionKey enrolls a fresh X25519 keypair under the given user
|
||||
// ID into the server's authorizer with the requested OS-user wildcard
|
||||
// mapping. Returns the keypair so the test can drive the handshake.
|
||||
func registerSessionKey(t *testing.T, srv *Server, userID string) noise.DHKey {
|
||||
t.Helper()
|
||||
|
||||
kp, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
userHash, err := sshuserhash.HashUserID(userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.UpdateVNCAuth(&sshauth.Config{
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{userHash},
|
||||
MachineUsers: map[string][]uint32{sshauth.Wildcard: {0}},
|
||||
SessionPubKeys: []sshauth.SessionPubKey{
|
||||
{PubKey: kp.Public, UserIDHash: userHash},
|
||||
},
|
||||
})
|
||||
return kp
|
||||
}
|
||||
|
||||
// writeHeaderPrefix writes the mode + zero-length-username prefix that
|
||||
// precedes the optional Noise handshake in the NetBird VNC header.
|
||||
func writeHeaderPrefix(t *testing.T, conn net.Conn, mode byte) {
|
||||
t.Helper()
|
||||
prefix := []byte{mode, 0, 0}
|
||||
_, err := conn.Write(prefix)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// writeHeaderTail writes the sessionID/width/height fields that follow
|
||||
// either the Noise msg2 (auth path) or the prefix alone (no-auth path).
|
||||
func writeHeaderTail(t *testing.T, conn net.Conn) {
|
||||
t.Helper()
|
||||
tail := make([]byte, 8)
|
||||
_, err := conn.Write(tail)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// performInitiator drives the initiator side of Noise_IK against the
|
||||
// server's identity public key, returns the resulting state. The Noise
|
||||
// msg2 produced by the server is read and consumed.
|
||||
func performInitiator(t *testing.T, conn net.Conn, clientKey noise.DHKey, serverPub []byte) {
|
||||
t.Helper()
|
||||
|
||||
state, err := noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: vncNoiseSuite,
|
||||
Pattern: noise.HandshakeIK,
|
||||
Initiator: true,
|
||||
StaticKeypair: clientKey,
|
||||
PeerStatic: serverPub,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
msg1, _, _, err := state.WriteMessage(nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, noiseInitiatorMsgLen, len(msg1))
|
||||
|
||||
_, err = conn.Write(append([]byte("NBV3"), msg1...))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
|
||||
msg2 := make([]byte, noiseResponderMsgLen)
|
||||
_, err = io.ReadFull(conn, msg2)
|
||||
require.NoError(t, err)
|
||||
_, _, _, err = state.ReadMessage(nil, msg2)
|
||||
require.NoError(t, err, "server responder message must decrypt with the correct peer static")
|
||||
}
|
||||
|
||||
// readRFBFailure consumes the RFB version exchange and returns the
|
||||
// security-failure reason string. Fails the test if the server did not
|
||||
// send a failure (i.e. produced a non-zero security-types list).
|
||||
func readRFBFailure(t *testing.T, conn net.Conn) string {
|
||||
t.Helper()
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
|
||||
|
||||
var ver [12]byte
|
||||
_, err := io.ReadFull(conn, ver[:])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "RFB 003.008\n", string(ver[:]))
|
||||
|
||||
_, err = conn.Write(ver[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
var n [1]byte
|
||||
_, err = io.ReadFull(conn, n[:])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, byte(0), n[0], "expected security-failure (0 types)")
|
||||
|
||||
var rl [4]byte
|
||||
_, err = io.ReadFull(conn, rl[:])
|
||||
require.NoError(t, err)
|
||||
reason := make([]byte, binary.BigEndian.Uint32(rl[:]))
|
||||
_, err = io.ReadFull(conn, reason)
|
||||
require.NoError(t, err)
|
||||
return string(reason)
|
||||
}
|
||||
|
||||
// readRFBGreetingNoFailure asserts the server proceeded past auth: it
|
||||
// must offer at least one security type rather than a 0 failure.
|
||||
func readRFBGreetingNoFailure(t *testing.T, conn net.Conn) {
|
||||
t.Helper()
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
|
||||
|
||||
var ver [12]byte
|
||||
_, err := io.ReadFull(conn, ver[:])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "RFB 003.008\n", string(ver[:]))
|
||||
|
||||
_, err = conn.Write(ver[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
var n [1]byte
|
||||
_, err = io.ReadFull(conn, n[:])
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, byte(0), n[0], "server must offer security types after a valid handshake")
|
||||
}
|
||||
|
||||
// TestNoise_RegisteredKey_AccessGranted exercises the happy path: a
|
||||
// session key enrolled in the authorizer completes a Noise_IK handshake
|
||||
// and the server proceeds to the RFB greeting.
|
||||
func TestNoise_RegisteredKey_AccessGranted(t *testing.T) {
|
||||
addr, srv, serverPub := noiseTestServer(t)
|
||||
clientKey := registerSessionKey(t, srv, "alice@example")
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
performInitiator(t, conn, clientKey, serverPub)
|
||||
writeHeaderTail(t, conn)
|
||||
|
||||
readRFBGreetingNoFailure(t, conn)
|
||||
}
|
||||
|
||||
// TestNoise_UnregisteredClientStatic_Rejected proves the authorizer is
|
||||
// consulted: a syntactically-valid handshake from a key the server has
|
||||
// never been told about must be rejected fail-closed.
|
||||
func TestNoise_UnregisteredClientStatic_Rejected(t *testing.T) {
|
||||
addr, _, serverPub := noiseTestServer(t)
|
||||
// Auth is enabled but the authorizer was not updated, so the lookup
|
||||
// path returns ErrSessionKeyNotKnown.
|
||||
attackerKey, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
performInitiator(t, conn, attackerKey, serverPub)
|
||||
writeHeaderTail(t, conn)
|
||||
|
||||
reason := readRFBFailure(t, conn)
|
||||
assert.Contains(t, reason, RejectCodeAuthForbidden)
|
||||
assert.Contains(t, reason, "session pubkey not registered")
|
||||
}
|
||||
|
||||
// TestNoise_WrongServerStatic_HandshakeFails proves the server's
|
||||
// identity is bound into the handshake: an initiator using the wrong
|
||||
// peer static encrypts msg1 under keys the real server can't derive, so
|
||||
// the server fails the handshake and closes without RFB output.
|
||||
func TestNoise_WrongServerStatic_HandshakeFails(t *testing.T) {
|
||||
addr, srv, _ := noiseTestServer(t)
|
||||
clientKey := registerSessionKey(t, srv, "alice@example")
|
||||
|
||||
bogusServerKey, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
|
||||
state, err := noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: vncNoiseSuite,
|
||||
Pattern: noise.HandshakeIK,
|
||||
Initiator: true,
|
||||
StaticKeypair: clientKey,
|
||||
PeerStatic: bogusServerKey.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
msg1, _, _, err := state.WriteMessage(nil, nil)
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Write(append([]byte("NBV3"), msg1...))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
|
||||
var b [1]byte
|
||||
_, err = io.ReadFull(conn, b[:])
|
||||
require.Error(t, err, "server must close without RFB greeting when msg1 is sealed for a different server identity")
|
||||
}
|
||||
|
||||
// TestNoise_MalformedMsg1_ClosesConnection covers the case where the
|
||||
// magic prefix is correct but the following 96 bytes are random: the
|
||||
// noise library fails ReadMessage and the server closes silently.
|
||||
func TestNoise_MalformedMsg1_ClosesConnection(t *testing.T) {
|
||||
addr, _, _ := noiseTestServer(t)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
junk := make([]byte, noiseInitiatorMsgLen)
|
||||
for i := range junk {
|
||||
junk[i] = byte(i)
|
||||
}
|
||||
_, err = conn.Write(append([]byte("NBV3"), junk...))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
|
||||
var b [1]byte
|
||||
_, err = io.ReadFull(conn, b[:])
|
||||
require.Error(t, err, "garbage msg1 must terminate the connection before any RFB output")
|
||||
}
|
||||
|
||||
// TestNoise_TruncatedMsg1_ClosesConnection sends fewer than the 96
|
||||
// bytes a Noise_IK msg1 must contain. The server's io.ReadFull short-
|
||||
// reads and closes; no RFB greeting must leak.
|
||||
func TestNoise_TruncatedMsg1_ClosesConnection(t *testing.T) {
|
||||
addr, _, _ := noiseTestServer(t)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
_, err = conn.Write([]byte("NBV3"))
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Write(make([]byte, 8))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, conn.(*net.TCPConn).CloseWrite())
|
||||
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(2*time.Second)))
|
||||
buf := make([]byte, 64)
|
||||
n, err := conn.Read(buf)
|
||||
require.Equal(t, 0, n, "server must not emit RFB bytes after a truncated handshake")
|
||||
require.ErrorIs(t, err, io.EOF, "server must close the connection on truncated msg1")
|
||||
}
|
||||
|
||||
// TestNoise_AuthEnabled_NoHandshake_Rejected proves that with auth on,
|
||||
// a connection that skips the Noise prefix (older client / VNC client)
|
||||
// is rejected with AUTH_FORBIDDEN: identity proof missing.
|
||||
func TestNoise_AuthEnabled_NoHandshake_Rejected(t *testing.T) {
|
||||
addr, _, _ := noiseTestServer(t)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
writeHeaderTail(t, conn)
|
||||
|
||||
reason := readRFBFailure(t, conn)
|
||||
assert.Contains(t, reason, RejectCodeAuthForbidden)
|
||||
assert.Contains(t, reason, "identity proof missing")
|
||||
}
|
||||
|
||||
// TestNoise_RevokedKey_RejectedAfterAuthUpdate verifies the authorizer
|
||||
// honors revocations: a key that worked before a UpdateVNCAuth call
|
||||
// must stop working as soon as the new config omits it.
|
||||
func TestNoise_RevokedKey_RejectedAfterAuthUpdate(t *testing.T) {
|
||||
addr, srv, serverPub := noiseTestServer(t)
|
||||
clientKey := registerSessionKey(t, srv, "alice@example")
|
||||
|
||||
// First connection succeeds.
|
||||
conn1, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn1.Close()
|
||||
writeHeaderPrefix(t, conn1, ModeAttach)
|
||||
performInitiator(t, conn1, clientKey, serverPub)
|
||||
writeHeaderTail(t, conn1)
|
||||
readRFBGreetingNoFailure(t, conn1)
|
||||
|
||||
// Revoke by pushing a fresh config that drops the pubkey entry.
|
||||
srv.UpdateVNCAuth(&sshauth.Config{})
|
||||
|
||||
// Same client, same Noise key, should now be denied.
|
||||
conn2, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn2.Close()
|
||||
writeHeaderPrefix(t, conn2, ModeAttach)
|
||||
performInitiator(t, conn2, clientKey, serverPub)
|
||||
writeHeaderTail(t, conn2)
|
||||
|
||||
reason := readRFBFailure(t, conn2)
|
||||
assert.Contains(t, reason, RejectCodeAuthForbidden)
|
||||
assert.Contains(t, reason, "session pubkey not registered")
|
||||
}
|
||||
|
||||
// TestNoise_NoIdentityKey_FailsClosed ensures a server constructed
|
||||
// without a static private key still rejects authenticated connections
|
||||
// fail-closed; it must not silently accept the client.
|
||||
func TestNoise_NoIdentityKey_FailsClosed(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(false)
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
require.NoError(t, srv.Start(t.Context(), addr, network))
|
||||
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
||||
t.Cleanup(func() { _ = srv.Stop() })
|
||||
|
||||
clientKey, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
fakeServerKey, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := net.Dial("tcp", srv.listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
|
||||
state, err := noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: vncNoiseSuite,
|
||||
Pattern: noise.HandshakeIK,
|
||||
Initiator: true,
|
||||
StaticKeypair: clientKey,
|
||||
PeerStatic: fakeServerKey.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
msg1, _, _, err := state.WriteMessage(nil, nil)
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Write(append([]byte("NBV3"), msg1...))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
|
||||
var b [1]byte
|
||||
_, err = io.ReadFull(conn, b[:])
|
||||
require.Error(t, err, "server without identity key must not write the RFB greeting")
|
||||
}
|
||||
|
||||
// TestNoise_DerivedIdentityPublicMatchesPrivate sanity-checks the
|
||||
// derivation done in New(): the identityPublic must be Curve25519.
|
||||
// Basepoint multiplied with identityKey.
|
||||
func TestNoise_DerivedIdentityPublicMatchesPrivate(t *testing.T) {
|
||||
priv := make([]byte, 32)
|
||||
for i := range priv {
|
||||
priv[i] = byte(i + 1)
|
||||
}
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, priv)
|
||||
|
||||
expected, err := curve25519.X25519(priv, curve25519.Basepoint)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, srv.identityPublic)
|
||||
}
|
||||
|
||||
// TestNoise_SessionMode_OSUserCheckRunsAfterHandshake verifies that a
|
||||
// successful Noise handshake doesn't bypass OS-user authorization: an
|
||||
// authenticated key whose user index isn't mapped to the requested OS
|
||||
// user must be rejected.
|
||||
func TestNoise_SessionMode_OSUserCheckRunsAfterHandshake(t *testing.T) {
|
||||
addr, srv, serverPub := noiseTestServer(t)
|
||||
|
||||
clientKey, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
userHash, err := sshuserhash.HashUserID("alice@example")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Map Alice only to "alice" OS user, not the wildcard.
|
||||
srv.UpdateVNCAuth(&sshauth.Config{
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{userHash},
|
||||
MachineUsers: map[string][]uint32{"alice": {0}},
|
||||
SessionPubKeys: []sshauth.SessionPubKey{
|
||||
{PubKey: clientKey.Public, UserIDHash: userHash},
|
||||
},
|
||||
})
|
||||
|
||||
// Request session for "bob" — Noise succeeds, OS-user check denies.
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
bob := []byte("bob")
|
||||
prefix := []byte{ModeSession, 0, byte(len(bob))}
|
||||
prefix = append(prefix, bob...)
|
||||
_, err = conn.Write(prefix)
|
||||
require.NoError(t, err)
|
||||
|
||||
performInitiator(t, conn, clientKey, serverPub)
|
||||
writeHeaderTail(t, conn)
|
||||
|
||||
reason := readRFBFailure(t, conn)
|
||||
assert.Contains(t, reason, RejectCodeAuthForbidden)
|
||||
assert.Contains(t, reason, "authorize OS user")
|
||||
}
|
||||
59
client/vnc/server/pseudo_encodings_test.go
Normal file
59
client/vnc/server/pseudo_encodings_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestEncodeDesktopSizeBody(t *testing.T) {
|
||||
got := encodeDesktopSizeBody(1920, 1080)
|
||||
if len(got) != 12 {
|
||||
t.Fatalf("DesktopSize body length: want 12, got %d", len(got))
|
||||
}
|
||||
if got[0] != 0 || got[1] != 0 || got[2] != 0 || got[3] != 0 {
|
||||
t.Fatalf("DesktopSize: x and y must be zero; got % x", got[0:4])
|
||||
}
|
||||
if got[4] != 0x07 || got[5] != 0x80 {
|
||||
t.Fatalf("DesktopSize: width should be 1920 (0x0780); got % x", got[4:6])
|
||||
}
|
||||
if got[6] != 0x04 || got[7] != 0x38 {
|
||||
t.Fatalf("DesktopSize: height should be 1080 (0x0438); got % x", got[6:8])
|
||||
}
|
||||
// Encoding = -223 → 0xFFFFFF21 in two's complement big-endian.
|
||||
if got[8] != 0xFF || got[9] != 0xFF || got[10] != 0xFF || got[11] != 0x21 {
|
||||
t.Fatalf("DesktopSize: encoding bytes wrong: % x", got[8:12])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeDesktopNameBody(t *testing.T) {
|
||||
name := "vma@debian3"
|
||||
got := encodeDesktopNameBody(name)
|
||||
if len(got) != 12+4+len(name) {
|
||||
t.Fatalf("DesktopName body length: want %d, got %d", 12+4+len(name), len(got))
|
||||
}
|
||||
// Encoding = -307 → 0xFFFFFECD.
|
||||
if got[8] != 0xFF || got[9] != 0xFF || got[10] != 0xFE || got[11] != 0xCD {
|
||||
t.Fatalf("DesktopName: encoding bytes wrong: % x", got[8:12])
|
||||
}
|
||||
if got[12] != 0 || got[13] != 0 || got[14] != 0 || got[15] != byte(len(name)) {
|
||||
t.Fatalf("DesktopName: name length prefix wrong: % x", got[12:16])
|
||||
}
|
||||
if string(got[16:]) != name {
|
||||
t.Fatalf("DesktopName: name body wrong: %q", got[16:])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeLastRectBody(t *testing.T) {
|
||||
got := encodeLastRectBody()
|
||||
if len(got) != 12 {
|
||||
t.Fatalf("LastRect body length: want 12, got %d", len(got))
|
||||
}
|
||||
for i := 0; i < 8; i++ {
|
||||
if got[i] != 0 {
|
||||
t.Fatalf("LastRect: header bytes 0..7 must be zero; got byte %d = 0x%02x", i, got[i])
|
||||
}
|
||||
}
|
||||
// Encoding = -224 → 0xFFFFFF20.
|
||||
if got[8] != 0xFF || got[9] != 0xFF || got[10] != 0xFF || got[11] != 0x20 {
|
||||
t.Fatalf("LastRect: encoding bytes wrong: % x", got[8:12])
|
||||
}
|
||||
}
|
||||
806
client/vnc/server/rfb.go
Normal file
806
client/vnc/server/rfb.go
Normal file
@@ -0,0 +1,806 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/zlib"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/jpeg"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// rect describes a rectangle on the framebuffer in pixels.
|
||||
type rect struct {
|
||||
x, y, w, h int
|
||||
}
|
||||
|
||||
const (
|
||||
rfbProtocolVersion = "RFB 003.008\n"
|
||||
|
||||
secNone = 1
|
||||
|
||||
// Client message types.
|
||||
clientSetPixelFormat = 0
|
||||
clientSetEncodings = 2
|
||||
clientFramebufferUpdateRequest = 3
|
||||
clientKeyEvent = 4
|
||||
clientPointerEvent = 5
|
||||
clientCutText = 6
|
||||
// clientQEMUMessage is the QEMU vendor message wrapper. The subtype
|
||||
// byte that follows selects the actual operation; we only handle the
|
||||
// Extended Key Event (subtype 0) which carries a hardware scancode in
|
||||
// addition to the X11 keysym. Layout-independent key entry.
|
||||
clientQEMUMessage = 255
|
||||
|
||||
// QEMU Extended Key Event subtype carried inside clientQEMUMessage.
|
||||
qemuSubtypeExtendedKeyEvent = 0
|
||||
|
||||
// clientNetbirdTypeText is a NetBird-specific message that asks the
|
||||
// server to synthesize the given text as keystrokes regardless of the
|
||||
// active desktop. Lets a client push host clipboard content into a
|
||||
// Windows secure desktop (Winlogon, UAC), where the OS clipboard is
|
||||
// isolated. Format mirrors clientCutText: 1-byte message type + 3-byte
|
||||
// padding + 4-byte length + text bytes. The opcode is in the
|
||||
// vendor-specific range (>=128).
|
||||
clientNetbirdTypeText = 250
|
||||
|
||||
// clientNetbirdShowRemoteCursor toggles "show remote cursor" mode.
|
||||
// When enabled the encoder composites the server cursor sprite into
|
||||
// the captured framebuffer and suppresses the Cursor pseudo-encoding
|
||||
// so the client sees a single pointer at the remote position.
|
||||
// Wire format: 1-byte msgType + 1-byte enable flag + 6 padding bytes
|
||||
// reserved for future arguments (so the message is fixed-size).
|
||||
clientNetbirdShowRemoteCursor = 251
|
||||
|
||||
// Server message types.
|
||||
serverFramebufferUpdate = 0
|
||||
serverCutText = 3
|
||||
|
||||
// Encoding types.
|
||||
encRaw = 0
|
||||
encCopyRect = 1
|
||||
encHextile = 5
|
||||
encZlib = 6
|
||||
encTight = 7
|
||||
|
||||
// Pseudo-encodings carried over wire as rects with a negative
|
||||
// encoding value. The client advertises supported optional protocol
|
||||
// extensions by listing these in SetEncodings.
|
||||
pseudoEncCursor = -239
|
||||
pseudoEncDesktopSize = -223
|
||||
pseudoEncLastRect = -224
|
||||
pseudoEncQEMUExtendedKeyEvent = -258
|
||||
pseudoEncDesktopName = -307
|
||||
pseudoEncExtendedDesktopSize = -308
|
||||
pseudoEncExtendedMouseButtons = -316
|
||||
|
||||
// Quality/Compression level pseudo-encodings. The client picks one
|
||||
// value from each range to tune JPEG quality and zlib effort. 0 is
|
||||
// lowest quality / fastest, 9 is highest quality / best compression.
|
||||
pseudoEncQualityLevelMin = -32
|
||||
pseudoEncQualityLevelMax = -23
|
||||
pseudoEncCompressLevelMin = -256
|
||||
pseudoEncCompressLevelMax = -247
|
||||
|
||||
// Hextile sub-encoding bits used by the SolidFill fast path.
|
||||
hextileBackgroundSpecified = 0x02
|
||||
hextileSubSize = 16
|
||||
|
||||
// Tight compression-control byte top nibble. Stream-reset bits 0-3
|
||||
// (one per zlib stream) are unused while we run a single stream.
|
||||
tightFillSubenc = 0x80
|
||||
tightJPEGSubenc = 0x90
|
||||
tightBasicFilter = 0x40 // Bit 6 set = explicit filter byte follows.
|
||||
tightFilterCopy = 0x00 // No-op filter, raw pixel stream.
|
||||
|
||||
// JPEG quality used by the Tight encoder. 70 is a reasonable speed/
|
||||
// quality knee; bandwidth roughly halves vs raw RGB while staying
|
||||
// visually clean for typical desktop content. Large rects (e.g. a
|
||||
// fullscreen video region) drop to a lower quality so the encoder
|
||||
// keeps up at 30+ fps; the visual hit is small for moving content.
|
||||
tightJPEGQuality = 70
|
||||
tightJPEGQualityMedium = 55
|
||||
tightJPEGQualityLarge = 40
|
||||
tightJPEGMediumPixels = 800 * 600 // ≈ SVGA, applies medium tier
|
||||
tightJPEGLargePixels = 1280 * 720 // ≈ 720p, applies large tier
|
||||
// Minimum rect area before we consider JPEG. Below this, header
|
||||
// overhead dominates and Basic+zlib wins.
|
||||
tightJPEGMinArea = 4096 // 64×64 ≈ 1 tile
|
||||
// Distinct-colour cap below which we still prefer Basic+zlib (text,
|
||||
// UI). Sampled, not exhaustive: cheap to compute, good enough.
|
||||
tightJPEGMinColors = 64
|
||||
)
|
||||
|
||||
// serverPixelFormat is the pixel format the server advertises and requires:
|
||||
// 32bpp RGBA, little-endian, true-colour, 8 bits per channel at standard
|
||||
// shifts (R=16, G=8, B=0). handleSetPixelFormat rejects any client that
|
||||
// negotiates a different format. Browser-side decoders are little-endian
|
||||
// natively, so advertising little-endian skips a byte-swap on every pixel.
|
||||
var serverPixelFormat = [16]byte{
|
||||
32, // bits-per-pixel
|
||||
24, // depth
|
||||
0, // big-endian-flag
|
||||
1, // true-colour-flag
|
||||
0, 255, // red-max
|
||||
0, 255, // green-max
|
||||
0, 255, // blue-max
|
||||
16, // red-shift
|
||||
8, // green-shift
|
||||
0, // blue-shift
|
||||
0, 0, 0, // padding
|
||||
}
|
||||
|
||||
// clientPixelFormat holds the negotiated pixel format. Only RGB channel
|
||||
// shifts are tracked: every other field is constrained by the server to
|
||||
// the values in serverPixelFormat (32bpp / little-endian / truecolour /
|
||||
// 8-bit channels) and rejected at SetPixelFormat time if the client tries
|
||||
// to negotiate otherwise.
|
||||
type clientPixelFormat struct {
|
||||
rShift uint8
|
||||
gShift uint8
|
||||
bShift uint8
|
||||
}
|
||||
|
||||
func defaultClientPixelFormat() clientPixelFormat {
|
||||
return clientPixelFormat{
|
||||
rShift: serverPixelFormat[10],
|
||||
gShift: serverPixelFormat[11],
|
||||
bShift: serverPixelFormat[12],
|
||||
}
|
||||
}
|
||||
|
||||
// parsePixelFormat returns the negotiated client pixel format, or an error
|
||||
// if the client tried to negotiate an unsupported format. The server only
|
||||
// supports 32bpp truecolour little-endian with 8-bit channels; arbitrary
|
||||
// shifts within that constraint are allowed because they are cheap to honour.
|
||||
func parsePixelFormat(pf []byte) (clientPixelFormat, error) {
|
||||
bpp := pf[0]
|
||||
bigEndian := pf[2]
|
||||
trueColour := pf[3]
|
||||
rMax := binary.BigEndian.Uint16(pf[4:6])
|
||||
gMax := binary.BigEndian.Uint16(pf[6:8])
|
||||
bMax := binary.BigEndian.Uint16(pf[8:10])
|
||||
if bpp != 32 || bigEndian != 0 || trueColour != 1 ||
|
||||
rMax != 255 || gMax != 255 || bMax != 255 {
|
||||
return clientPixelFormat{}, fmt.Errorf(
|
||||
"unsupported pixel format (bpp=%d be=%d tc=%d rgb-max=%d/%d/%d): "+
|
||||
"server only supports 32bpp truecolour little-endian 8-bit channels",
|
||||
bpp, bigEndian, trueColour, rMax, gMax, bMax)
|
||||
}
|
||||
return clientPixelFormat{
|
||||
rShift: pf[10],
|
||||
gShift: pf[11],
|
||||
bShift: pf[12],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// encodeCopyRectBody emits the per-rect payload for a CopyRect rectangle:
|
||||
// the 12-byte rect header (dst position + size + encoding=1) plus a 4-byte
|
||||
// source position. Used inside multi-rect FramebufferUpdate messages, so
|
||||
// the 4-byte FU header is the caller's responsibility.
|
||||
func encodeCopyRectBody(srcX, srcY, dstX, dstY, w, h int) []byte {
|
||||
buf := make([]byte, 12+4)
|
||||
binary.BigEndian.PutUint16(buf[0:2], uint16(dstX))
|
||||
binary.BigEndian.PutUint16(buf[2:4], uint16(dstY))
|
||||
binary.BigEndian.PutUint16(buf[4:6], uint16(w))
|
||||
binary.BigEndian.PutUint16(buf[6:8], uint16(h))
|
||||
binary.BigEndian.PutUint32(buf[8:12], uint32(encCopyRect))
|
||||
binary.BigEndian.PutUint16(buf[12:14], uint16(srcX))
|
||||
binary.BigEndian.PutUint16(buf[14:16], uint16(srcY))
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeDesktopSizeBody emits a DesktopSize pseudo-encoded rectangle. The
|
||||
// "rect" carries no pixel data: x and y are zero, w and h are the new
|
||||
// framebuffer dimensions, and encoding=-223 signals to the client that the
|
||||
// framebuffer was resized. Clients reallocate their backing buffer and
|
||||
// expect a full update at the new size to follow.
|
||||
func encodeDesktopSizeBody(w, h int) []byte {
|
||||
buf := make([]byte, 12)
|
||||
binary.BigEndian.PutUint16(buf[0:2], 0)
|
||||
binary.BigEndian.PutUint16(buf[2:4], 0)
|
||||
binary.BigEndian.PutUint16(buf[4:6], uint16(w))
|
||||
binary.BigEndian.PutUint16(buf[6:8], uint16(h))
|
||||
enc := int32(pseudoEncDesktopSize)
|
||||
binary.BigEndian.PutUint32(buf[8:12], uint32(enc))
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeDesktopNameBody emits a DesktopName pseudo-encoded rectangle. The
|
||||
// rect header is all zeros and encoding=-307; the body is a 4-byte
|
||||
// big-endian length followed by the UTF-8 name. Clients update their
|
||||
// window title or label without reconnecting.
|
||||
func encodeDesktopNameBody(name string) []byte {
|
||||
nameBytes := []byte(name)
|
||||
buf := make([]byte, 12+4+len(nameBytes))
|
||||
binary.BigEndian.PutUint16(buf[0:2], 0)
|
||||
binary.BigEndian.PutUint16(buf[2:4], 0)
|
||||
binary.BigEndian.PutUint16(buf[4:6], 0)
|
||||
binary.BigEndian.PutUint16(buf[6:8], 0)
|
||||
enc := int32(pseudoEncDesktopName)
|
||||
binary.BigEndian.PutUint32(buf[8:12], uint32(enc))
|
||||
binary.BigEndian.PutUint32(buf[12:16], uint32(len(nameBytes)))
|
||||
copy(buf[16:], nameBytes)
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeLastRectBody emits a LastRect sentinel. When the server sets
|
||||
// numRects=0xFFFF in the FramebufferUpdate header, the client reads rects
|
||||
// until it sees one with this encoding. Lets us stream rects from a
|
||||
// goroutine without committing to a count up front.
|
||||
func encodeLastRectBody() []byte {
|
||||
buf := make([]byte, 12)
|
||||
// x, y, w, h all zero; encoding = -224.
|
||||
enc := int32(pseudoEncLastRect)
|
||||
binary.BigEndian.PutUint32(buf[8:12], uint32(enc))
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeRawRect encodes a framebuffer region as a raw RFB rectangle.
|
||||
// The returned buffer includes the FramebufferUpdate header (1 rectangle).
|
||||
func encodeRawRect(img *image.RGBA, pf clientPixelFormat, x, y, w, h int) []byte {
|
||||
buf := make([]byte, 4+12+w*h*4)
|
||||
|
||||
// FramebufferUpdate header.
|
||||
buf[0] = serverFramebufferUpdate
|
||||
buf[1] = 0 // padding
|
||||
binary.BigEndian.PutUint16(buf[2:4], 1)
|
||||
|
||||
// Rectangle header.
|
||||
binary.BigEndian.PutUint16(buf[4:6], uint16(x))
|
||||
binary.BigEndian.PutUint16(buf[6:8], uint16(y))
|
||||
binary.BigEndian.PutUint16(buf[8:10], uint16(w))
|
||||
binary.BigEndian.PutUint16(buf[10:12], uint16(h))
|
||||
binary.BigEndian.PutUint32(buf[12:16], uint32(encRaw))
|
||||
|
||||
writePixels(buf[16:], img, pf, rect{x, y, w, h})
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeZlibRect encodes a framebuffer region using the standalone Zlib
|
||||
// encoding. The zlib stream is continuous for the entire VNC session: the
|
||||
// client keeps a single inflate context and reuses it across rects. The
|
||||
// returned buffer includes the 4-byte FramebufferUpdate header.
|
||||
func encodeZlibRect(img *image.RGBA, pf clientPixelFormat, x, y, w, h int, z *zlibState) []byte {
|
||||
zw, zbuf := z.w, z.buf
|
||||
zbuf.Reset()
|
||||
|
||||
rowBytes := w * 4
|
||||
total := rowBytes * h
|
||||
if cap(z.scratch) < total {
|
||||
z.scratch = make([]byte, total)
|
||||
}
|
||||
scratch := z.scratch[:total]
|
||||
writePixels(scratch, img, pf, rect{x, y, w, h})
|
||||
for row := 0; row < h; row++ {
|
||||
if _, err := zw.Write(scratch[row*rowBytes : (row+1)*rowBytes]); err != nil {
|
||||
log.Debugf("zlib write row %d: %v", row, err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if err := zw.Flush(); err != nil {
|
||||
log.Debugf("zlib flush: %v", err)
|
||||
return nil
|
||||
}
|
||||
compressed := zbuf.Bytes()
|
||||
|
||||
buf := make([]byte, 4+12+4+len(compressed))
|
||||
buf[0] = serverFramebufferUpdate
|
||||
binary.BigEndian.PutUint16(buf[2:4], 1)
|
||||
binary.BigEndian.PutUint16(buf[4:6], uint16(x))
|
||||
binary.BigEndian.PutUint16(buf[6:8], uint16(y))
|
||||
binary.BigEndian.PutUint16(buf[8:10], uint16(w))
|
||||
binary.BigEndian.PutUint16(buf[10:12], uint16(h))
|
||||
binary.BigEndian.PutUint32(buf[12:16], uint32(encZlib))
|
||||
binary.BigEndian.PutUint32(buf[16:20], uint32(len(compressed)))
|
||||
copy(buf[20:], compressed)
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeHextileSolidRect emits a Hextile-encoded rectangle whose every
|
||||
// pixel is the same colour. The first sub-tile carries the background
|
||||
// pixel; remaining sub-tiles inherit it via a zero sub-encoding byte,
|
||||
// collapsing a uniform 64×64 tile down to ~20 bytes. The returned buffer
|
||||
// starts with the 12-byte rect header; callers prepend a FramebufferUpdate
|
||||
// header.
|
||||
func encodeHextileSolidRect(r, g, b byte, pf clientPixelFormat, rc rect) []byte {
|
||||
cols := (rc.w + hextileSubSize - 1) / hextileSubSize
|
||||
rows := (rc.h + hextileSubSize - 1) / hextileSubSize
|
||||
subs := cols * rows
|
||||
// One sub-encoding byte plus a 32bpp pixel for the first sub-tile, then
|
||||
// one zero byte per remaining sub-tile to inherit the background.
|
||||
bodySize := 1 + 4 + (subs - 1)
|
||||
buf := make([]byte, 12+bodySize)
|
||||
|
||||
binary.BigEndian.PutUint16(buf[0:2], uint16(rc.x))
|
||||
binary.BigEndian.PutUint16(buf[2:4], uint16(rc.y))
|
||||
binary.BigEndian.PutUint16(buf[4:6], uint16(rc.w))
|
||||
binary.BigEndian.PutUint16(buf[6:8], uint16(rc.h))
|
||||
binary.BigEndian.PutUint32(buf[8:12], uint32(encHextile))
|
||||
|
||||
buf[12] = hextileBackgroundSpecified
|
||||
pixel := (uint32(r) << pf.rShift) | (uint32(g) << pf.gShift) | (uint32(b) << pf.bShift)
|
||||
binary.LittleEndian.PutUint32(buf[13:17], pixel)
|
||||
return buf
|
||||
}
|
||||
|
||||
// writePixels writes a rectangle of img into dst as 32bpp little-endian
|
||||
// pixels at the negotiated RGB shifts. The pixel format is constrained at
|
||||
// SetPixelFormat time so we can assume 4 bytes per pixel, 8-bit channels,
|
||||
// and little-endian byte order; arbitrary shifts (R/G/B order) are honoured.
|
||||
func writePixels(dst []byte, img *image.RGBA, pf clientPixelFormat, r rect) {
|
||||
stride := img.Stride
|
||||
rShift, gShift, bShift := pf.rShift, pf.gShift, pf.bShift
|
||||
off := 0
|
||||
for row := r.y; row < r.y+r.h; row++ {
|
||||
p := row*stride + r.x*4
|
||||
for col := 0; col < r.w; col++ {
|
||||
pixel := (uint32(img.Pix[p]) << rShift) |
|
||||
(uint32(img.Pix[p+1]) << gShift) |
|
||||
(uint32(img.Pix[p+2]) << bShift)
|
||||
binary.LittleEndian.PutUint32(dst[off:off+4], pixel)
|
||||
p += 4
|
||||
off += 4
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// diffTiles compares two RGBA images and returns a tile-ordered list of
|
||||
// dirty tiles, one entry per tile. Tile order is top-to-bottom, left-to-
|
||||
// right within each row. The caller decides whether to coalesce or hand
|
||||
// the list off to the CopyRect detector first.
|
||||
func diffTiles(prev, cur *image.RGBA, w, h, tileSize int) [][4]int {
|
||||
if prev == nil {
|
||||
return [][4]int{{0, 0, w, h}}
|
||||
}
|
||||
var rects [][4]int
|
||||
for ty := 0; ty < h; ty += tileSize {
|
||||
th := min(tileSize, h-ty)
|
||||
for tx := 0; tx < w; tx += tileSize {
|
||||
tw := min(tileSize, w-tx)
|
||||
if tileChanged(prev, cur, tx, ty, tw, th) {
|
||||
rects = append(rects, [4]int{tx, ty, tw, th})
|
||||
}
|
||||
}
|
||||
}
|
||||
return rects
|
||||
}
|
||||
|
||||
// diffRects is the legacy convenience: diff then coalesce. Used by paths
|
||||
// that don't go through the CopyRect detector and by tests that exercise
|
||||
// the diff-plus-coalesce pipeline as one unit.
|
||||
func diffRects(prev, cur *image.RGBA, w, h, tileSize int) [][4]int {
|
||||
return coalesceRects(diffTiles(prev, cur, w, h, tileSize))
|
||||
}
|
||||
|
||||
// coalesceRects merges adjacent dirty tiles into larger rectangles to cut
|
||||
// per-rect framing overhead. Input must be tile-ordered (top-to-bottom rows,
|
||||
// left-to-right within each row), as produced by diffRects. Two passes:
|
||||
// 1. Horizontal: within a row, merge tiles whose x-extents touch.
|
||||
// 2. Vertical: merge a row's run with the run directly above it when they
|
||||
// share the same [x, x+w] extent and are vertically adjacent.
|
||||
//
|
||||
// Larger merged rects still encode correctly: Hextile-solid and Zlib paths
|
||||
// both work on arbitrary sizes, and uniform-tile detection still fires when
|
||||
// the merged region happens to be a single colour.
|
||||
func coalesceRects(in [][4]int) [][4]int {
|
||||
if len(in) < 2 {
|
||||
return in
|
||||
}
|
||||
c := newRectCoalescer(len(in))
|
||||
c.curY = in[0][1]
|
||||
for _, r := range in {
|
||||
c.consume(r)
|
||||
}
|
||||
c.flushCurrentRow()
|
||||
return c.out
|
||||
}
|
||||
|
||||
// rectCoalescer is the working state for coalesceRects, lifted out so the
|
||||
// algorithm can be split across small methods without long parameter lists
|
||||
// and to keep each method's cognitive complexity below Sonar's threshold.
|
||||
type rectCoalescer struct {
|
||||
out [][4]int
|
||||
prevRowStart, prevRowEnd int
|
||||
curRowStart int
|
||||
curY int
|
||||
}
|
||||
|
||||
func newRectCoalescer(capacity int) *rectCoalescer {
|
||||
return &rectCoalescer{out: make([][4]int, 0, capacity)}
|
||||
}
|
||||
|
||||
// consume processes one rect from the (row-ordered) input.
|
||||
func (c *rectCoalescer) consume(r [4]int) {
|
||||
if r[1] != c.curY {
|
||||
c.flushCurrentRow()
|
||||
c.prevRowEnd = len(c.out)
|
||||
c.curRowStart = len(c.out)
|
||||
c.curY = r[1]
|
||||
}
|
||||
if c.tryHorizontalMerge(r) {
|
||||
return
|
||||
}
|
||||
c.out = append(c.out, r)
|
||||
}
|
||||
|
||||
// tryHorizontalMerge extends the last run in the current row when r is
|
||||
// vertically aligned and horizontally adjacent to it.
|
||||
func (c *rectCoalescer) tryHorizontalMerge(r [4]int) bool {
|
||||
if len(c.out) <= c.curRowStart {
|
||||
return false
|
||||
}
|
||||
last := &c.out[len(c.out)-1]
|
||||
if last[1] == r[1] && last[3] == r[3] && last[0]+last[2] == r[0] {
|
||||
last[2] += r[2]
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// flushCurrentRow merges each run in the current row with any run from the
|
||||
// previous row that has identical x extent and is vertically adjacent.
|
||||
func (c *rectCoalescer) flushCurrentRow() {
|
||||
i := c.curRowStart
|
||||
for i < len(c.out) {
|
||||
if c.mergeWithPrevRow(i) {
|
||||
continue
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
// mergeWithPrevRow tries to extend a previous-row run downward to absorb
|
||||
// out[i]. Returns true and removes out[i] from the slice on success.
|
||||
func (c *rectCoalescer) mergeWithPrevRow(i int) bool {
|
||||
for j := c.prevRowStart; j < c.prevRowEnd; j++ {
|
||||
if c.out[j][0] == c.out[i][0] &&
|
||||
c.out[j][2] == c.out[i][2] &&
|
||||
c.out[j][1]+c.out[j][3] == c.out[i][1] {
|
||||
c.out[j][3] += c.out[i][3]
|
||||
copy(c.out[i:], c.out[i+1:])
|
||||
c.out = c.out[:len(c.out)-1]
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func tileChanged(prev, cur *image.RGBA, x, y, w, h int) bool {
|
||||
stride := prev.Stride
|
||||
for row := y; row < y+h; row++ {
|
||||
off := row*stride + x*4
|
||||
end := off + w*4
|
||||
prevRow := prev.Pix[off:end]
|
||||
curRow := cur.Pix[off:end]
|
||||
if !bytes.Equal(prevRow, curRow) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// tileIsUniform reports whether every pixel in the given rectangle of img is
|
||||
// the same RGBA value, and returns that pixel packed as 0xRRGGBBAA when so.
|
||||
// Uses uint32 comparisons across rows; returns early on the first mismatch.
|
||||
func tileIsUniform(img *image.RGBA, x, y, w, h int) (uint32, bool) {
|
||||
if w <= 0 || h <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
stride := img.Stride
|
||||
base := y*stride + x*4
|
||||
first := *(*uint32)(unsafe.Pointer(&img.Pix[base]))
|
||||
rowBytes := w * 4
|
||||
for row := 0; row < h; row++ {
|
||||
p := base + row*stride
|
||||
for col := 0; col < rowBytes; col += 4 {
|
||||
if *(*uint32)(unsafe.Pointer(&img.Pix[p+col])) != first {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
}
|
||||
return first, true
|
||||
}
|
||||
|
||||
// tightState holds the per-session JPEG scratch buffer and reused encoders
|
||||
// so per-rect encoding stays alloc-free in the steady state.
|
||||
type tightState struct {
|
||||
jpegBuf *bytes.Buffer
|
||||
zlib *zlibState
|
||||
scratch []byte // RGB-packed pixel scratch for JPEG and Basic paths.
|
||||
// colorSeen is reused by sampledColorCount per rect; cleared via the Go
|
||||
// runtime's map-clear fast path to avoid a fresh allocation each call.
|
||||
colorSeen map[uint32]struct{}
|
||||
// jpegQualityOverride forces a fixed JPEG quality on every rect when
|
||||
// non-zero (set from the client's QualityLevel pseudo-encoding). Zero
|
||||
// falls back to the area-based tiers in tightQualityFor.
|
||||
jpegQualityOverride int
|
||||
// qualityLevel and compressLevel are the 0..9 levels currently applied,
|
||||
// or -1 if the client did not express a preference. Used to decide
|
||||
// whether a SetEncodings refresh needs to recreate the tight state.
|
||||
qualityLevel int
|
||||
compressLevel int
|
||||
// pendingZlibReset becomes true when this tightState replaces an
|
||||
// in-use one (e.g. CompressLevel change mid-session). The next Basic
|
||||
// rect we emit ORs the stream-0 reset bit into its sub-encoding byte
|
||||
// so the client's inflater drops its now-stale dictionary; cleared
|
||||
// after one emission.
|
||||
pendingZlibReset bool
|
||||
}
|
||||
|
||||
func newTightState() *tightState {
|
||||
return newTightStateWithLevels(-1, -1)
|
||||
}
|
||||
|
||||
// newTightStateWithLevels builds a tightState whose zlib stream and JPEG
|
||||
// quality reflect the client's QualityLevel / CompressLevel pseudo-encodings.
|
||||
// Pass -1 for either level to keep our defaults (BestSpeed zlib and the
|
||||
// area-tiered JPEG quality in tightQualityFor).
|
||||
func newTightStateWithLevels(qualityLevel, compressLevel int) *tightState {
|
||||
return &tightState{
|
||||
jpegBuf: &bytes.Buffer{},
|
||||
zlib: newZlibStateLevel(zlibLevelFor(compressLevel)),
|
||||
colorSeen: make(map[uint32]struct{}, 64),
|
||||
jpegQualityOverride: jpegQualityForLevel(qualityLevel),
|
||||
qualityLevel: qualityLevel,
|
||||
compressLevel: compressLevel,
|
||||
}
|
||||
}
|
||||
|
||||
// jpegQualityForLevel maps a 0..9 client preference to a JPEG quality value.
|
||||
// Returns 0 when no preference is set (-1), letting the encoder fall back
|
||||
// to the area-based tiers. The encoder lowers this dynamically when the
|
||||
// socket is backpressured, so this routine emits the unclamped, client-
|
||||
// requested value.
|
||||
func jpegQualityForLevel(level int) int {
|
||||
if level < 0 {
|
||||
return 0
|
||||
}
|
||||
if level > 9 {
|
||||
level = 9
|
||||
}
|
||||
return 30 + level*7
|
||||
}
|
||||
|
||||
// zlibLevelFor maps a 0..9 client preference to a zlib compression level.
|
||||
// Level 0 ("no compression") would emit larger output than input on most
|
||||
// rects, so we floor to BestSpeed (1). -1 (no preference) also picks
|
||||
// BestSpeed: matches the historical default before the pseudo-encoding
|
||||
// was honoured.
|
||||
func zlibLevelFor(level int) int {
|
||||
if level < 1 {
|
||||
return zlib.BestSpeed
|
||||
}
|
||||
if level > zlib.BestCompression {
|
||||
return zlib.BestCompression
|
||||
}
|
||||
return level
|
||||
}
|
||||
|
||||
// tightMaxLength is the maximum payload size representable in the Tight
|
||||
// compact length prefix (RFB §7.7.6: 22 bits, three 7+7+8 bit groups).
|
||||
// Exceeding this would silently truncate the high byte; callers must fall
|
||||
// back to a different encoding when an attempt would overflow.
|
||||
const tightMaxLength = (1 << 22) - 1
|
||||
|
||||
// encodeTightRect emits a single Tight-encoded rect. Picks Fill for uniform
|
||||
// content, JPEG for photo-like rects above a size and color-count threshold,
|
||||
// and Basic+zlib otherwise. When Tight's 22-bit length cap would be exceeded
|
||||
// (huge full-frame rects under bad compression), falls back to Raw. Returns
|
||||
// the rect header + body (no FramebufferUpdate header).
|
||||
func encodeTightRect(img *image.RGBA, pf clientPixelFormat, x, y, w, h int, t *tightState) []byte {
|
||||
if pixel, uniform := tileIsUniform(img, x, y, w, h); uniform {
|
||||
return encodeTightFill(x, y, w, h, byte(pixel), byte(pixel>>8), byte(pixel>>16))
|
||||
}
|
||||
if w*h >= tightJPEGMinArea && sampledColorCountInto(t.colorSeen, img, x, y, w, h, tightJPEGMinColors) >= tightJPEGMinColors {
|
||||
if buf, ok := encodeTightJPEG(img, x, y, w, h, t); ok {
|
||||
return buf
|
||||
}
|
||||
}
|
||||
if buf, ok := encodeTightBasic(img, x, y, w, h, t); ok {
|
||||
return buf
|
||||
}
|
||||
// Fall back to Raw rect body (skip the 4-byte FU header that encodeRawRect
|
||||
// prepends, since callers compose their own FU header).
|
||||
return encodeRawRect(img, pf, x, y, w, h)[4:]
|
||||
}
|
||||
|
||||
func writeTightRectHeader(buf []byte, x, y, w, h int) {
|
||||
binary.BigEndian.PutUint16(buf[0:2], uint16(x))
|
||||
binary.BigEndian.PutUint16(buf[2:4], uint16(y))
|
||||
binary.BigEndian.PutUint16(buf[4:6], uint16(w))
|
||||
binary.BigEndian.PutUint16(buf[6:8], uint16(h))
|
||||
binary.BigEndian.PutUint32(buf[8:12], uint32(encTight))
|
||||
}
|
||||
|
||||
// appendTightLength encodes a Tight compact length prefix (1, 2, or 3 bytes
|
||||
// LE-ish, top bit of each byte signals continuation). Lengths exceeding
|
||||
// tightMaxLength would silently truncate the high byte; callers must clamp
|
||||
// or fall back before reaching here.
|
||||
func appendTightLength(buf []byte, n int) []byte {
|
||||
if n < 0 || n > tightMaxLength {
|
||||
panic(fmt.Sprintf("tight length out of range: %d", n))
|
||||
}
|
||||
b0 := byte(n & 0x7f)
|
||||
if n <= 0x7f {
|
||||
return append(buf, b0)
|
||||
}
|
||||
b0 |= 0x80
|
||||
b1 := byte((n >> 7) & 0x7f)
|
||||
if n <= 0x3fff {
|
||||
return append(buf, b0, b1)
|
||||
}
|
||||
b1 |= 0x80
|
||||
// High group is 8 bits per spec, but our cap guarantees the top 2 bits
|
||||
// are zero; mask defensively.
|
||||
b2 := byte((n >> 14) & 0xff)
|
||||
return append(buf, b0, b1, b2)
|
||||
}
|
||||
|
||||
// encodeTightFill emits a uniform rect: 12-byte rect header + 1-byte
|
||||
// subenc (0x80) + 3-byte RGB pixel. Tight Fill always uses 24-bit RGB
|
||||
// regardless of the negotiated pixel format.
|
||||
func encodeTightFill(x, y, w, h int, r, g, b byte) []byte {
|
||||
buf := make([]byte, 12+1+3)
|
||||
writeTightRectHeader(buf, x, y, w, h)
|
||||
buf[12] = tightFillSubenc
|
||||
buf[13] = r
|
||||
buf[14] = g
|
||||
buf[15] = b
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeTightJPEG compresses the rect as a baseline JPEG. Returns ok=false
|
||||
// if the encoder errors so the caller can fall back to Basic.
|
||||
func encodeTightJPEG(img *image.RGBA, x, y, w, h int, t *tightState) ([]byte, bool) {
|
||||
t.jpegBuf.Reset()
|
||||
sub := img.SubImage(image.Rect(img.Rect.Min.X+x, img.Rect.Min.Y+y, img.Rect.Min.X+x+w, img.Rect.Min.Y+y+h))
|
||||
q := t.jpegQualityOverride
|
||||
if q == 0 {
|
||||
q = tightQualityFor(w * h)
|
||||
}
|
||||
if err := jpeg.Encode(t.jpegBuf, sub, &jpeg.Options{Quality: q}); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
jpegBytes := t.jpegBuf.Bytes()
|
||||
if len(jpegBytes) > tightMaxLength {
|
||||
return nil, false
|
||||
}
|
||||
buf := make([]byte, 0, 12+1+3+len(jpegBytes))
|
||||
buf = buf[:12]
|
||||
writeTightRectHeader(buf, x, y, w, h)
|
||||
buf = append(buf, tightJPEGSubenc)
|
||||
buf = appendTightLength(buf, len(jpegBytes))
|
||||
buf = append(buf, jpegBytes...)
|
||||
return buf, true
|
||||
}
|
||||
|
||||
// encodeTightBasic emits Basic+zlib with the no-op (CopyFilter) filter.
|
||||
// Pixels are sent as 24-bit RGB ("TPIXEL" format) which most clients
|
||||
// negotiate when the server advertises 32bpp true colour. Streams under
|
||||
// 12 bytes ship uncompressed per RFB Tight spec. Returns ok=false when the
|
||||
// compressed payload would exceed Tight's 22-bit length cap or when zlib
|
||||
// errors, signalling the caller to fall back to Raw.
|
||||
func encodeTightBasic(img *image.RGBA, x, y, w, h int, t *tightState) ([]byte, bool) {
|
||||
pixelStream := w * h * 3
|
||||
if cap(t.scratch) < pixelStream {
|
||||
t.scratch = make([]byte, pixelStream)
|
||||
}
|
||||
scratch := t.scratch[:pixelStream]
|
||||
stride := img.Stride
|
||||
off := 0
|
||||
for row := y; row < y+h; row++ {
|
||||
p := row*stride + x*4
|
||||
for col := 0; col < w; col++ {
|
||||
scratch[off+0] = img.Pix[p]
|
||||
scratch[off+1] = img.Pix[p+1]
|
||||
scratch[off+2] = img.Pix[p+2]
|
||||
p += 4
|
||||
off += 3
|
||||
}
|
||||
}
|
||||
|
||||
// Sub-encoding byte: stream 0, basic encoding (top nibble = 0x40 =
|
||||
// explicit filter follows). The low nibble carries per-stream reset
|
||||
// flags; bit 0 here tells the client to reset its stream-0 inflater
|
||||
// when our deflater was just recreated.
|
||||
subenc := byte(tightBasicFilter)
|
||||
if t.pendingZlibReset {
|
||||
subenc |= 0x01
|
||||
t.pendingZlibReset = false
|
||||
}
|
||||
filter := byte(tightFilterCopy)
|
||||
|
||||
if pixelStream < 12 {
|
||||
buf := make([]byte, 0, 12+2+pixelStream)
|
||||
buf = buf[:12]
|
||||
writeTightRectHeader(buf, x, y, w, h)
|
||||
buf = append(buf, subenc, filter)
|
||||
buf = append(buf, scratch...)
|
||||
return buf, true
|
||||
}
|
||||
|
||||
z := t.zlib
|
||||
z.buf.Reset()
|
||||
if _, err := z.w.Write(scratch); err != nil {
|
||||
log.Debugf("tight zlib write: %v", err)
|
||||
return nil, false
|
||||
}
|
||||
if err := z.w.Flush(); err != nil {
|
||||
log.Debugf("tight zlib flush: %v", err)
|
||||
return nil, false
|
||||
}
|
||||
compressed := z.buf.Bytes()
|
||||
if len(compressed) > tightMaxLength {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
buf := make([]byte, 0, 12+2+5+len(compressed))
|
||||
buf = buf[:12]
|
||||
writeTightRectHeader(buf, x, y, w, h)
|
||||
buf = append(buf, subenc, filter)
|
||||
buf = appendTightLength(buf, len(compressed))
|
||||
buf = append(buf, compressed...)
|
||||
return buf, true
|
||||
}
|
||||
|
||||
func tightQualityFor(pixels int) int {
|
||||
switch {
|
||||
case pixels >= tightJPEGLargePixels:
|
||||
return tightJPEGQualityLarge
|
||||
case pixels >= tightJPEGMediumPixels:
|
||||
return tightJPEGQualityMedium
|
||||
default:
|
||||
return tightJPEGQuality
|
||||
}
|
||||
}
|
||||
|
||||
// sampledColorCountInto estimates distinct-colour count by checking up to
|
||||
// maxColors samples. The caller-provided `seen` map is cleared and reused so
|
||||
// per-rect Tight encoding stays alloc-free. Cheap O(maxColors) per call.
|
||||
func sampledColorCountInto(seen map[uint32]struct{}, img *image.RGBA, x, y, w, h, maxColors int) int {
|
||||
clear(seen)
|
||||
stride := img.Stride
|
||||
step := max((w*h)/(maxColors*4), 1)
|
||||
var idx int
|
||||
for row := 0; row < h; row++ {
|
||||
p := (y+row)*stride + x*4
|
||||
for col := 0; col < w; col++ {
|
||||
if idx%step == 0 {
|
||||
px := *(*uint32)(unsafe.Pointer(&img.Pix[p+col*4]))
|
||||
seen[px&0x00ffffff] = struct{}{}
|
||||
if len(seen) > maxColors {
|
||||
return len(seen)
|
||||
}
|
||||
}
|
||||
idx++
|
||||
}
|
||||
}
|
||||
return len(seen)
|
||||
}
|
||||
|
||||
// zlibState holds the persistent zlib writer and its output buffer, reused
|
||||
// across rects so steady-state Tight encoding stays alloc-free.
|
||||
type zlibState struct {
|
||||
buf *bytes.Buffer
|
||||
w *zlib.Writer
|
||||
// scratch stages the packed pixel stream for a rect before it is fed
|
||||
// to the deflater. Grown to the largest rect seen in the session and
|
||||
// reused to keep the steady-state encode allocation-free.
|
||||
scratch []byte
|
||||
}
|
||||
|
||||
func newZlibStateLevel(level int) *zlibState {
|
||||
buf := &bytes.Buffer{}
|
||||
w, _ := zlib.NewWriterLevel(buf, level)
|
||||
return &zlibState{buf: buf, w: w}
|
||||
}
|
||||
|
||||
func (z *zlibState) Close() error {
|
||||
return z.w.Close()
|
||||
}
|
||||
364
client/vnc/server/rfb_bench_test.go
Normal file
364
client/vnc/server/rfb_bench_test.go
Normal file
@@ -0,0 +1,364 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"image"
|
||||
"math/rand"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Representative frame sizes.
|
||||
var benchRects = []struct {
|
||||
name string
|
||||
w, h int
|
||||
}{
|
||||
{"1080p_full", 1920, 1080},
|
||||
{"720p_full", 1280, 720},
|
||||
{"256x256_tile", 256, 256},
|
||||
{"64x64_tile", 64, 64},
|
||||
}
|
||||
|
||||
func makeBenchImage(w, h int, seed int64) *image.RGBA {
|
||||
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
r := rand.New(rand.NewSource(seed))
|
||||
_, _ = r.Read(img.Pix)
|
||||
// Force alpha byte so the fast path and slow path produce identical output.
|
||||
for i := 3; i < len(img.Pix); i += 4 {
|
||||
img.Pix[i] = 0xff
|
||||
}
|
||||
return img
|
||||
}
|
||||
|
||||
func makeBenchImagePartial(w, h, changedRows int) (*image.RGBA, *image.RGBA) {
|
||||
prev := makeBenchImage(w, h, 1)
|
||||
cur := image.NewRGBA(prev.Rect)
|
||||
copy(cur.Pix, prev.Pix)
|
||||
if changedRows > h {
|
||||
changedRows = h
|
||||
}
|
||||
// Dirty the first `changedRows` rows.
|
||||
r := rand.New(rand.NewSource(2))
|
||||
_, _ = r.Read(cur.Pix[:changedRows*cur.Stride])
|
||||
for i := 3; i < len(cur.Pix); i += 4 {
|
||||
cur.Pix[i] = 0xff
|
||||
}
|
||||
return prev, cur
|
||||
}
|
||||
|
||||
func BenchmarkEncodeRawRect(b *testing.B) {
|
||||
pf := defaultClientPixelFormat()
|
||||
for _, r := range benchRects {
|
||||
img := makeBenchImage(r.w, r.h, 1)
|
||||
b.Run(r.name, func(b *testing.B) {
|
||||
b.SetBytes(int64(r.w * r.h * 4))
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = encodeRawRect(img, pf, 0, 0, r.w, r.h)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodeTightRect(b *testing.B) {
|
||||
pf := defaultClientPixelFormat()
|
||||
for _, r := range benchRects {
|
||||
img := makeBenchImage(r.w, r.h, 1)
|
||||
t := newTightState()
|
||||
b.Run(r.name, func(b *testing.B) {
|
||||
b.SetBytes(int64(r.w * r.h * 4))
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = encodeTightRect(img, pf, 0, 0, r.w, r.h, t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkWritePixels isolates the per-pixel pack loop from the allocation
|
||||
// and FramebufferUpdate-header overhead.
|
||||
func BenchmarkWritePixels(b *testing.B) {
|
||||
pf := defaultClientPixelFormat()
|
||||
for _, r := range benchRects {
|
||||
img := makeBenchImage(r.w, r.h, 1)
|
||||
dst := make([]byte, r.w*r.h*4)
|
||||
b.Run(r.name, func(b *testing.B) {
|
||||
b.SetBytes(int64(r.w * r.h * 4))
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
writePixels(dst, img, pf, rect{0, 0, r.w, r.h})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSwizzleBGRAtoRGBA(b *testing.B) {
|
||||
for _, r := range benchRects {
|
||||
size := r.w * r.h * 4
|
||||
src := make([]byte, size)
|
||||
dst := make([]byte, size)
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
_, _ = rng.Read(src)
|
||||
b.Run(r.name, func(b *testing.B) {
|
||||
b.SetBytes(int64(size))
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
swizzleBGRAtoRGBA(dst, src)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSwizzleBGRAtoRGBANaive is the naive byte-by-byte implementation
|
||||
// that the Linux SHM capturer used before the uint32 rewrite, kept here so
|
||||
// we can compare the cost directly.
|
||||
func BenchmarkSwizzleBGRAtoRGBANaive(b *testing.B) {
|
||||
for _, r := range benchRects {
|
||||
size := r.w * r.h * 4
|
||||
src := make([]byte, size)
|
||||
dst := make([]byte, size)
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
_, _ = rng.Read(src)
|
||||
b.Run(r.name, func(b *testing.B) {
|
||||
b.SetBytes(int64(size))
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for j := 0; j < size; j += 4 {
|
||||
dst[j+0] = src[j+2]
|
||||
dst[j+1] = src[j+1]
|
||||
dst[j+2] = src[j+0]
|
||||
dst[j+3] = 0xff
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkEncodeUniformTile_TightFill measures the fast path for a uniform
|
||||
// 64×64 tile via Tight's Fill subencoding (16 wire bytes regardless of size).
|
||||
func BenchmarkEncodeUniformTile_TightFill(b *testing.B) {
|
||||
pf := defaultClientPixelFormat()
|
||||
img := image.NewRGBA(image.Rect(0, 0, 64, 64))
|
||||
for i := 0; i < len(img.Pix); i += 4 {
|
||||
img.Pix[i+0] = 0x33
|
||||
img.Pix[i+1] = 0x66
|
||||
img.Pix[i+2] = 0x99
|
||||
img.Pix[i+3] = 0xff
|
||||
}
|
||||
t := newTightState()
|
||||
b.ReportAllocs()
|
||||
var bytesOut int
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := encodeTightRect(img, pf, 0, 0, 64, 64, t)
|
||||
bytesOut = len(out)
|
||||
}
|
||||
b.ReportMetric(float64(bytesOut), "wire_bytes")
|
||||
}
|
||||
|
||||
func BenchmarkTileIsUniform(b *testing.B) {
|
||||
img := image.NewRGBA(image.Rect(0, 0, 64, 64))
|
||||
for i := 0; i < len(img.Pix); i += 4 {
|
||||
img.Pix[i+3] = 0xff
|
||||
}
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = tileIsUniform(img, 0, 0, 64, 64)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkEncodeManyTilesVsFullFrame exercises the bandwidth + CPU
|
||||
// trade-off that motivates the full-frame promotion path: encoding a burst
|
||||
// of N dirty 64×64 tiles as separate Tight rects vs emitting one big Tight
|
||||
// rect for the whole frame.
|
||||
func BenchmarkEncodeManyTilesVsFullFrame(b *testing.B) {
|
||||
pf := defaultClientPixelFormat()
|
||||
const w, h = 1920, 1080
|
||||
img := makeBenchImage(w, h, 1)
|
||||
|
||||
// Build the list of every tile in the frame (worst case: entire screen dirty).
|
||||
var tiles [][4]int
|
||||
for ty := 0; ty < h; ty += tileSize {
|
||||
th := tileSize
|
||||
if ty+th > h {
|
||||
th = h - ty
|
||||
}
|
||||
for tx := 0; tx < w; tx += tileSize {
|
||||
tw := tileSize
|
||||
if tx+tw > w {
|
||||
tw = w - tx
|
||||
}
|
||||
tiles = append(tiles, [4]int{tx, ty, tw, th})
|
||||
}
|
||||
}
|
||||
nTiles := len(tiles)
|
||||
|
||||
b.Run("per_tile_tight", func(b *testing.B) {
|
||||
t := newTightState()
|
||||
b.SetBytes(int64(w * h * 4))
|
||||
b.ReportAllocs()
|
||||
var totalOut int
|
||||
for i := 0; i < b.N; i++ {
|
||||
totalOut = 0
|
||||
for _, r := range tiles {
|
||||
out := encodeTightRect(img, pf, r[0], r[1], r[2], r[3], t)
|
||||
totalOut += len(out)
|
||||
}
|
||||
}
|
||||
b.ReportMetric(float64(totalOut), "wire_bytes")
|
||||
b.ReportMetric(float64(nTiles), "tiles")
|
||||
})
|
||||
|
||||
b.Run("full_frame_tight", func(b *testing.B) {
|
||||
t := newTightState()
|
||||
b.SetBytes(int64(w * h * 4))
|
||||
b.ReportAllocs()
|
||||
var totalOut int
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := encodeTightRect(img, pf, 0, 0, w, h, t)
|
||||
totalOut = len(out)
|
||||
}
|
||||
b.ReportMetric(float64(totalOut), "wire_bytes")
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkShouldPromoteToFullFrame verifies the threshold check itself is
|
||||
// cheap. It runs on every frame, so regressions here hit all workloads.
|
||||
func BenchmarkShouldPromoteToFullFrame(b *testing.B) {
|
||||
const w, h = 1920, 1080
|
||||
s := &session{serverW: w, serverH: h}
|
||||
// Build a worst-case rect list (every tile dirty, 510 entries).
|
||||
var rects [][4]int
|
||||
for ty := 0; ty < h; ty += tileSize {
|
||||
th := tileSize
|
||||
if ty+th > h {
|
||||
th = h - ty
|
||||
}
|
||||
for tx := 0; tx < w; tx += tileSize {
|
||||
tw := tileSize
|
||||
if tx+tw > w {
|
||||
tw = w - tx
|
||||
}
|
||||
rects = append(rects, [4]int{tx, ty, tw, th})
|
||||
}
|
||||
}
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = s.shouldPromoteToFullFrame(rects)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkEncodeCoalescedVsPerTile compares per-tile encoding vs the
|
||||
// coalesced rect list emitted by diffRects, on a horizontal-band dirty
|
||||
// pattern (e.g. a scrolling status bar) where coalescing pays off.
|
||||
func BenchmarkEncodeCoalescedVsPerTile(b *testing.B) {
|
||||
pf := defaultClientPixelFormat()
|
||||
const w, h = 1920, 1080
|
||||
img := makeBenchImage(w, h, 1)
|
||||
|
||||
// Dirty band: rows 200..264 (one tile-row), full width.
|
||||
var perTile [][4]int
|
||||
for tx := 0; tx < w; tx += tileSize {
|
||||
tw := tileSize
|
||||
if tx+tw > w {
|
||||
tw = w - tx
|
||||
}
|
||||
perTile = append(perTile, [4]int{tx, 200, tw, tileSize})
|
||||
}
|
||||
coalesced := coalesceRects(append([][4]int(nil), perTile...))
|
||||
|
||||
b.Run("per_tile", func(b *testing.B) {
|
||||
t := newTightState()
|
||||
b.ReportAllocs()
|
||||
var bytesOut int
|
||||
for i := 0; i < b.N; i++ {
|
||||
bytesOut = 0
|
||||
for _, r := range perTile {
|
||||
out := encodeTightRect(img, pf, r[0], r[1], r[2], r[3], t)
|
||||
bytesOut += len(out)
|
||||
}
|
||||
}
|
||||
b.ReportMetric(float64(bytesOut), "wire_bytes")
|
||||
b.ReportMetric(float64(len(perTile)), "rects")
|
||||
})
|
||||
|
||||
b.Run("coalesced", func(b *testing.B) {
|
||||
t := newTightState()
|
||||
b.ReportAllocs()
|
||||
var bytesOut int
|
||||
for i := 0; i < b.N; i++ {
|
||||
bytesOut = 0
|
||||
for _, r := range coalesced {
|
||||
out := encodeTightRect(img, pf, r[0], r[1], r[2], r[3], t)
|
||||
bytesOut += len(out)
|
||||
}
|
||||
}
|
||||
b.ReportMetric(float64(bytesOut), "wire_bytes")
|
||||
b.ReportMetric(float64(len(coalesced)), "rects")
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCoalesceRects(b *testing.B) {
|
||||
const w, h = 1920, 1080
|
||||
// Worst case: every tile dirty.
|
||||
var allTiles [][4]int
|
||||
for ty := 0; ty < h; ty += tileSize {
|
||||
th := tileSize
|
||||
if ty+th > h {
|
||||
th = h - ty
|
||||
}
|
||||
for tx := 0; tx < w; tx += tileSize {
|
||||
tw := tileSize
|
||||
if tx+tw > w {
|
||||
tw = w - tx
|
||||
}
|
||||
allTiles = append(allTiles, [4]int{tx, ty, tw, th})
|
||||
}
|
||||
}
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
in := make([][4]int, len(allTiles))
|
||||
copy(in, allTiles)
|
||||
_ = coalesceRects(in)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkEncodeTight_Photo measures Tight on random/photographic content.
|
||||
// The internal sampledColorCount gate routes large many-colour rects to JPEG
|
||||
// at quality 70.
|
||||
func BenchmarkEncodeTight_Photo(b *testing.B) {
|
||||
pf := defaultClientPixelFormat()
|
||||
for _, r := range []struct {
|
||||
name string
|
||||
w, h int
|
||||
}{
|
||||
{"256x256", 256, 256},
|
||||
{"512x512", 512, 512},
|
||||
{"1080p", 1920, 1080},
|
||||
} {
|
||||
img := makeBenchImage(r.w, r.h, 1)
|
||||
b.Run(r.name+"/tight", func(b *testing.B) {
|
||||
t := newTightState()
|
||||
b.SetBytes(int64(r.w * r.h * 4))
|
||||
b.ReportAllocs()
|
||||
var bytesOut int
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := encodeTightRect(img, pf, 0, 0, r.w, r.h, t)
|
||||
bytesOut = len(out)
|
||||
}
|
||||
b.ReportMetric(float64(bytesOut), "wire_bytes")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDiffRects(b *testing.B) {
|
||||
for _, r := range benchRects {
|
||||
prev, cur := makeBenchImagePartial(r.w, r.h, 100)
|
||||
b.Run(r.name, func(b *testing.B) {
|
||||
b.SetBytes(int64(r.w * r.h * 4))
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = diffRects(prev, cur, r.w, r.h, tileSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
274
client/vnc/server/scancodes.go
Normal file
274
client/vnc/server/scancodes.go
Normal file
@@ -0,0 +1,274 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
// QEMU Extended Key Event carries hardware scancodes encoded as PC AT Set 1.
|
||||
// Single-byte codes cover the standard keys; the "extended" prefix 0xE0 is
|
||||
// merged into the high byte (so 0xE048 is the extended-Up arrow). This file
|
||||
// translates those scancodes into the per-platform identifiers each input
|
||||
// backend wants:
|
||||
//
|
||||
// - Linux uinput wants Linux KEY_* codes (defined in
|
||||
// linux/input-event-codes.h). uinput is what we use for virtual Xvfb
|
||||
// sessions on Linux.
|
||||
// - X11 XTest wants XKB keycodes, which on a standard layout equal
|
||||
// Linux KEY_* + 8 (the per-server offset between the Linux event code
|
||||
// and the X server's keycode space).
|
||||
// - Windows SendInput accepts the PC AT scancode directly via
|
||||
// KEYEVENTF_SCANCODE, so no mapping table is needed there; the
|
||||
// extended-key bit is set when the QEMU scancode high byte is 0xE0.
|
||||
// - macOS CGEventCreateKeyboardEvent takes a "virtual keycode" from
|
||||
// Apple's HID set, which is unrelated to PC AT and needs its own
|
||||
// table (see qemuToMacVK in input_darwin.go).
|
||||
//
|
||||
// Linux KEY_* codes. Only the ones we reference, since the full
|
||||
// linux/input-event-codes.h list isn't useful here. Naming mirrors the
|
||||
// existing constants in input_uinput_linux.go (mixed case, no underscores).
|
||||
const (
|
||||
keyEsc = 1
|
||||
key1 = 2
|
||||
key2 = 3
|
||||
key3 = 4
|
||||
key4 = 5
|
||||
key5 = 6
|
||||
key6 = 7
|
||||
key7 = 8
|
||||
key8 = 9
|
||||
key9 = 10
|
||||
key0 = 11
|
||||
keyMinus = 12
|
||||
keyEqual = 13
|
||||
keyBackspace = 14
|
||||
keyTab = 15
|
||||
keyQ = 16
|
||||
keyW = 17
|
||||
keyE = 18
|
||||
keyR = 19
|
||||
keyT = 20
|
||||
keyY = 21
|
||||
keyU = 22
|
||||
keyI = 23
|
||||
keyO = 24
|
||||
keyP = 25
|
||||
keyLeftBracket = 26
|
||||
keyRightBracket = 27
|
||||
keyEnter = 28
|
||||
keyLeftCtrl = 29
|
||||
keyA = 30
|
||||
keyS = 31
|
||||
keyD = 32
|
||||
keyF = 33
|
||||
keyG = 34
|
||||
keyH = 35
|
||||
keyJ = 36
|
||||
keyK = 37
|
||||
keyL = 38
|
||||
keySemicolon = 39
|
||||
keyApostrophe = 40
|
||||
keyGrave = 41
|
||||
keyLeftShift = 42
|
||||
keyBackslash = 43
|
||||
keyZ = 44
|
||||
keyX = 45
|
||||
keyC = 46
|
||||
keyV = 47
|
||||
keyB = 48
|
||||
keyN = 49
|
||||
keyM = 50
|
||||
keyComma = 51
|
||||
keyDot = 52
|
||||
keySlash = 53
|
||||
keyRightShift = 54
|
||||
keyKPAsterisk = 55
|
||||
keyLeftAlt = 56
|
||||
keySpace = 57
|
||||
keyCapsLock = 58
|
||||
keyF1 = 59
|
||||
keyF2 = 60
|
||||
keyF3 = 61
|
||||
keyF4 = 62
|
||||
keyF5 = 63
|
||||
keyF6 = 64
|
||||
keyF7 = 65
|
||||
keyF8 = 66
|
||||
keyF9 = 67
|
||||
keyF10 = 68
|
||||
keyNumLock = 69
|
||||
keyScrollLock = 70
|
||||
keyKP7 = 71
|
||||
keyKP8 = 72
|
||||
keyKP9 = 73
|
||||
keyKPMinus = 74
|
||||
keyKP4 = 75
|
||||
keyKP5 = 76
|
||||
keyKP6 = 77
|
||||
keyKPPlus = 78
|
||||
keyKP1 = 79
|
||||
keyKP2 = 80
|
||||
keyKP3 = 81
|
||||
keyKP0 = 82
|
||||
keyKPDot = 83
|
||||
key102nd = 86
|
||||
keyF11 = 87
|
||||
keyF12 = 88
|
||||
keyKPEnter = 96
|
||||
keyRightCtrl = 97
|
||||
keyKPSlash = 98
|
||||
keySysRq = 99
|
||||
keyRightAlt = 100
|
||||
keyHome = 102
|
||||
keyUp = 103
|
||||
keyPageUp = 104
|
||||
keyLeft = 105
|
||||
keyRight = 106
|
||||
keyEnd = 107
|
||||
keyDown = 108
|
||||
keyPageDown = 109
|
||||
keyInsert = 110
|
||||
keyDelete = 111
|
||||
keyMute = 113
|
||||
keyVolumeDown = 114
|
||||
keyVolumeUp = 115
|
||||
keyLeftMeta = 125
|
||||
keyRightMeta = 126
|
||||
keyCompose = 127
|
||||
)
|
||||
|
||||
// qemuToLinuxKey maps the PC AT Set 1 scancode QEMU sends to a Linux KEY_*
|
||||
// code. The high byte 0xE0 marks "extended" scancodes (arrows, the right-
|
||||
// side modifier keys, keypad enter/divide, browser keys, etc.).
|
||||
//
|
||||
// Keep this table dense so a reviewer sees the whole keyboard at a glance,
|
||||
// and so adding a new key is a single line.
|
||||
var qemuToLinuxKey = map[uint32]int{
|
||||
// Single-byte (non-extended) scancodes.
|
||||
0x01: keyEsc,
|
||||
0x02: key1,
|
||||
0x03: key2,
|
||||
0x04: key3,
|
||||
0x05: key4,
|
||||
0x06: key5,
|
||||
0x07: key6,
|
||||
0x08: key7,
|
||||
0x09: key8,
|
||||
0x0A: key9,
|
||||
0x0B: key0,
|
||||
0x0C: keyMinus,
|
||||
0x0D: keyEqual,
|
||||
0x0E: keyBackspace,
|
||||
0x0F: keyTab,
|
||||
0x10: keyQ,
|
||||
0x11: keyW,
|
||||
0x12: keyE,
|
||||
0x13: keyR,
|
||||
0x14: keyT,
|
||||
0x15: keyY,
|
||||
0x16: keyU,
|
||||
0x17: keyI,
|
||||
0x18: keyO,
|
||||
0x19: keyP,
|
||||
0x1A: keyLeftBracket,
|
||||
0x1B: keyRightBracket,
|
||||
0x1C: keyEnter,
|
||||
0x1D: keyLeftCtrl,
|
||||
0x1E: keyA,
|
||||
0x1F: keyS,
|
||||
0x20: keyD,
|
||||
0x21: keyF,
|
||||
0x22: keyG,
|
||||
0x23: keyH,
|
||||
0x24: keyJ,
|
||||
0x25: keyK,
|
||||
0x26: keyL,
|
||||
0x27: keySemicolon,
|
||||
0x28: keyApostrophe,
|
||||
0x29: keyGrave,
|
||||
0x2A: keyLeftShift,
|
||||
0x2B: keyBackslash,
|
||||
0x2C: keyZ,
|
||||
0x2D: keyX,
|
||||
0x2E: keyC,
|
||||
0x2F: keyV,
|
||||
0x30: keyB,
|
||||
0x31: keyN,
|
||||
0x32: keyM,
|
||||
0x33: keyComma,
|
||||
0x34: keyDot,
|
||||
0x35: keySlash,
|
||||
0x36: keyRightShift,
|
||||
0x37: keyKPAsterisk,
|
||||
0x38: keyLeftAlt,
|
||||
0x39: keySpace,
|
||||
0x3A: keyCapsLock,
|
||||
0x3B: keyF1,
|
||||
0x3C: keyF2,
|
||||
0x3D: keyF3,
|
||||
0x3E: keyF4,
|
||||
0x3F: keyF5,
|
||||
0x40: keyF6,
|
||||
0x41: keyF7,
|
||||
0x42: keyF8,
|
||||
0x43: keyF9,
|
||||
0x44: keyF10,
|
||||
0x45: keyNumLock,
|
||||
0x46: keyScrollLock,
|
||||
0x47: keyKP7,
|
||||
0x48: keyKP8,
|
||||
0x49: keyKP9,
|
||||
0x4A: keyKPMinus,
|
||||
0x4B: keyKP4,
|
||||
0x4C: keyKP5,
|
||||
0x4D: keyKP6,
|
||||
0x4E: keyKPPlus,
|
||||
0x4F: keyKP1,
|
||||
0x50: keyKP2,
|
||||
0x51: keyKP3,
|
||||
0x52: keyKP0,
|
||||
0x53: keyKPDot,
|
||||
0x56: key102nd,
|
||||
0x57: keyF11,
|
||||
0x58: keyF12,
|
||||
|
||||
// Extended (0xE0-prefixed) scancodes.
|
||||
0xE01C: keyKPEnter,
|
||||
0xE01D: keyRightCtrl,
|
||||
0xE020: keyMute,
|
||||
0xE02E: keyVolumeDown,
|
||||
0xE030: keyVolumeUp,
|
||||
0xE035: keyKPSlash,
|
||||
0xE037: keySysRq, // PrintScreen
|
||||
0xE038: keyRightAlt,
|
||||
0xE047: keyHome,
|
||||
0xE048: keyUp,
|
||||
0xE049: keyPageUp,
|
||||
0xE04B: keyLeft,
|
||||
0xE04D: keyRight,
|
||||
0xE04F: keyEnd,
|
||||
0xE050: keyDown,
|
||||
0xE051: keyPageDown,
|
||||
0xE052: keyInsert,
|
||||
0xE053: keyDelete,
|
||||
0xE05B: keyLeftMeta,
|
||||
0xE05C: keyRightMeta,
|
||||
0xE05D: keyCompose,
|
||||
}
|
||||
|
||||
// qemuScancodeToLinuxKey is the lookup the uinput and X11 paths use.
|
||||
// Returns 0 (which Linux treats as KEY_RESERVED) when the scancode has no
|
||||
// mapping, signalling "fall back to the keysym path".
|
||||
func qemuScancodeToLinuxKey(scancode uint32) int {
|
||||
return qemuToLinuxKey[scancode]
|
||||
}
|
||||
|
||||
// qemuScancodeIsExtended reports whether a QEMU scancode is in the
|
||||
// 0xE0-prefixed extended range. Used by Windows SendInput to set the
|
||||
// KEYEVENTF_EXTENDEDKEY flag.
|
||||
func qemuScancodeIsExtended(scancode uint32) bool {
|
||||
return scancode&0xFF00 == 0xE000
|
||||
}
|
||||
|
||||
// qemuScancodeLowByte returns the byte SendInput's wScan field actually
|
||||
// stores: the low byte of the scancode regardless of any extended prefix.
|
||||
func qemuScancodeLowByte(scancode uint32) uint16 {
|
||||
return uint16(scancode & 0xFF)
|
||||
}
|
||||
238
client/vnc/server/scancodes_darwin.go
Normal file
238
client/vnc/server/scancodes_darwin.go
Normal file
@@ -0,0 +1,238 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
// Apple keyboard virtual-key codes used with CGEventCreateKeyboardEvent.
|
||||
// These are the kVK_ANSI_* / kVK_* values from Apple's
|
||||
// HIToolbox/Events.h; reproduced here so we don't need to drag in the
|
||||
// HIToolbox framework just for the constants.
|
||||
const (
|
||||
macKeyA uint16 = 0x00
|
||||
macKeyS uint16 = 0x01
|
||||
macKeyD uint16 = 0x02
|
||||
macKeyF uint16 = 0x03
|
||||
macKeyH uint16 = 0x04
|
||||
macKeyG uint16 = 0x05
|
||||
macKeyZ uint16 = 0x06
|
||||
macKeyX uint16 = 0x07
|
||||
macKeyC uint16 = 0x08
|
||||
macKeyV uint16 = 0x09
|
||||
macKeyNonUSBackslash uint16 = 0x0A // ISO_Section / 102nd
|
||||
macKeyB uint16 = 0x0B
|
||||
macKeyQ uint16 = 0x0C
|
||||
macKeyW uint16 = 0x0D
|
||||
macKeyE uint16 = 0x0E
|
||||
macKeyR uint16 = 0x0F
|
||||
macKeyY uint16 = 0x10
|
||||
macKeyT uint16 = 0x11
|
||||
macKey1 uint16 = 0x12
|
||||
macKey2 uint16 = 0x13
|
||||
macKey3 uint16 = 0x14
|
||||
macKey4 uint16 = 0x15
|
||||
macKey6 uint16 = 0x16
|
||||
macKey5 uint16 = 0x17
|
||||
macKeyEqual uint16 = 0x18
|
||||
macKey9 uint16 = 0x19
|
||||
macKey7 uint16 = 0x1A
|
||||
macKeyMinus uint16 = 0x1B
|
||||
macKey8 uint16 = 0x1C
|
||||
macKey0 uint16 = 0x1D
|
||||
macKeyRightBracket uint16 = 0x1E
|
||||
macKeyO uint16 = 0x1F
|
||||
macKeyU uint16 = 0x20
|
||||
macKeyLeftBracket uint16 = 0x21
|
||||
macKeyI uint16 = 0x22
|
||||
macKeyP uint16 = 0x23
|
||||
macKeyReturn uint16 = 0x24
|
||||
macKeyL uint16 = 0x25
|
||||
macKeyJ uint16 = 0x26
|
||||
macKeyApostrophe uint16 = 0x27
|
||||
macKeyK uint16 = 0x28
|
||||
macKeySemicolon uint16 = 0x29
|
||||
macKeyBackslash uint16 = 0x2A
|
||||
macKeyComma uint16 = 0x2B
|
||||
macKeySlash uint16 = 0x2C
|
||||
macKeyN uint16 = 0x2D
|
||||
macKeyM uint16 = 0x2E
|
||||
macKeyPeriod uint16 = 0x2F
|
||||
macKeyTab uint16 = 0x30
|
||||
macKeySpace uint16 = 0x31
|
||||
macKeyGrave uint16 = 0x32
|
||||
macKeyDelete uint16 = 0x33 // Backspace
|
||||
macKeyEscape uint16 = 0x35
|
||||
macKeyCommand uint16 = 0x37
|
||||
macKeyShift uint16 = 0x38
|
||||
macKeyCapsLock uint16 = 0x39
|
||||
macKeyOption uint16 = 0x3A // Alt
|
||||
macKeyControl uint16 = 0x3B
|
||||
macKeyRightShift uint16 = 0x3C
|
||||
macKeyRightOption uint16 = 0x3D
|
||||
macKeyRightControl uint16 = 0x3E
|
||||
macKeyFunction uint16 = 0x3F
|
||||
macKeyF17 uint16 = 0x40
|
||||
macKeyKPDecimal uint16 = 0x41
|
||||
macKeyKPMultiply uint16 = 0x43
|
||||
macKeyKPPlus uint16 = 0x45
|
||||
macKeyKPClear uint16 = 0x47 // numlock
|
||||
macKeyVolumeUp uint16 = 0x48
|
||||
macKeyVolumeDown uint16 = 0x49
|
||||
macKeyMute uint16 = 0x4A
|
||||
macKeyKPDivide uint16 = 0x4B
|
||||
macKeyKPEnter uint16 = 0x4C
|
||||
macKeyKPMinus uint16 = 0x4E
|
||||
macKeyF18 uint16 = 0x4F
|
||||
macKeyF19 uint16 = 0x50
|
||||
macKeyKPEqual uint16 = 0x51
|
||||
macKeyKP0 uint16 = 0x52
|
||||
macKeyKP1 uint16 = 0x53
|
||||
macKeyKP2 uint16 = 0x54
|
||||
macKeyKP3 uint16 = 0x55
|
||||
macKeyKP4 uint16 = 0x56
|
||||
macKeyKP5 uint16 = 0x57
|
||||
macKeyKP6 uint16 = 0x58
|
||||
macKeyKP7 uint16 = 0x59
|
||||
macKeyF20 uint16 = 0x5A
|
||||
macKeyKP8 uint16 = 0x5B
|
||||
macKeyKP9 uint16 = 0x5C
|
||||
macKeyF5 uint16 = 0x60
|
||||
macKeyF6 uint16 = 0x61
|
||||
macKeyF7 uint16 = 0x62
|
||||
macKeyF3 uint16 = 0x63
|
||||
macKeyF8 uint16 = 0x64
|
||||
macKeyF9 uint16 = 0x65
|
||||
macKeyF11 uint16 = 0x67
|
||||
macKeyF13 uint16 = 0x69 // PrintScreen on most layouts
|
||||
macKeyF16 uint16 = 0x6A
|
||||
macKeyF14 uint16 = 0x6B
|
||||
macKeyF10 uint16 = 0x6D
|
||||
macKeyF12 uint16 = 0x6F
|
||||
macKeyF15 uint16 = 0x71
|
||||
macKeyHelp uint16 = 0x72 // Insert on PC keyboards
|
||||
macKeyHome uint16 = 0x73
|
||||
macKeyPageUp uint16 = 0x74
|
||||
macKeyForwardDelete uint16 = 0x75
|
||||
macKeyF4 uint16 = 0x76
|
||||
macKeyEnd uint16 = 0x77
|
||||
macKeyF2 uint16 = 0x78
|
||||
macKeyPageDown uint16 = 0x79
|
||||
macKeyF1 uint16 = 0x7A
|
||||
macKeyLeft uint16 = 0x7B
|
||||
macKeyRight uint16 = 0x7C
|
||||
macKeyDown uint16 = 0x7D
|
||||
macKeyUp uint16 = 0x7E
|
||||
)
|
||||
|
||||
// qemuToMacVK maps PC AT Set 1 scancodes (as QEMU emits them, with the
|
||||
// 0xE0 prefix merged into the high byte) onto Apple virtual-key codes.
|
||||
// Layout-independent: the scancode names the physical key, the user's
|
||||
// active keyboard layout on the Mac decides what the key produces.
|
||||
var qemuToMacVK = map[uint32]uint16{
|
||||
// Single-byte (non-extended).
|
||||
0x01: macKeyEscape,
|
||||
0x02: macKey1,
|
||||
0x03: macKey2,
|
||||
0x04: macKey3,
|
||||
0x05: macKey4,
|
||||
0x06: macKey5,
|
||||
0x07: macKey6,
|
||||
0x08: macKey7,
|
||||
0x09: macKey8,
|
||||
0x0A: macKey9,
|
||||
0x0B: macKey0,
|
||||
0x0C: macKeyMinus,
|
||||
0x0D: macKeyEqual,
|
||||
0x0E: macKeyDelete, // PC Backspace -> mac "Delete"
|
||||
0x0F: macKeyTab,
|
||||
0x10: macKeyQ,
|
||||
0x11: macKeyW,
|
||||
0x12: macKeyE,
|
||||
0x13: macKeyR,
|
||||
0x14: macKeyT,
|
||||
0x15: macKeyY,
|
||||
0x16: macKeyU,
|
||||
0x17: macKeyI,
|
||||
0x18: macKeyO,
|
||||
0x19: macKeyP,
|
||||
0x1A: macKeyLeftBracket,
|
||||
0x1B: macKeyRightBracket,
|
||||
0x1C: macKeyReturn,
|
||||
0x1D: macKeyControl,
|
||||
0x1E: macKeyA,
|
||||
0x1F: macKeyS,
|
||||
0x20: macKeyD,
|
||||
0x21: macKeyF,
|
||||
0x22: macKeyG,
|
||||
0x23: macKeyH,
|
||||
0x24: macKeyJ,
|
||||
0x25: macKeyK,
|
||||
0x26: macKeyL,
|
||||
0x27: macKeySemicolon,
|
||||
0x28: macKeyApostrophe,
|
||||
0x29: macKeyGrave,
|
||||
0x2A: macKeyShift,
|
||||
0x2B: macKeyBackslash,
|
||||
0x2C: macKeyZ,
|
||||
0x2D: macKeyX,
|
||||
0x2E: macKeyC,
|
||||
0x2F: macKeyV,
|
||||
0x30: macKeyB,
|
||||
0x31: macKeyN,
|
||||
0x32: macKeyM,
|
||||
0x33: macKeyComma,
|
||||
0x34: macKeyPeriod,
|
||||
0x35: macKeySlash,
|
||||
0x36: macKeyRightShift,
|
||||
0x37: macKeyKPMultiply,
|
||||
0x38: macKeyOption, // Left Alt -> Option
|
||||
0x39: macKeySpace,
|
||||
0x3A: macKeyCapsLock,
|
||||
0x3B: macKeyF1,
|
||||
0x3C: macKeyF2,
|
||||
0x3D: macKeyF3,
|
||||
0x3E: macKeyF4,
|
||||
0x3F: macKeyF5,
|
||||
0x40: macKeyF6,
|
||||
0x41: macKeyF7,
|
||||
0x42: macKeyF8,
|
||||
0x43: macKeyF9,
|
||||
0x44: macKeyF10,
|
||||
0x45: macKeyKPClear, // PC NumLock -> mac Clear
|
||||
0x47: macKeyKP7,
|
||||
0x48: macKeyKP8,
|
||||
0x49: macKeyKP9,
|
||||
0x4A: macKeyKPMinus,
|
||||
0x4B: macKeyKP4,
|
||||
0x4C: macKeyKP5,
|
||||
0x4D: macKeyKP6,
|
||||
0x4E: macKeyKPPlus,
|
||||
0x4F: macKeyKP1,
|
||||
0x50: macKeyKP2,
|
||||
0x51: macKeyKP3,
|
||||
0x52: macKeyKP0,
|
||||
0x53: macKeyKPDecimal,
|
||||
0x56: macKeyNonUSBackslash,
|
||||
0x57: macKeyF11,
|
||||
0x58: macKeyF12,
|
||||
|
||||
// Extended (0xE0 prefix).
|
||||
0xE01C: macKeyKPEnter,
|
||||
0xE01D: macKeyRightControl,
|
||||
0xE020: macKeyMute,
|
||||
0xE02E: macKeyVolumeDown,
|
||||
0xE030: macKeyVolumeUp,
|
||||
0xE035: macKeyKPDivide,
|
||||
0xE037: macKeyF13, // PrintScreen
|
||||
0xE038: macKeyRightOption,
|
||||
0xE047: macKeyHome,
|
||||
0xE048: macKeyUp,
|
||||
0xE049: macKeyPageUp,
|
||||
0xE04B: macKeyLeft,
|
||||
0xE04D: macKeyRight,
|
||||
0xE04F: macKeyEnd,
|
||||
0xE050: macKeyDown,
|
||||
0xE051: macKeyPageDown,
|
||||
0xE052: macKeyHelp, // PC Insert -> mac Help
|
||||
0xE053: macKeyForwardDelete,
|
||||
0xE05B: macKeyCommand, // Left Windows -> Command
|
||||
0xE05C: macKeyCommand, // Right Windows -> Command (no separate code)
|
||||
}
|
||||
100
client/vnc/server/scancodes_test.go
Normal file
100
client/vnc/server/scancodes_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestQemuScancodeToLinuxKey_KnownLetters(t *testing.T) {
|
||||
// Spot-check a few familiar letter keys against the Linux KEY_*
|
||||
// values they're supposed to land on.
|
||||
tests := []struct {
|
||||
name string
|
||||
scancode uint32
|
||||
want int
|
||||
}{
|
||||
{"A", 0x1E, keyA},
|
||||
{"S", 0x1F, keyS},
|
||||
{"D", 0x20, keyD},
|
||||
{"Q", 0x10, keyQ},
|
||||
{"Z", 0x2C, keyZ},
|
||||
{"1", 0x02, key1},
|
||||
{"Esc", 0x01, keyEsc},
|
||||
{"Tab", 0x0F, keyTab},
|
||||
{"Space", 0x39, keySpace},
|
||||
{"LeftShift", 0x2A, keyLeftShift},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got := qemuScancodeToLinuxKey(tc.scancode)
|
||||
if got != tc.want {
|
||||
t.Errorf("%s: scancode 0x%X => %d, want %d", tc.name, tc.scancode, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQemuScancodeToLinuxKey_Extended(t *testing.T) {
|
||||
// Extended (0xE0-prefixed) scancodes for arrow + navigation cluster.
|
||||
tests := []struct {
|
||||
name string
|
||||
scancode uint32
|
||||
want int
|
||||
}{
|
||||
{"Up", 0xE048, keyUp},
|
||||
{"Down", 0xE050, keyDown},
|
||||
{"Left", 0xE04B, keyLeft},
|
||||
{"Right", 0xE04D, keyRight},
|
||||
{"Home", 0xE047, keyHome},
|
||||
{"End", 0xE04F, keyEnd},
|
||||
{"PageUp", 0xE049, keyPageUp},
|
||||
{"PageDown", 0xE051, keyPageDown},
|
||||
{"Insert", 0xE052, keyInsert},
|
||||
{"Delete", 0xE053, keyDelete},
|
||||
{"RightCtrl", 0xE01D, keyRightCtrl},
|
||||
{"RightAlt", 0xE038, keyRightAlt},
|
||||
{"KPEnter", 0xE01C, keyKPEnter},
|
||||
{"KPSlash", 0xE035, keyKPSlash},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got := qemuScancodeToLinuxKey(tc.scancode)
|
||||
if got != tc.want {
|
||||
t.Errorf("%s: scancode 0x%X => %d, want %d", tc.name, tc.scancode, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQemuScancodeToLinuxKey_Miss(t *testing.T) {
|
||||
// 0xE0FF is in the extended range but not a real key. Must return 0
|
||||
// so the caller can fall back to the keysym path.
|
||||
if got := qemuScancodeToLinuxKey(0xE0FF); got != 0 {
|
||||
t.Errorf("unknown scancode should miss: got %d, want 0", got)
|
||||
}
|
||||
if got := qemuScancodeToLinuxKey(0xFF); got != 0 {
|
||||
t.Errorf("unknown non-extended scancode should miss: got %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQemuScancodeIsExtended(t *testing.T) {
|
||||
cases := []struct {
|
||||
scancode uint32
|
||||
want bool
|
||||
}{
|
||||
{0x1E, false},
|
||||
{0xE048, true},
|
||||
{0xE000, true},
|
||||
{0xFF, false},
|
||||
{0xE0FF, true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if got := qemuScancodeIsExtended(tc.scancode); got != tc.want {
|
||||
t.Errorf("isExtended(0x%X) = %v, want %v", tc.scancode, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQemuScancodeLowByte(t *testing.T) {
|
||||
if got := qemuScancodeLowByte(0xE048); got != 0x48 {
|
||||
t.Errorf("lowByte(0xE048) = 0x%X, want 0x48", got)
|
||||
}
|
||||
if got := qemuScancodeLowByte(0x1E); got != 0x1E {
|
||||
t.Errorf("lowByte(0x1E) = 0x%X, want 0x1E", got)
|
||||
}
|
||||
}
|
||||
1054
client/vnc/server/server.go
Normal file
1054
client/vnc/server/server.go
Normal file
File diff suppressed because it is too large
Load Diff
119
client/vnc/server/server_darwin.go
Normal file
119
client/vnc/server/server_darwin.go
Normal file
@@ -0,0 +1,119 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (s *Server) platformInit() {
|
||||
// no-op on macOS
|
||||
}
|
||||
|
||||
func (s *Server) platformShutdown() {
|
||||
// no-op on macOS
|
||||
}
|
||||
|
||||
func (s *Server) platformSessionManager() virtualSessionManager {
|
||||
return nil
|
||||
}
|
||||
|
||||
// serviceAcceptLoop runs in a LaunchDaemon and proxies each VNC
|
||||
// connection to a per-user agent. The agent is spawned lazily on the
|
||||
// first connection (and respawned after a console-user change) via
|
||||
// launchctl asuser, which is the only mechanism that lands a child
|
||||
// inside the user's Aqua session, where WindowServer and TCC grants
|
||||
// for screen capture work.
|
||||
func (s *Server) serviceAcceptLoop() {
|
||||
mgr := newDarwinAgentManager(s.ctx)
|
||||
defer mgr.stop()
|
||||
|
||||
log.Infof("service mode, proxying connections to per-user agent on 127.0.0.1:%d", agentPort)
|
||||
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
s.log.Debugf("accept VNC connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.tryAcquireConnSlot() {
|
||||
s.log.Warnf("rejecting VNC connection from %s: %d concurrent connections in flight", conn.RemoteAddr(), maxConcurrentVNCConns)
|
||||
_ = conn.Close()
|
||||
continue
|
||||
}
|
||||
enableTCPKeepAlive(conn, s.log)
|
||||
conn = newMetricsConn(conn, s.sessionRecorder)
|
||||
s.trackConn(conn)
|
||||
go func(c net.Conn) {
|
||||
defer s.releaseConnSlot()
|
||||
defer s.untrackConn(c)
|
||||
s.handleServiceConnectionDarwin(c, mgr)
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleServiceConnectionDarwin(conn net.Conn, mgr *darwinAgentManager) {
|
||||
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
|
||||
|
||||
if !s.isAllowedSource(conn.RemoteAddr()) {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var headerBuf bytes.Buffer
|
||||
tee := io.TeeReader(conn, &headerBuf)
|
||||
teeConn := &darwinPrefixConn{Reader: tee, Conn: conn}
|
||||
|
||||
header, err := s.readConnectionHeader(teeConn)
|
||||
if err != nil {
|
||||
connLog.Debugf("read connection header: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if !s.disableAuth {
|
||||
if _, err := s.authenticateSession(header); err != nil {
|
||||
rejectConnection(conn, codeMessage(RejectCodeAuthForbidden, err.Error()))
|
||||
connLog.Warnf("auth rejected: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
s.registerConnAuth(conn, header)
|
||||
|
||||
token, err := mgr.ensure(s.ctx)
|
||||
if err != nil {
|
||||
code := RejectCodeCapturerError
|
||||
if errors.Is(err, errNoConsoleUser) {
|
||||
code = RejectCodeNoConsoleUser
|
||||
}
|
||||
rejectConnection(conn, codeMessage(code, err.Error()))
|
||||
connLog.Warnf("spawn per-user agent: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
replayConn := &darwinPrefixConn{
|
||||
Reader: io.MultiReader(&headerBuf, conn),
|
||||
Conn: conn,
|
||||
}
|
||||
proxyToAgent(s.ctx, replayConn, agentPort, token)
|
||||
}
|
||||
|
||||
// darwinPrefixConn replays the already-consumed connection-header bytes
|
||||
// in front of the proxy stream, mirroring the Windows prefixConn shape.
|
||||
type darwinPrefixConn struct {
|
||||
io.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (p *darwinPrefixConn) Read(b []byte) (int, error) { return p.Reader.Read(b) }
|
||||
318
client/vnc/server/server_test.go
Normal file
318
client/vnc/server/server_test.go
Normal file
@@ -0,0 +1,318 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"image"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testCapturer returns a 100x100 image for test sessions.
|
||||
type testCapturer struct{}
|
||||
|
||||
func (t *testCapturer) Width() int { return 100 }
|
||||
func (t *testCapturer) Height() int { return 100 }
|
||||
func (t *testCapturer) Capture() (*image.RGBA, error) {
|
||||
return image.NewRGBA(image.Rect(0, 0, 100, 100)), nil
|
||||
}
|
||||
|
||||
func startTestServer(t *testing.T, disableAuth bool) (net.Addr, *Server) {
|
||||
t.Helper()
|
||||
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(disableAuth)
|
||||
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
require.NoError(t, srv.Start(t.Context(), addr, network))
|
||||
// Override local address so source validation doesn't reject 127.0.0.1 as "own IP".
|
||||
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
||||
t.Cleanup(func() { _ = srv.Stop() })
|
||||
|
||||
return srv.listener.Addr(), srv
|
||||
}
|
||||
|
||||
func TestAuthEnabled_NoSessionAuth_RejectsConnection(t *testing.T) {
|
||||
addr, _ := startTestServer(t, false)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Header with no Noise handshake. Auth-required servers must reject
|
||||
// because no client static was authenticated.
|
||||
header := make([]byte, 11) // mode + usernameLen + sessionID + w + h
|
||||
header[0] = ModeAttach
|
||||
_, err = conn.Write(header)
|
||||
require.NoError(t, err)
|
||||
|
||||
var version [12]byte
|
||||
_, err = io.ReadFull(conn, version[:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "RFB 003.008\n", string(version[:]))
|
||||
|
||||
_, err = conn.Write(version[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
var numTypes [1]byte
|
||||
_, err = io.ReadFull(conn, numTypes[:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, byte(0), numTypes[0], "should have 0 security types (failure)")
|
||||
|
||||
var reasonLen [4]byte
|
||||
_, err = io.ReadFull(conn, reasonLen[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:]))
|
||||
_, err = io.ReadFull(conn, reason)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(reason), "identity proof missing", "rejection reason should mention missing identity proof")
|
||||
}
|
||||
|
||||
func TestAuthDisabled_AllowsConnection(t *testing.T) {
|
||||
addr, _ := startTestServer(t, true)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
header := make([]byte, 11) // mode + usernameLen + sessionID + w + h
|
||||
header[0] = ModeAttach
|
||||
_, err = conn.Write(header)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Server should send RFB version.
|
||||
var version [12]byte
|
||||
_, err = io.ReadFull(conn, version[:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "RFB 003.008\n", string(version[:]))
|
||||
|
||||
// Write client version.
|
||||
_, err = conn.Write(version[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should get security types (not 0 = failure).
|
||||
var numTypes [1]byte
|
||||
_, err = io.ReadFull(conn, numTypes[:])
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, byte(0), numTypes[0], "should have at least one security type (auth disabled)")
|
||||
}
|
||||
|
||||
// TestAuth_NoUnauthBytesPastHeader proves the server does not send any RFB
|
||||
// content to a connection that fails source validation. Specifically, the
|
||||
// server must close immediately and the client must see EOF before any RFB
|
||||
// version greeting is written.
|
||||
func TestAuth_NoUnauthBytesPastHeader(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(true)
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
// Tight overlay that excludes 127.0.0.0/8 and a non-loopback local IP, so
|
||||
// the loopback short-circuit in isAllowedSource doesn't apply.
|
||||
require.NoError(t, srv.Start(t.Context(), addr, netip.MustParsePrefix("10.99.0.0/16")))
|
||||
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
||||
t.Cleanup(func() { _ = srv.Stop() })
|
||||
|
||||
conn, err := net.Dial("tcp", srv.listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
require.NoError(t, conn.SetDeadline(time.Now().Add(5*time.Second)))
|
||||
|
||||
// Reading even one byte must EOF: the source IP (127.0.0.1) is outside
|
||||
// the configured overlay, so handleConnection closes before writing.
|
||||
var b [1]byte
|
||||
_, err = io.ReadFull(conn, b[:])
|
||||
require.Error(t, err, "non-overlay client must see EOF, not an RFB greeting")
|
||||
}
|
||||
|
||||
func TestIsAllowedSource(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
localAddr netip.Addr
|
||||
network netip.Prefix
|
||||
remote net.Addr
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "non-tcp address rejected",
|
||||
localAddr: netip.MustParseAddr("10.99.99.1"),
|
||||
network: netip.MustParsePrefix("10.99.0.0/16"),
|
||||
remote: &net.UDPAddr{IP: net.ParseIP("10.99.99.2"), Port: 1234},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "own IP rejected",
|
||||
localAddr: netip.MustParseAddr("10.99.99.1"),
|
||||
network: netip.MustParsePrefix("10.99.0.0/16"),
|
||||
remote: &net.TCPAddr{IP: net.ParseIP("10.99.99.1"), Port: 5900},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "non-overlay IP rejected",
|
||||
localAddr: netip.MustParseAddr("10.99.99.1"),
|
||||
network: netip.MustParsePrefix("10.99.0.0/16"),
|
||||
remote: &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 5900},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "overlay IP allowed",
|
||||
localAddr: netip.MustParseAddr("10.99.99.1"),
|
||||
network: netip.MustParsePrefix("10.99.0.0/16"),
|
||||
remote: &net.TCPAddr{IP: net.ParseIP("10.99.99.2"), Port: 5900},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "v4-mapped v6 in overlay allowed (unmapped)",
|
||||
localAddr: netip.MustParseAddr("10.99.99.1"),
|
||||
network: netip.MustParsePrefix("10.99.0.0/16"),
|
||||
remote: &net.TCPAddr{IP: net.ParseIP("::ffff:10.99.99.2"), Port: 5900},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "loopback allowed only when local is loopback",
|
||||
localAddr: netip.MustParseAddr("127.0.0.1"),
|
||||
network: netip.MustParsePrefix("127.0.0.0/8"),
|
||||
remote: &net.TCPAddr{IP: net.ParseIP("127.0.0.5"), Port: 5900},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "invalid network rejected (fail-closed)",
|
||||
localAddr: netip.MustParseAddr("10.99.99.1"),
|
||||
network: netip.Prefix{},
|
||||
remote: &net.TCPAddr{IP: net.ParseIP("10.99.99.2"), Port: 5900},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.localAddr = tc.localAddr
|
||||
srv.network = tc.network
|
||||
assert.Equal(t, tc.want, srv.isAllowedSource(tc.remote))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStart_InvalidNetworkRejected(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
err := srv.Start(t.Context(), addr, netip.Prefix{})
|
||||
require.Error(t, err, "Start must refuse an invalid overlay prefix")
|
||||
assert.Contains(t, err.Error(), "invalid overlay network prefix")
|
||||
}
|
||||
|
||||
func TestAgentToken_MismatchClosesConnection(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(true)
|
||||
srv.SetAgentToken("deadbeefcafebabe")
|
||||
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
require.NoError(t, srv.Start(t.Context(), addr, network))
|
||||
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
||||
t.Cleanup(func() { _ = srv.Stop() })
|
||||
|
||||
conn, err := net.Dial("tcp", srv.listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second)))
|
||||
|
||||
// Send a wrong token of the right length (8 bytes hex-decoded).
|
||||
if _, err := conn.Write([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}); err != nil {
|
||||
// Server may already have closed; either way the read below must EOF.
|
||||
_ = err
|
||||
}
|
||||
|
||||
// Server must close without sending the RFB greeting.
|
||||
var version [12]byte
|
||||
_, err = io.ReadFull(conn, version[:])
|
||||
require.Error(t, err, "server must close the connection on bad agent token")
|
||||
}
|
||||
|
||||
func TestAgentToken_MatchAllowsHandshake(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(true)
|
||||
const tokenHex = "deadbeefcafebabe"
|
||||
srv.SetAgentToken(tokenHex)
|
||||
token, err := hex.DecodeString(tokenHex)
|
||||
require.NoError(t, err)
|
||||
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
require.NoError(t, srv.Start(t.Context(), addr, network))
|
||||
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
||||
t.Cleanup(func() { _ = srv.Stop() })
|
||||
|
||||
conn, err := net.Dial("tcp", srv.listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second)))
|
||||
|
||||
_, err = conn.Write(token)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send session header so handleConnection can proceed past readConnectionHeader.
|
||||
header := make([]byte, 11) // ModeAttach + usernameLen=0 + sessionID=0 + width=0 + height=0
|
||||
header[0] = ModeAttach
|
||||
_, err = conn.Write(header)
|
||||
require.NoError(t, err)
|
||||
|
||||
// With a matching token the server proceeds to the RFB greeting.
|
||||
var version [12]byte
|
||||
_, err = io.ReadFull(conn, version[:])
|
||||
require.NoError(t, err, "server must keep the connection open after a valid agent token")
|
||||
assert.Equal(t, "RFB 003.008\n", string(version[:]))
|
||||
}
|
||||
|
||||
func TestSessionMode_RejectedWhenNoVMGR(t *testing.T) {
|
||||
// Default platformSessionManager() on non-Linux returns nil, so ModeSession
|
||||
// must be rejected with the UNSUPPORTED reason rather than crashing.
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(true)
|
||||
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
require.NoError(t, srv.Start(t.Context(), addr, network))
|
||||
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
||||
// Force vmgr to nil regardless of platform so the test is deterministic.
|
||||
srv.vmgr = nil
|
||||
t.Cleanup(func() { _ = srv.Stop() })
|
||||
|
||||
conn, err := net.Dial("tcp", srv.listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second)))
|
||||
|
||||
// ModeSession with no username, so we exit on the vmgr==nil branch
|
||||
// before username validation runs.
|
||||
header := []byte{ModeSession, 0, 0, 0, 0}
|
||||
_, err = conn.Write(header)
|
||||
require.NoError(t, err)
|
||||
|
||||
var version [12]byte
|
||||
_, err = io.ReadFull(conn, version[:])
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Write(version[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
var numTypes [1]byte
|
||||
_, err = io.ReadFull(conn, numTypes[:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, byte(0), numTypes[0])
|
||||
|
||||
var reasonLen [4]byte
|
||||
_, err = io.ReadFull(conn, reasonLen[:])
|
||||
require.NoError(t, err)
|
||||
reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:]))
|
||||
_, err = io.ReadFull(conn, reason)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(reason), RejectCodeUnsupportedOS)
|
||||
}
|
||||
322
client/vnc/server/server_windows.go
Normal file
322
client/vnc/server/server_windows.go
Normal file
@@ -0,0 +1,322 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
var (
|
||||
sasDLL = windows.NewLazySystemDLL("sas.dll")
|
||||
procSendSAS = sasDLL.NewProc("SendSAS")
|
||||
|
||||
procConvertStringSecurityDescriptorToSecurityDescriptor = advapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
|
||||
)
|
||||
|
||||
// sasSecurityAttributes builds a SECURITY_ATTRIBUTES that grants
|
||||
// EVENT_MODIFY_STATE only to the SYSTEM account, preventing unprivileged
|
||||
// local processes from triggering the Secure Attention Sequence.
|
||||
func sasSecurityAttributes() (*windows.SecurityAttributes, error) {
|
||||
// SDDL: grant full access to SYSTEM (creates/waits) and EVENT_MODIFY_STATE
|
||||
// to the interactive user (IU) so the VNC agent in the console session can
|
||||
// signal it. Other local users and network users are denied.
|
||||
sddl, err := windows.UTF16PtrFromString("D:(A;;GA;;;SY)(A;;0x0002;;;IU)")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var sd uintptr
|
||||
r, _, lerr := procConvertStringSecurityDescriptorToSecurityDescriptor.Call(
|
||||
uintptr(unsafe.Pointer(sddl)),
|
||||
1, // SDDL_REVISION_1
|
||||
uintptr(unsafe.Pointer(&sd)),
|
||||
0,
|
||||
)
|
||||
if r == 0 {
|
||||
return nil, lerr
|
||||
}
|
||||
return &windows.SecurityAttributes{
|
||||
Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})),
|
||||
SecurityDescriptor: (*windows.SECURITY_DESCRIPTOR)(unsafe.Pointer(sd)),
|
||||
InheritHandle: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// sasOriginalState tracks the SoftwareSASGeneration value present before we
|
||||
// changed it, so disableSoftwareSAS can restore the machine to its prior
|
||||
// state on shutdown instead of leaving the policy enabled.
|
||||
type sasOriginalState struct {
|
||||
had bool // true if the value existed before we wrote
|
||||
value uint32 // its prior DWORD value, if had == true
|
||||
}
|
||||
|
||||
var savedSASState sasOriginalState
|
||||
|
||||
// enableSoftwareSAS sets the SoftwareSASGeneration registry key to allow
|
||||
// services to trigger the Secure Attention Sequence via SendSAS. Without this,
|
||||
// SendSAS silently does nothing on most Windows editions. The original value
|
||||
// is snapshotted so disableSoftwareSAS can put the system back as it was.
|
||||
func enableSoftwareSAS() {
|
||||
key, _, err := registry.CreateKey(
|
||||
registry.LOCAL_MACHINE,
|
||||
`SOFTWARE\Microsoft\Windows\CurrentVersion\Policies\System`,
|
||||
registry.SET_VALUE|registry.QUERY_VALUE,
|
||||
)
|
||||
if err != nil {
|
||||
log.Warnf("open SoftwareSASGeneration registry key: %v", err)
|
||||
return
|
||||
}
|
||||
defer key.Close()
|
||||
|
||||
if prev, _, err := key.GetIntegerValue("SoftwareSASGeneration"); err == nil {
|
||||
savedSASState = sasOriginalState{had: true, value: uint32(prev)}
|
||||
} else {
|
||||
savedSASState = sasOriginalState{had: false}
|
||||
}
|
||||
|
||||
if err := key.SetDWordValue("SoftwareSASGeneration", 1); err != nil {
|
||||
log.Warnf("set SoftwareSASGeneration: %v", err)
|
||||
return
|
||||
}
|
||||
log.Debug("SoftwareSASGeneration registry key set to 1 (services allowed)")
|
||||
}
|
||||
|
||||
// disableSoftwareSAS restores the SoftwareSASGeneration value to its
|
||||
// pre-enable state. Idempotent; safe to call when enableSoftwareSAS never ran.
|
||||
func disableSoftwareSAS() {
|
||||
key, err := registry.OpenKey(
|
||||
registry.LOCAL_MACHINE,
|
||||
`SOFTWARE\Microsoft\Windows\CurrentVersion\Policies\System`,
|
||||
registry.SET_VALUE,
|
||||
)
|
||||
if err != nil {
|
||||
log.Debugf("open SoftwareSASGeneration for restore: %v", err)
|
||||
return
|
||||
}
|
||||
defer key.Close()
|
||||
|
||||
if savedSASState.had {
|
||||
if err := key.SetDWordValue("SoftwareSASGeneration", savedSASState.value); err != nil {
|
||||
log.Warnf("restore SoftwareSASGeneration to %d: %v", savedSASState.value, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := key.DeleteValue("SoftwareSASGeneration"); err != nil {
|
||||
log.Debugf("delete SoftwareSASGeneration: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// startSASListener creates a named event with a restricted DACL and waits for
|
||||
// the VNC input injector to signal it. When signaled, it calls SendSAS(FALSE)
|
||||
// from Session 0 to trigger the Secure Attention Sequence (Ctrl+Alt+Del).
|
||||
// Only SYSTEM processes can open the event.
|
||||
//
|
||||
// sas.dll / SendSAS is part of the Desktop Experience feature: present on
|
||||
// client SKUs (Win10/11) and Server SKUs with Desktop Experience installed,
|
||||
// missing on Server Core. We probe for the symbol at startup; if absent we
|
||||
// don't register the listener and the agent will silently drop SAS keysyms,
|
||||
// rather than panicking the entire service every time the user clicks
|
||||
// Ctrl+Alt+Del.
|
||||
func startSASListener(ctx context.Context) {
|
||||
ev, ok := createSASEvent()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Info("SAS listener ready (Session 0)")
|
||||
go runSASListenerLoop(ctx, ev)
|
||||
}
|
||||
|
||||
// createSASEvent prepares the named event handle on which the SAS listener
|
||||
// waits for client signals. Returns ok=false (with the failure already
|
||||
// logged) when the platform doesn't support SAS or the event cannot be
|
||||
// created; the caller must not spawn the listener goroutine in that case.
|
||||
func createSASEvent() (windows.Handle, bool) {
|
||||
if err := procSendSAS.Find(); err != nil {
|
||||
log.Warnf("SAS unavailable on this Windows SKU (sas.dll/SendSAS not present): %v", err)
|
||||
return 0, false
|
||||
}
|
||||
enableSoftwareSAS()
|
||||
namePtr, err := windows.UTF16PtrFromString(sasEventName)
|
||||
if err != nil {
|
||||
log.Warnf("SAS listener UTF16: %v", err)
|
||||
return 0, false
|
||||
}
|
||||
sa, err := sasSecurityAttributes()
|
||||
if err != nil {
|
||||
log.Warnf("build SAS security descriptor: %v", err)
|
||||
return 0, false
|
||||
}
|
||||
ev, err := windows.CreateEvent(sa, 0, 0, namePtr)
|
||||
if err != nil {
|
||||
log.Warnf("SAS CreateEvent: %v", err)
|
||||
return 0, false
|
||||
}
|
||||
return ev, true
|
||||
}
|
||||
|
||||
// runSASListenerLoop blocks on ev and invokes SendSAS each time it is
|
||||
// signalled, until ctx is cancelled. Recovers from panics inside SendSAS so
|
||||
// a future ABI surprise doesn't tear down the service.
|
||||
func runSASListenerLoop(ctx context.Context, ev windows.Handle) {
|
||||
defer func() { _ = windows.CloseHandle(ev) }()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Warnf("SAS listener recovered from panic: %v", r)
|
||||
}
|
||||
}()
|
||||
const pollMillis = 500
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
ret, _ := windows.WaitForSingleObject(ev, pollMillis)
|
||||
if ret != windows.WAIT_OBJECT_0 {
|
||||
continue
|
||||
}
|
||||
r, _, sasErr := procSendSAS.Call(0) // FALSE = not from service desktop
|
||||
if r == 0 {
|
||||
log.Warnf("SendSAS: %v", sasErr)
|
||||
continue
|
||||
}
|
||||
log.Info("SendSAS called from Session 0")
|
||||
}
|
||||
}
|
||||
|
||||
// enablePrivilege enables a named privilege on the current process token.
|
||||
func enablePrivilege(name string) error {
|
||||
var token windows.Token
|
||||
if err := windows.OpenProcessToken(windows.CurrentProcess(),
|
||||
windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, &token); err != nil {
|
||||
return err
|
||||
}
|
||||
defer token.Close()
|
||||
|
||||
var luid windows.LUID
|
||||
namePtr, err := windows.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("UTF16 privilege name: %w", err)
|
||||
}
|
||||
if err := windows.LookupPrivilegeValue(nil, namePtr, &luid); err != nil {
|
||||
return err
|
||||
}
|
||||
tp := windows.Tokenprivileges{PrivilegeCount: 1}
|
||||
tp.Privileges[0].Luid = luid
|
||||
tp.Privileges[0].Attributes = windows.SE_PRIVILEGE_ENABLED
|
||||
return windows.AdjustTokenPrivileges(token, false, &tp, 0, nil, nil)
|
||||
}
|
||||
|
||||
func (s *Server) platformSessionManager() virtualSessionManager {
|
||||
return nil
|
||||
}
|
||||
|
||||
// platformShutdown restores any machine state mutated by platformInit.
|
||||
func (s *Server) platformShutdown() {
|
||||
disableSoftwareSAS()
|
||||
}
|
||||
|
||||
// platformInit starts the SAS listener and enables privileges needed for
|
||||
// Session 0 operations (agent spawning, SendSAS).
|
||||
func (s *Server) platformInit() {
|
||||
for _, priv := range []string{"SeTcbPrivilege", "SeAssignPrimaryTokenPrivilege"} {
|
||||
if err := enablePrivilege(priv); err != nil {
|
||||
log.Debugf("enable %s: %v", priv, err)
|
||||
}
|
||||
}
|
||||
startSASListener(s.ctx)
|
||||
}
|
||||
|
||||
// serviceAcceptLoop runs in Session 0. It validates the source IP and
|
||||
// hands accepted connections to handleServiceConnection, which runs the
|
||||
// Noise_IK handshake before proxying to the user-session agent.
|
||||
func (s *Server) serviceAcceptLoop() {
|
||||
|
||||
sm := newSessionManager(agentPort)
|
||||
go sm.run()
|
||||
|
||||
log.Infof("service mode, proxying connections to agent on 127.0.0.1:%d", agentPort)
|
||||
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
sm.Stop()
|
||||
return
|
||||
default:
|
||||
}
|
||||
s.log.Debugf("accept VNC connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.tryAcquireConnSlot() {
|
||||
s.log.Warnf("rejecting VNC connection from %s: %d concurrent connections in flight", conn.RemoteAddr(), maxConcurrentVNCConns)
|
||||
_ = conn.Close()
|
||||
continue
|
||||
}
|
||||
enableTCPKeepAlive(conn, s.log)
|
||||
conn = newMetricsConn(conn, s.sessionRecorder)
|
||||
s.trackConn(conn)
|
||||
go func(c net.Conn) {
|
||||
defer s.releaseConnSlot()
|
||||
defer s.untrackConn(c)
|
||||
s.handleServiceConnection(c, sm)
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// handleServiceConnection runs the connection-header handshake (including
|
||||
// Noise_IK), then proxies the connection (with header bytes replayed) to
|
||||
// the agent listening on loopback.
|
||||
func (s *Server) handleServiceConnection(conn net.Conn, sm *sessionManager) {
|
||||
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
|
||||
|
||||
if !s.isAllowedSource(conn.RemoteAddr()) {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var headerBuf bytes.Buffer
|
||||
tee := io.TeeReader(conn, &headerBuf)
|
||||
teeConn := &prefixConn{Reader: tee, Conn: conn}
|
||||
|
||||
header, err := s.readConnectionHeader(teeConn)
|
||||
if err != nil {
|
||||
connLog.Debugf("read connection header: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if !s.disableAuth {
|
||||
if _, err := s.authenticateSession(header); err != nil {
|
||||
rejectConnection(conn, codeMessage(RejectCodeAuthForbidden, err.Error()))
|
||||
connLog.Warnf("auth rejected: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
s.registerConnAuth(conn, header)
|
||||
|
||||
// Replay buffered header bytes + remaining stream to the agent.
|
||||
replayConn := &prefixConn{
|
||||
Reader: io.MultiReader(&headerBuf, conn),
|
||||
Conn: conn,
|
||||
}
|
||||
proxyToAgent(s.ctx, replayConn, agentPort, sm.AuthToken())
|
||||
}
|
||||
|
||||
// prefixConn wraps a net.Conn, overriding Read to use a different reader.
|
||||
type prefixConn struct {
|
||||
io.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (p *prefixConn) Read(b []byte) (int, error) {
|
||||
return p.Reader.Read(b)
|
||||
}
|
||||
21
client/vnc/server/server_x11.go
Normal file
21
client/vnc/server/server_x11.go
Normal file
@@ -0,0 +1,21 @@
|
||||
//go:build unix && !darwin && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
func (s *Server) platformInit() {
|
||||
// no-op on X11
|
||||
}
|
||||
|
||||
// serviceAcceptLoop is not supported on Linux.
|
||||
func (s *Server) serviceAcceptLoop() {
|
||||
s.log.Warn("service mode not supported on Linux, falling back to direct mode")
|
||||
s.acceptLoop()
|
||||
}
|
||||
|
||||
func (s *Server) platformSessionManager() virtualSessionManager {
|
||||
return newSessionManager(s.log)
|
||||
}
|
||||
|
||||
func (s *Server) platformShutdown() {
|
||||
// no-op on this platform
|
||||
}
|
||||
633
client/vnc/server/session.go
Normal file
633
client/vnc/server/session.go
Normal file
@@ -0,0 +1,633 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"image"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
readDeadline = 60 * time.Second
|
||||
maxCutTextBytes = 1 << 20 // 1 MiB
|
||||
)
|
||||
|
||||
const tileSize = 64 // pixels per tile for dirty-rect detection
|
||||
|
||||
// fullFramePromoteNum/Den trigger full-frame encoding when the dirty area
|
||||
// exceeds num/den of the screen. Once past the crossover (benchmarks put it
|
||||
// around 60% at 1080p) a single zlib rect is faster than many per-tile
|
||||
// encodes AND produces about the same wire bytes: the per-tile path keeps
|
||||
// restarting zlib dictionaries and re-emitting rect headers.
|
||||
const (
|
||||
fullFramePromoteNum = 60
|
||||
fullFramePromoteDen = 100
|
||||
)
|
||||
|
||||
// bboxPromoteDensityPct collapses the coalesced rect list down to its
|
||||
// bounding box when the dirty pixels occupy at least this fraction of the
|
||||
// bbox. Catches the "windowed video" case where the player area dirties as
|
||||
// a dense block but is split into many sibling rects by overlays or by
|
||||
// non-uniform tile coverage. Sending one JPEG over the bbox beats sending
|
||||
// dozens of small JPEGs that each carry their own header and Tight stream
|
||||
// restart.
|
||||
const (
|
||||
bboxPromoteDensityPct = 70
|
||||
// bboxPromoteMinArea avoids promoting a handful of small scattered
|
||||
// rects whose bbox would span most of the screen and pull in mostly
|
||||
// clean pixels.
|
||||
bboxPromoteMinArea = tileSize * tileSize * 16
|
||||
)
|
||||
|
||||
type session struct {
|
||||
conn net.Conn
|
||||
capturer ScreenCapturer
|
||||
injector InputInjector
|
||||
serverW int
|
||||
serverH int
|
||||
desktopName string
|
||||
log *log.Entry
|
||||
|
||||
writeMu sync.Mutex
|
||||
// encMu guards the negotiated pixel format and encoding state below.
|
||||
// messageLoop writes these on SetPixelFormat/SetEncodings, which RFB
|
||||
// clients may send at any time after the handshake, while encoderLoop
|
||||
// reads them on every frame.
|
||||
encMu sync.RWMutex
|
||||
pf clientPixelFormat
|
||||
useTight bool
|
||||
useCopyRect bool
|
||||
useZlib bool
|
||||
useHextile bool
|
||||
tight *tightState
|
||||
zlib *zlibState
|
||||
copyRectDet *copyRectDetector
|
||||
// Pseudo-encodings the client advertised support for. Updated under
|
||||
// encMu by handleSetEncodings and read by the encoder goroutine.
|
||||
clientSupportsDesktopSize bool
|
||||
clientSupportsExtendedDesktopSize bool
|
||||
clientSupportsDesktopName bool
|
||||
clientSupportsLastRect bool
|
||||
clientSupportsQEMUKey bool
|
||||
clientSupportsExtClipboard bool
|
||||
clientSupportsCursor bool
|
||||
// clientSupportsExtMouseButtons is set when the client advertises the
|
||||
// ExtendedMouseButtons pseudo-encoding (-316). Once the server emits
|
||||
// the ack rect, the client switches its pointer events to the 6-byte
|
||||
// extended format that carries back/forward buttons in a second mask
|
||||
// byte. Without this gate the byte after the type field would still
|
||||
// be a standard 7-bit mask and our parser must not look further.
|
||||
clientSupportsExtMouseButtons bool
|
||||
// extMouseAckSent is set once we've emitted the pseudo-rect ack that
|
||||
// flips the client into extended-pointer mode. Sticky for the
|
||||
// session because the client only needs to see it once.
|
||||
extMouseAckSent bool
|
||||
extClipCapsSent bool
|
||||
// lastCursorSerial is the serial of the cursor sprite last emitted.
|
||||
// The encoder re-queries the source each cycle and only emits when
|
||||
// the serial changes.
|
||||
lastCursorSerial uint64
|
||||
// cursorSourceFailed latches a permanent failure from the cursor
|
||||
// source so the encoder stops polling for the rest of the session.
|
||||
// Reset on SetEncodings so a reconnect can retry.
|
||||
cursorSourceFailed bool
|
||||
// showRemoteCursor switches the encoder to compositing the server
|
||||
// cursor sprite into the captured framebuffer at the remote position
|
||||
// instead of emitting the Cursor pseudo-encoding. Toggled by the
|
||||
// client via clientNetbirdShowRemoteCursor.
|
||||
showRemoteCursor bool
|
||||
// cursorWarnOnce throttles the diagnostic emitted when remote-cursor
|
||||
// compositing falls back to a no-op (capturer cannot supply a sprite
|
||||
// or position). One line per session is enough to point at the cause.
|
||||
cursorWarnOnce sync.Once
|
||||
// clientJPEGQuality and clientZlibLevel hold the 0..9 levels the client
|
||||
// advertised via the QualityLevel / CompressLevel pseudo-encodings, or
|
||||
// -1 when the client has not expressed a preference. Applied to the
|
||||
// tight encoder state after every SetEncodings.
|
||||
clientJPEGQuality int
|
||||
clientZlibLevel int
|
||||
// prevFrame, curFrame and idleFrames live on the encoder goroutine and
|
||||
// must not be touched elsewhere. curFrame holds a session-owned copy of
|
||||
// the capturer's latest frame so the encoder works on a stable buffer
|
||||
// even when the capturer double-buffers and recycles memory underneath.
|
||||
prevFrame *image.RGBA
|
||||
curFrame *image.RGBA
|
||||
idleFrames int
|
||||
|
||||
// captureErrLast throttles "capture (transient)" logs while the
|
||||
// capturer is in a sustained failure state (e.g. X server died but a
|
||||
// client is still connected). Owned by the encoder goroutine.
|
||||
captureErrLast time.Time
|
||||
captureErrSeen bool
|
||||
|
||||
// encodeCh carries framebuffer-update requests from the read loop to the
|
||||
// encoder goroutine. Buffered size 1: RFB clients have one outstanding
|
||||
// request at a time, so a new request always replaces any pending one.
|
||||
encodeCh chan fbRequest
|
||||
|
||||
// pointerMu guards the cached last cursor position used by
|
||||
// releaseStickyInput so the disconnect-time button-release event
|
||||
// targets the cursor's current spot instead of warping to (0, 0).
|
||||
pointerMu sync.Mutex
|
||||
lastPointerX int
|
||||
lastPointerY int
|
||||
}
|
||||
|
||||
type fbRequest struct {
|
||||
incremental bool
|
||||
}
|
||||
|
||||
func (s *session) addr() string { return s.conn.RemoteAddr().String() }
|
||||
|
||||
// serve runs the full RFB session lifecycle.
|
||||
func (s *session) serve() {
|
||||
defer s.conn.Close()
|
||||
s.pf = defaultClientPixelFormat()
|
||||
s.clientJPEGQuality = -1
|
||||
s.clientZlibLevel = -1
|
||||
s.encodeCh = make(chan fbRequest, 1)
|
||||
|
||||
if err := s.handshake(); err != nil {
|
||||
s.log.Warnf("handshake with %s: %v", s.addr(), err)
|
||||
return
|
||||
}
|
||||
s.log.Infof("client connected: %s", s.addr())
|
||||
|
||||
// On any exit path (clean disconnect, transport error, panic) release
|
||||
// modifier keys and mouse buttons so the host doesn't end up with
|
||||
// Shift/Ctrl/Alt or a mouse button stuck because the client dropped
|
||||
// while holding them.
|
||||
defer s.releaseStickyInput()
|
||||
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
go s.clipboardPoll(done)
|
||||
|
||||
encoderDone := make(chan struct{})
|
||||
go s.encoderLoop(encoderDone)
|
||||
defer func() {
|
||||
close(s.encodeCh)
|
||||
<-encoderDone
|
||||
}()
|
||||
|
||||
if err := s.messageLoop(); err != nil && err != io.EOF {
|
||||
s.log.Warnf("client %s disconnected: %v", s.addr(), err)
|
||||
} else {
|
||||
s.log.Infof("client disconnected: %s", s.addr())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) handshake() error {
|
||||
// Send protocol version.
|
||||
if _, err := io.WriteString(s.conn, rfbProtocolVersion); err != nil {
|
||||
return fmt.Errorf("send version: %w", err)
|
||||
}
|
||||
|
||||
// Read client version.
|
||||
var clientVer [12]byte
|
||||
if _, err := io.ReadFull(s.conn, clientVer[:]); err != nil {
|
||||
return fmt.Errorf("read client version: %w", err)
|
||||
}
|
||||
|
||||
// Send supported security types.
|
||||
if err := s.sendSecurityTypes(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read chosen security type.
|
||||
var secType [1]byte
|
||||
if _, err := io.ReadFull(s.conn, secType[:]); err != nil {
|
||||
return fmt.Errorf("read security type: %w", err)
|
||||
}
|
||||
|
||||
if err := s.handleSecurity(secType[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read ClientInit.
|
||||
var clientInit [1]byte
|
||||
if _, err := io.ReadFull(s.conn, clientInit[:]); err != nil {
|
||||
return fmt.Errorf("read ClientInit: %w", err)
|
||||
}
|
||||
|
||||
return s.sendServerInit()
|
||||
}
|
||||
|
||||
// sendSecurityTypes advertises only secNone. Authentication and access
|
||||
// control happen in the NetBird connection header (Noise_IK handshake,
|
||||
// mode, username) that precedes the RFB handshake; the protocol-level
|
||||
// password scheme is not supported.
|
||||
func (s *session) sendSecurityTypes() error {
|
||||
_, err := s.conn.Write([]byte{1, secNone})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *session) handleSecurity(secType byte) error {
|
||||
if secType != secNone {
|
||||
return fmt.Errorf("unsupported security type: %d", secType)
|
||||
}
|
||||
return binary.Write(s.conn, binary.BigEndian, uint32(0))
|
||||
}
|
||||
|
||||
func (s *session) sendServerInit() error {
|
||||
desktop := s.desktopName
|
||||
if desktop == "" {
|
||||
desktop = "NetBird VNC"
|
||||
}
|
||||
name := []byte(desktop)
|
||||
buf := make([]byte, 0, 4+16+4+len(name))
|
||||
|
||||
// Framebuffer width and height.
|
||||
buf = append(buf, byte(s.serverW>>8), byte(s.serverW))
|
||||
buf = append(buf, byte(s.serverH>>8), byte(s.serverH))
|
||||
|
||||
// Server pixel format.
|
||||
buf = append(buf, serverPixelFormat[:]...)
|
||||
|
||||
// Desktop name.
|
||||
buf = append(buf,
|
||||
byte(len(name)>>24), byte(len(name)>>16),
|
||||
byte(len(name)>>8), byte(len(name)),
|
||||
)
|
||||
buf = append(buf, name...)
|
||||
|
||||
_, err := s.conn.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *session) messageLoop() error {
|
||||
for {
|
||||
var msgType [1]byte
|
||||
if err := s.conn.SetDeadline(time.Now().Add(readDeadline)); err != nil {
|
||||
return fmt.Errorf("set deadline: %w", err)
|
||||
}
|
||||
if _, err := io.ReadFull(s.conn, msgType[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var err error
|
||||
switch msgType[0] {
|
||||
case clientSetPixelFormat:
|
||||
err = s.handleSetPixelFormat()
|
||||
case clientSetEncodings:
|
||||
err = s.handleSetEncodings()
|
||||
case clientFramebufferUpdateRequest:
|
||||
err = s.handleFBUpdateRequest()
|
||||
case clientKeyEvent:
|
||||
err = s.handleKeyEvent()
|
||||
case clientPointerEvent:
|
||||
err = s.handlePointerEvent()
|
||||
case clientCutText:
|
||||
err = s.handleCutText()
|
||||
case clientQEMUMessage:
|
||||
err = s.handleQEMUMessage()
|
||||
case clientNetbirdTypeText:
|
||||
err = s.handleTypeText()
|
||||
case clientNetbirdShowRemoteCursor:
|
||||
err = s.handleShowRemoteCursor()
|
||||
default:
|
||||
return fmt.Errorf("unknown client message type: %d", msgType[0])
|
||||
}
|
||||
// Clear the deadline only after the full message has been read and
|
||||
// processed so payload reads in the handlers stay bounded.
|
||||
_ = s.conn.SetDeadline(time.Time{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) handleSetPixelFormat() error {
|
||||
var buf [19]byte // 3 padding + 16 pixel format
|
||||
if _, err := io.ReadFull(s.conn, buf[:]); err != nil {
|
||||
return fmt.Errorf("read SetPixelFormat: %w", err)
|
||||
}
|
||||
pf, err := parsePixelFormat(buf[3:19])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.encMu.Lock()
|
||||
s.pf = pf
|
||||
s.encMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) handleSetEncodings() error {
|
||||
var header [3]byte // 1 padding + 2 number-of-encodings
|
||||
if _, err := io.ReadFull(s.conn, header[:]); err != nil {
|
||||
return fmt.Errorf("read SetEncodings header: %w", err)
|
||||
}
|
||||
numEnc := binary.BigEndian.Uint16(header[1:3])
|
||||
// RFB clients advertise a handful of real encodings plus pseudo-encodings.
|
||||
// Cap to keep a malicious client from forcing a 256 KiB allocation per
|
||||
// SetEncodings message.
|
||||
const maxEncodings = 64
|
||||
if numEnc > maxEncodings {
|
||||
return fmt.Errorf("SetEncodings: too many encodings (%d)", numEnc)
|
||||
}
|
||||
buf := make([]byte, int(numEnc)*4)
|
||||
if _, err := io.ReadFull(s.conn, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encs, sendExtClipCaps, sendExtMouseAck := s.applyEncodings(buf, int(numEnc))
|
||||
if len(encs) > 0 {
|
||||
s.log.Debugf("client supports encodings: %s", strings.Join(encs, ", "))
|
||||
}
|
||||
if sendExtClipCaps {
|
||||
if err := s.writeExtClipMessage(buildExtClipCaps()); err != nil {
|
||||
return fmt.Errorf("send ext clipboard caps: %w", err)
|
||||
}
|
||||
}
|
||||
if sendExtMouseAck {
|
||||
if err := s.sendExtMouseAck(); err != nil {
|
||||
return fmt.Errorf("send ext mouse ack: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyEncodings parses the SetEncodings body, updates capability flags,
|
||||
// rebuilds the tight state if quality/level changed, and reports which
|
||||
// one-shot acknowledgements still need to be sent.
|
||||
func (s *session) applyEncodings(buf []byte, numEnc int) (names []string, sendExtClipCaps, sendExtMouseAck bool) {
|
||||
s.encMu.Lock()
|
||||
defer s.encMu.Unlock()
|
||||
// Per RFC 6143 §7.5.3 each SetEncodings replaces the previous list, so
|
||||
// reset all flags before re-applying. extClipCapsSent stays sticky so
|
||||
// we don't re-emit Caps every refresh.
|
||||
s.resetEncodingCaps()
|
||||
for i := range numEnc {
|
||||
enc := int32(binary.BigEndian.Uint32(buf[i*4 : i*4+4]))
|
||||
if name := s.applyEncoding(enc); name != "" {
|
||||
names = append(names, name)
|
||||
}
|
||||
}
|
||||
s.refreshTightStateLocked()
|
||||
sendExtClipCaps = s.clientSupportsExtClipboard && !s.extClipCapsSent
|
||||
if sendExtClipCaps {
|
||||
s.extClipCapsSent = true
|
||||
}
|
||||
sendExtMouseAck = s.clientSupportsExtMouseButtons && !s.extMouseAckSent
|
||||
if sendExtMouseAck {
|
||||
s.extMouseAckSent = true
|
||||
}
|
||||
return names, sendExtClipCaps, sendExtMouseAck
|
||||
}
|
||||
|
||||
// refreshTightStateLocked reallocates s.tight when the requested quality
|
||||
// or compression level no longer matches the cached state. Caller holds
|
||||
// s.encMu.
|
||||
func (s *session) refreshTightStateLocked() {
|
||||
if !s.useTight {
|
||||
return
|
||||
}
|
||||
if s.tight != nil &&
|
||||
s.tight.qualityLevel == s.clientJPEGQuality &&
|
||||
s.tight.compressLevel == s.clientZlibLevel {
|
||||
return
|
||||
}
|
||||
// When we replace an in-use tightState the client's stream-0
|
||||
// inflater carries dictionary state from the old deflater. Carry
|
||||
// the pending-reset flag so the next Basic rect tells the client
|
||||
// to reset its inflater before decoding.
|
||||
replacing := s.tight != nil
|
||||
s.tight = newTightStateWithLevels(s.clientJPEGQuality, s.clientZlibLevel)
|
||||
if replacing {
|
||||
s.tight.pendingZlibReset = true
|
||||
}
|
||||
}
|
||||
|
||||
// resetEncodingCaps zeroes the encoding capability flags so the next pass
|
||||
// through applyEncoding reflects exactly what the client just advertised.
|
||||
// Caller holds s.encMu. tight / copyRectDet allocations are kept; their
|
||||
// runtime use is gated by the boolean flags here.
|
||||
func (s *session) resetEncodingCaps() {
|
||||
s.useTight = false
|
||||
s.useCopyRect = false
|
||||
s.useZlib = false
|
||||
s.useHextile = false
|
||||
s.clientSupportsDesktopSize = false
|
||||
s.clientSupportsExtendedDesktopSize = false
|
||||
s.clientSupportsDesktopName = false
|
||||
s.clientSupportsLastRect = false
|
||||
s.clientSupportsQEMUKey = false
|
||||
s.clientSupportsExtClipboard = false
|
||||
s.clientSupportsCursor = false
|
||||
s.clientSupportsExtMouseButtons = false
|
||||
s.cursorSourceFailed = false
|
||||
s.clientJPEGQuality = -1
|
||||
s.clientZlibLevel = -1
|
||||
}
|
||||
|
||||
// applyEncoding records a single encoding/pseudo-encoding from a SetEncodings
|
||||
// message. Returns the short name used in the debug log, or "" if the value
|
||||
// is one we don't recognise. Caller holds s.encMu.
|
||||
func (s *session) applyEncoding(enc int32) string {
|
||||
switch enc {
|
||||
case encCopyRect:
|
||||
s.useCopyRect = true
|
||||
if s.copyRectDet == nil {
|
||||
s.copyRectDet = newCopyRectDetector(tileSize)
|
||||
}
|
||||
return "copyrect"
|
||||
case pseudoEncDesktopSize:
|
||||
s.clientSupportsDesktopSize = true
|
||||
return "desktop-size"
|
||||
case pseudoEncExtendedDesktopSize:
|
||||
s.clientSupportsExtendedDesktopSize = true
|
||||
return "ext-desktop-size"
|
||||
case pseudoEncDesktopName:
|
||||
s.clientSupportsDesktopName = true
|
||||
return "desktop-name"
|
||||
case pseudoEncLastRect:
|
||||
s.clientSupportsLastRect = true
|
||||
return "last-rect"
|
||||
case pseudoEncQEMUExtendedKeyEvent:
|
||||
s.clientSupportsQEMUKey = true
|
||||
return "qemu-key"
|
||||
case pseudoEncExtendedClipboard:
|
||||
s.clientSupportsExtClipboard = true
|
||||
return "ext-clipboard"
|
||||
case pseudoEncCursor:
|
||||
s.clientSupportsCursor = true
|
||||
return "cursor"
|
||||
case pseudoEncExtendedMouseButtons:
|
||||
s.clientSupportsExtMouseButtons = true
|
||||
return "ext-mouse-buttons"
|
||||
case encTight:
|
||||
s.useTight = true
|
||||
return "tight"
|
||||
case encZlib:
|
||||
s.useZlib = true
|
||||
if s.zlib == nil {
|
||||
s.zlib = newZlibStateLevel(zlibLevelFor(-1))
|
||||
}
|
||||
return "zlib"
|
||||
case encHextile:
|
||||
s.useHextile = true
|
||||
return "hextile"
|
||||
}
|
||||
if enc >= pseudoEncQualityLevelMin && enc <= pseudoEncQualityLevelMax {
|
||||
s.clientJPEGQuality = int(enc - pseudoEncQualityLevelMin)
|
||||
return fmt.Sprintf("quality=%d", s.clientJPEGQuality)
|
||||
}
|
||||
if enc >= pseudoEncCompressLevelMin && enc <= pseudoEncCompressLevelMax {
|
||||
s.clientZlibLevel = int(enc - pseudoEncCompressLevelMin)
|
||||
return fmt.Sprintf("compress=%d", s.clientZlibLevel)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// handleFBUpdateRequest parses the request and hands it to the encoder
|
||||
// goroutine. It never blocks on capture/encode, so the input dispatch loop
|
||||
// stays responsive even when a previous frame is still being encoded.
|
||||
func (s *session) handleFBUpdateRequest() error {
|
||||
var req [9]byte
|
||||
if _, err := io.ReadFull(s.conn, req[:]); err != nil {
|
||||
return fmt.Errorf("read FBUpdateRequest: %w", err)
|
||||
}
|
||||
r := fbRequest{incremental: req[0] == 1}
|
||||
// Channel is size 1. If a request is already pending, replace it with
|
||||
// this fresher one so the encoder always works on the latest ask.
|
||||
select {
|
||||
case s.encodeCh <- r:
|
||||
default:
|
||||
select {
|
||||
case <-s.encodeCh:
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case s.encodeCh <- r:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendDesktopName pushes a DesktopName pseudo-encoded update to the
|
||||
// client if it advertised support. Lets the client keep its window title
|
||||
// in sync with the active session (e.g. username changes after login on
|
||||
// a virtual session).
|
||||
func (s *session) SendDesktopName(name string) error {
|
||||
s.encMu.RLock()
|
||||
supported := s.clientSupportsDesktopName
|
||||
s.encMu.RUnlock()
|
||||
if !supported {
|
||||
s.desktopName = name
|
||||
return nil
|
||||
}
|
||||
s.desktopName = name
|
||||
header := make([]byte, 4)
|
||||
header[0] = serverFramebufferUpdate
|
||||
binary.BigEndian.PutUint16(header[2:4], 1)
|
||||
|
||||
body := encodeDesktopNameBody(name)
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
if _, err := s.conn.Write(header); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := s.conn.Write(body)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *session) handleKeyEvent() error {
|
||||
var data [7]byte
|
||||
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
|
||||
return fmt.Errorf("read KeyEvent: %w", err)
|
||||
}
|
||||
down := data[0] == 1
|
||||
keysym := binary.BigEndian.Uint32(data[3:7])
|
||||
s.injector.InjectKey(keysym, down)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleQEMUMessage parses one QEMU vendor message. Today we only handle
|
||||
// subtype 0 (Extended Key Event); the message itself is 12 bytes total so
|
||||
// reading 11 more after the type byte covers the layout regardless of
|
||||
// subtype, and unknown subtypes are dropped without aborting the session.
|
||||
func (s *session) handleQEMUMessage() error {
|
||||
var data [11]byte // subtype(1) + down(2) + keysym(4) + keycode(4)
|
||||
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
|
||||
return fmt.Errorf("read QEMU message: %w", err)
|
||||
}
|
||||
subtype := data[0]
|
||||
if subtype != qemuSubtypeExtendedKeyEvent {
|
||||
s.log.Tracef("ignoring QEMU subtype %d", subtype)
|
||||
return nil
|
||||
}
|
||||
down := binary.BigEndian.Uint16(data[1:3]) != 0
|
||||
keysym := binary.BigEndian.Uint32(data[3:7])
|
||||
scancode := binary.BigEndian.Uint32(data[7:11])
|
||||
s.injector.InjectKeyScancode(scancode, keysym, down)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) handlePointerEvent() error {
|
||||
var data [5]byte
|
||||
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
|
||||
return fmt.Errorf("read PointerEvent: %w", err)
|
||||
}
|
||||
mask := uint16(data[0])
|
||||
x := int(binary.BigEndian.Uint16(data[1:3]))
|
||||
y := int(binary.BigEndian.Uint16(data[3:5]))
|
||||
|
||||
s.encMu.RLock()
|
||||
extended := s.clientSupportsExtMouseButtons && s.extMouseAckSent
|
||||
s.encMu.RUnlock()
|
||||
if extended && mask&0x80 != 0 {
|
||||
var hi [1]byte
|
||||
if _, err := io.ReadFull(s.conn, hi[:]); err != nil {
|
||||
return fmt.Errorf("read ExtendedPointerEvent tail: %w", err)
|
||||
}
|
||||
// Strip the marker bit; bits 0..6 are the low part of the mask,
|
||||
// hi byte holds bits 7..14 (back at bit 7, forward at bit 8).
|
||||
mask = (mask & 0x7f) | uint16(hi[0])<<7
|
||||
}
|
||||
|
||||
s.pointerMu.Lock()
|
||||
s.lastPointerX = x
|
||||
s.lastPointerY = y
|
||||
s.pointerMu.Unlock()
|
||||
s.injector.InjectPointer(mask, x, y, s.serverW, s.serverH)
|
||||
return nil
|
||||
}
|
||||
|
||||
// stickyModifierKeysyms are the X11 keysyms we send "up" events for on
|
||||
// disconnect. Modifier-up while not held is a no-op on every supported
|
||||
// platform, so we can blanket-release without per-key tracking. This
|
||||
// covers the practical sticky-state bug: client drops while user is
|
||||
// holding Shift / Ctrl / Alt / Meta / Super.
|
||||
var stickyModifierKeysyms = [...]uint32{
|
||||
0xffe1, 0xffe2, // Shift_L, Shift_R
|
||||
0xffe3, 0xffe4, // Control_L, Control_R
|
||||
0xffe9, 0xffea, // Alt_L, Alt_R
|
||||
0xffe7, 0xffe8, // Meta_L, Meta_R
|
||||
0xffeb, 0xffec, // Super_L, Super_R
|
||||
0xff7e, // Mode_switch
|
||||
0xfe03, // ISO_Level3_Shift (AltGr)
|
||||
0xffe5, // Caps_Lock (release if user dropped mid-press)
|
||||
}
|
||||
|
||||
// releaseStickyInput synthesizes key-up for modifier keysyms and a
|
||||
// zero-button PointerEvent so the host doesn't end up with stuck input
|
||||
// when the client disconnects mid-press. Mouse coordinates are reused
|
||||
// from the last PointerEvent so we don't warp the cursor.
|
||||
func (s *session) releaseStickyInput() {
|
||||
for _, ks := range stickyModifierKeysyms {
|
||||
s.injector.InjectKey(ks, false)
|
||||
}
|
||||
s.pointerMu.Lock()
|
||||
x, y := s.lastPointerX, s.lastPointerY
|
||||
s.pointerMu.Unlock()
|
||||
s.injector.InjectPointer(0, x, y, s.serverW, s.serverH)
|
||||
}
|
||||
260
client/vnc/server/session_clipboard.go
Normal file
260
client/vnc/server/session_clipboard.go
Normal file
@@ -0,0 +1,260 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// clipboardPoll periodically checks the server-side clipboard and sends
|
||||
// changes to the VNC client. Only runs during active sessions.
|
||||
func (s *session) clipboardPoll(done <-chan struct{}) {
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastClip string
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
text := s.injector.GetClipboard()
|
||||
if len(text) > maxCutTextBytes {
|
||||
text = text[:maxCutTextBytes]
|
||||
}
|
||||
if text == "" || text == lastClip {
|
||||
continue
|
||||
}
|
||||
lastClip = text
|
||||
s.encMu.RLock()
|
||||
ext := s.clientSupportsExtClipboard
|
||||
s.encMu.RUnlock()
|
||||
if ext {
|
||||
if err := s.writeExtClipMessage(buildExtClipNotify(extClipFormatText)); err != nil {
|
||||
s.log.Debugf("send ext clipboard notify: %v", err)
|
||||
return
|
||||
}
|
||||
} else if err := s.sendServerCutText(text); err != nil {
|
||||
s.log.Debugf("send clipboard to client: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) handleCutText() error {
|
||||
var header [7]byte // 3 padding + 4 length
|
||||
if _, err := io.ReadFull(s.conn, header[:]); err != nil {
|
||||
return fmt.Errorf("read CutText header: %w", err)
|
||||
}
|
||||
rawLen := int32(binary.BigEndian.Uint32(header[3:7]))
|
||||
if rawLen < 0 {
|
||||
// Negative length signals ExtendedClipboard; absolute value is the
|
||||
// payload size. Guard against MinInt32 overflow before negating.
|
||||
if rawLen == -2147483648 {
|
||||
return fmt.Errorf("ext clipboard payload too large")
|
||||
}
|
||||
return s.handleExtCutText(uint32(-rawLen))
|
||||
}
|
||||
length := uint32(rawLen)
|
||||
if length > maxCutTextBytes {
|
||||
return fmt.Errorf("cut text too large: %d bytes", length)
|
||||
}
|
||||
buf := make([]byte, length)
|
||||
if _, err := io.ReadFull(s.conn, buf); err != nil {
|
||||
return fmt.Errorf("read CutText payload: %w", err)
|
||||
}
|
||||
s.injector.SetClipboard(latin1ToUTF8(buf))
|
||||
return nil
|
||||
}
|
||||
|
||||
// drainBytes consumes and discards n bytes from the connection. Used to
|
||||
// skip the payload of a malformed clipboard message after we've decided
|
||||
// not to honour it, so the next message stays aligned.
|
||||
func (s *session) drainBytes(n uint32) error {
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
if _, err := io.CopyN(io.Discard, s.conn, int64(n)); err != nil {
|
||||
return fmt.Errorf("drain %d bytes: %w", n, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// latin1ToUTF8 converts an RFB ClientCutText payload (ISO 8859-1 per
|
||||
// RFC 6143 §7.5.6) into a UTF-8 string. Bytes 0x80..0xFF map to the
|
||||
// matching U+0080..U+00FF code points; passing them through Go's
|
||||
// `string([]byte)` instead would produce invalid UTF-8 that downstream
|
||||
// clipboard backends mangle.
|
||||
func latin1ToUTF8(b []byte) string {
|
||||
runes := make([]rune, len(b))
|
||||
for i, c := range b {
|
||||
runes[i] = rune(c)
|
||||
}
|
||||
return string(runes)
|
||||
}
|
||||
|
||||
// utf8ToLatin1 converts a UTF-8 string into the Latin-1 byte sequence
|
||||
// required by legacy ServerCutText (RFC 6143 §7.6.4). Runes outside
|
||||
// U+0000..U+00FF are not representable in Latin-1; we substitute '?' so the
|
||||
// peer still receives a coherent message instead of a truncated or
|
||||
// silently mojibake'd payload. ExtendedClipboard clients take a separate
|
||||
// path that preserves full UTF-8.
|
||||
func utf8ToLatin1(s string) []byte {
|
||||
out := make([]byte, 0, len(s))
|
||||
for _, r := range s {
|
||||
if r > 0xFF {
|
||||
out = append(out, '?')
|
||||
continue
|
||||
}
|
||||
out = append(out, byte(r))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// handleExtCutText parses an ExtendedClipboard message (any of Caps,
|
||||
// Notify, Request, Peek, Provide) carried as a negative-length CutText.
|
||||
// Unknown actions, oversized payloads, and formats we don't handle
|
||||
// (RTF/HTML/DIB/Files) are logged and dropped instead of aborting the
|
||||
// session: a malformed clipboard message must never cost the user their
|
||||
// VNC connection. Read errors on the socket itself still propagate.
|
||||
func (s *session) handleExtCutText(payloadLen uint32) error {
|
||||
if payloadLen < 4 {
|
||||
s.log.Debugf("ext clipboard payload too short: %d", payloadLen)
|
||||
return s.drainBytes(payloadLen)
|
||||
}
|
||||
if payloadLen > extClipMaxPayload {
|
||||
s.log.Debugf("ext clipboard payload too large: %d", payloadLen)
|
||||
return s.drainBytes(payloadLen)
|
||||
}
|
||||
buf := make([]byte, payloadLen)
|
||||
if _, err := io.ReadFull(s.conn, buf); err != nil {
|
||||
return fmt.Errorf("read ext clipboard payload: %w", err)
|
||||
}
|
||||
flags := binary.BigEndian.Uint32(buf[0:4])
|
||||
action := flags & extClipActionMask
|
||||
formats := flags & extClipFormatMask
|
||||
rest := buf[4:]
|
||||
|
||||
// A Caps message sets the Caps bit alongside one bit per action the
|
||||
// peer supports, so the action byte is multi-bit. Detect it first; the
|
||||
// remaining actions are single-bit and are dispatched after.
|
||||
if action&extClipActionCaps != 0 {
|
||||
// Client max sizes are informational for us today: we only emit
|
||||
// text and already cap it at extClipMaxText.
|
||||
return nil
|
||||
}
|
||||
|
||||
switch action {
|
||||
case extClipActionRequest:
|
||||
if formats&extClipFormatText != 0 {
|
||||
return s.sendExtClipProvideText()
|
||||
}
|
||||
return nil
|
||||
case extClipActionPeek:
|
||||
return s.writeExtClipMessage(buildExtClipNotify(extClipFormatText))
|
||||
case extClipActionNotify:
|
||||
if formats&extClipFormatText != 0 {
|
||||
return s.writeExtClipMessage(buildExtClipRequest(extClipFormatText))
|
||||
}
|
||||
return nil
|
||||
case extClipActionProvide:
|
||||
s.handleExtClipProvide(flags, rest)
|
||||
return nil
|
||||
default:
|
||||
s.log.Debugf("unknown ext clipboard action 0x%x", action)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// handleExtClipProvide decodes a Provide payload and pushes the recovered
|
||||
// text into the host clipboard. Decode errors and unsupported formats (RTF,
|
||||
// HTML, etc.) are logged and dropped so a malformed message doesn't tear
|
||||
// down the session.
|
||||
func (s *session) handleExtClipProvide(flags uint32, payload []byte) {
|
||||
if len(payload) == 0 {
|
||||
return
|
||||
}
|
||||
text, err := parseExtClipProvideText(flags, payload)
|
||||
if err != nil {
|
||||
s.log.Debugf("parse ext clipboard provide: %v", err)
|
||||
return
|
||||
}
|
||||
if text != "" {
|
||||
s.injector.SetClipboard(text)
|
||||
}
|
||||
}
|
||||
|
||||
// sendExtClipProvideText answers an inbound Request(text) with the current
|
||||
// host clipboard contents, capped to extClipMaxText.
|
||||
func (s *session) sendExtClipProvideText() error {
|
||||
text := s.injector.GetClipboard()
|
||||
if len(text) > extClipMaxText {
|
||||
text = text[:extClipMaxText]
|
||||
}
|
||||
payload, err := buildExtClipProvideText(text)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build provide: %w", err)
|
||||
}
|
||||
return s.writeExtClipMessage(payload)
|
||||
}
|
||||
|
||||
// writeExtClipMessage frames an ExtendedClipboard payload as a ServerCutText
|
||||
// message with a negative length, then writes it under writeMu.
|
||||
func (s *session) writeExtClipMessage(payload []byte) error {
|
||||
if len(payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
buf := make([]byte, 8+len(payload))
|
||||
buf[0] = serverCutText
|
||||
// buf[1:4] = padding (zero)
|
||||
binary.BigEndian.PutUint32(buf[4:8], uint32(-int32(len(payload))))
|
||||
copy(buf[8:], payload)
|
||||
|
||||
s.writeMu.Lock()
|
||||
_, err := s.conn.Write(buf)
|
||||
s.writeMu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// handleTypeText handles the NetBird-specific PasteAndType message that
|
||||
// pushes host clipboard content as synthesized keystrokes, used to reach
|
||||
// secure desktops where the clipboard is isolated. Wire format mirrors
|
||||
// CutText: 3-byte padding + 4-byte length + text bytes.
|
||||
func (s *session) handleTypeText() error {
|
||||
var header [7]byte
|
||||
if _, err := io.ReadFull(s.conn, header[:]); err != nil {
|
||||
return fmt.Errorf("read TypeText header: %w", err)
|
||||
}
|
||||
length := binary.BigEndian.Uint32(header[3:7])
|
||||
if length > maxCutTextBytes {
|
||||
return fmt.Errorf("type text too large: %d bytes", length)
|
||||
}
|
||||
buf := make([]byte, length)
|
||||
if _, err := io.ReadFull(s.conn, buf); err != nil {
|
||||
return fmt.Errorf("read TypeText payload: %w", err)
|
||||
}
|
||||
s.injector.TypeText(string(buf))
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendServerCutText sends clipboard text from the server to the legacy
|
||||
// (non-ExtendedClipboard) client. The wire encoding is Latin-1; runes that
|
||||
// fall outside U+0000..U+00FF are best-effort replaced with '?' since the
|
||||
// peer cannot represent them.
|
||||
func (s *session) sendServerCutText(text string) error {
|
||||
data := utf8ToLatin1(text)
|
||||
buf := make([]byte, 8+len(data))
|
||||
buf[0] = serverCutText
|
||||
// buf[1:4] = padding (zero)
|
||||
binary.BigEndian.PutUint32(buf[4:8], uint32(len(data)))
|
||||
copy(buf[8:], data)
|
||||
|
||||
s.writeMu.Lock()
|
||||
_, err := s.conn.Write(buf)
|
||||
s.writeMu.Unlock()
|
||||
return err
|
||||
}
|
||||
88
client/vnc/server/session_cursor.go
Normal file
88
client/vnc/server/session_cursor.go
Normal file
@@ -0,0 +1,88 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"image"
|
||||
)
|
||||
|
||||
// pendingCursorRect returns the Cursor pseudo-rect for the current sprite
|
||||
// when the client negotiated the encoding and the platform exposes a
|
||||
// cursor source whose serial has changed since the last emission. A nil
|
||||
// return means "do not include a cursor rect in this FramebufferUpdate".
|
||||
func (s *session) pendingCursorRect() []byte {
|
||||
s.encMu.RLock()
|
||||
supported := s.clientSupportsCursor
|
||||
failed := s.cursorSourceFailed
|
||||
composite := s.showRemoteCursor
|
||||
lastSerial := s.lastCursorSerial
|
||||
s.encMu.RUnlock()
|
||||
if !supported || failed || composite {
|
||||
return nil
|
||||
}
|
||||
src, ok := s.capturer.(cursorSource)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
img, hotX, hotY, serial, err := src.Cursor()
|
||||
if err != nil {
|
||||
s.encMu.Lock()
|
||||
s.cursorSourceFailed = true
|
||||
s.encMu.Unlock()
|
||||
s.log.Debugf("cursor source unavailable: %v", err)
|
||||
return nil
|
||||
}
|
||||
if img == nil || serial == lastSerial {
|
||||
return nil
|
||||
}
|
||||
buf := encodeCursorPseudoRect(img, hotX, hotY)
|
||||
s.encMu.Lock()
|
||||
s.lastCursorSerial = serial
|
||||
s.encMu.Unlock()
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeCursorPseudoRect packs the cursor sprite into a Cursor pseudo
|
||||
// rectangle (RFB 7.7.4, pseudo-encoding -239). Layout: 12-byte rect header
|
||||
// followed by w*h*4 BGRX pixel bytes and a 1-bit mask of (w+7)/8 bytes per
|
||||
// row, MSB-first, with each row independently padded.
|
||||
func encodeCursorPseudoRect(img *image.RGBA, hotX, hotY int) []byte {
|
||||
w, h := img.Rect.Dx(), img.Rect.Dy()
|
||||
pixelBytes := w * h * 4
|
||||
maskStride := (w + 7) / 8
|
||||
maskBytes := maskStride * h
|
||||
buf := make([]byte, 12+pixelBytes+maskBytes)
|
||||
|
||||
binary.BigEndian.PutUint16(buf[0:2], uint16(hotX))
|
||||
binary.BigEndian.PutUint16(buf[2:4], uint16(hotY))
|
||||
binary.BigEndian.PutUint16(buf[4:6], uint16(w))
|
||||
binary.BigEndian.PutUint16(buf[6:8], uint16(h))
|
||||
enc := int32(pseudoEncCursor)
|
||||
binary.BigEndian.PutUint32(buf[8:12], uint32(enc))
|
||||
|
||||
pix := buf[12 : 12+pixelBytes]
|
||||
mask := buf[12+pixelBytes:]
|
||||
src := img.Pix
|
||||
stride := img.Stride
|
||||
for y := 0; y < h; y++ {
|
||||
row := y * stride
|
||||
dstRow := y * w * 4
|
||||
maskRow := y * maskStride
|
||||
for x := 0; x < w; x++ {
|
||||
r := src[row+x*4+0]
|
||||
g := src[row+x*4+1]
|
||||
b := src[row+x*4+2]
|
||||
a := src[row+x*4+3]
|
||||
off := dstRow + x*4
|
||||
pix[off+0] = b
|
||||
pix[off+1] = g
|
||||
pix[off+2] = r
|
||||
pix[off+3] = 0
|
||||
if a >= 0x80 {
|
||||
mask[maskRow+x/8] |= 0x80 >> (x % 8)
|
||||
}
|
||||
}
|
||||
}
|
||||
return buf
|
||||
}
|
||||
586
client/vnc/server/session_encode.go
Normal file
586
client/vnc/server/session_encode.go
Normal file
@@ -0,0 +1,586 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
"time"
|
||||
)
|
||||
|
||||
// encoderLoop owns the capture → diff → encode → write pipeline. Running it
|
||||
// off the read loop prevents a slow encode (zlib full-frame, many dirty
|
||||
// tiles) from blocking inbound input events.
|
||||
func (s *session) encoderLoop(done chan<- struct{}) {
|
||||
defer close(done)
|
||||
for req := range s.encodeCh {
|
||||
if err := s.processFBRequest(req); err != nil {
|
||||
s.log.Debugf("encode: %v", err)
|
||||
// On write/capture error, close the connection so messageLoop
|
||||
// exits and the session terminates cleanly.
|
||||
s.conn.Close()
|
||||
drainRequests(s.encodeCh)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) processFBRequest(req fbRequest) error {
|
||||
// Watch for resolution changes between cycles. When the capturer
|
||||
// reports a new size, tell the client via DesktopSize so it can
|
||||
// reallocate its backing buffer; the next full update will then fill
|
||||
// the new dimensions. Clients that didn't advertise support are stuck
|
||||
// with the original handshake size and just see clipping on resize.
|
||||
if err := s.handleResize(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
busy := s.applyBackpressure()
|
||||
if busy >= backpressureSkipThreshold {
|
||||
return s.sendEmptyUpdate()
|
||||
}
|
||||
|
||||
img, err := s.captureFrame()
|
||||
if errors.Is(err, errFrameUnchanged) {
|
||||
// macOS hashes the raw capture bytes and short-circuits when the
|
||||
// screen is byte-identical. Treat as "no dirty rects" to skip the
|
||||
// diff and send an empty update.
|
||||
s.idleFrames++
|
||||
delay := min(s.idleFrames*5, 100)
|
||||
time.Sleep(time.Duration(delay) * time.Millisecond)
|
||||
return s.sendEmptyUpdate()
|
||||
}
|
||||
if err != nil {
|
||||
// Capture failures are transient on Windows: a Ctrl+Alt+Del or
|
||||
// sign-out switches the OS to the secure desktop, and the DXGI
|
||||
// duplicator on the previous desktop returns an error until the
|
||||
// capturer reattaches on the new desktop. On Linux the X server
|
||||
// behind a virtual session may exit and the capturer reports
|
||||
// "unavailable" on every retry tick. Don't tear down the session
|
||||
// and don't spam the log: emit one line on the first failure, then
|
||||
// throttle further "still failing" lines to once per 5 s.
|
||||
s.captureErrorLog(err)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return s.sendEmptyUpdate()
|
||||
}
|
||||
s.captureRecovered()
|
||||
|
||||
s.maybeCompositeCursor(img)
|
||||
|
||||
if req.incremental && s.prevFrame != nil {
|
||||
return s.processIncremental(img)
|
||||
}
|
||||
|
||||
// Full update.
|
||||
s.idleFrames = 0
|
||||
if err := s.sendFullUpdate(img); err != nil {
|
||||
return err
|
||||
}
|
||||
s.swapPrevCur()
|
||||
s.refreshCopyRectIndex()
|
||||
return nil
|
||||
}
|
||||
|
||||
// processIncremental handles the diff/encode path for a non-initial frame.
|
||||
// Returns nil after writing either an empty update (no changes) or a mix of
|
||||
// CopyRect moves and pixel-encoded dirty rects.
|
||||
func (s *session) processIncremental(img *image.RGBA) error {
|
||||
tiles := diffTiles(s.prevFrame, img, s.serverW, s.serverH, tileSize)
|
||||
if len(tiles) == 0 {
|
||||
// Nothing changed. Back off briefly before responding to reduce
|
||||
// CPU usage when the screen is static. The client re-requests
|
||||
// immediately after receiving our empty response, so without
|
||||
// this delay we'd spin at ~1000fps checking for changes.
|
||||
s.idleFrames++
|
||||
delay := min(s.idleFrames*5, 100) // 5ms → 100ms adaptive backoff
|
||||
time.Sleep(time.Duration(delay) * time.Millisecond)
|
||||
s.swapPrevCur()
|
||||
return s.sendEmptyUpdate()
|
||||
}
|
||||
s.idleFrames = 0
|
||||
|
||||
// Snapshot the dirty set before extractCopyRectTiles consumes it.
|
||||
// extract mutates in place, so without the copy we lose the
|
||||
// move-destination positions needed to incrementally update the
|
||||
// CopyRect index after the swap.
|
||||
dirty := make([][4]int, len(tiles))
|
||||
copy(dirty, tiles)
|
||||
|
||||
var moves []copyRectMove
|
||||
if s.useCopyRect && s.copyRectDet != nil {
|
||||
moves, tiles = s.copyRectDet.extractCopyRectTiles(img, tiles)
|
||||
}
|
||||
|
||||
rects := coalesceRects(tiles)
|
||||
if s.shouldPromoteToFullFrame(rects) && len(moves) == 0 {
|
||||
if err := s.sendFullUpdate(img); err != nil {
|
||||
return err
|
||||
}
|
||||
s.swapPrevCur()
|
||||
s.refreshCopyRectIndex()
|
||||
return nil
|
||||
}
|
||||
if len(moves) == 0 {
|
||||
if bb, ok := promoteToBoundingBox(rects); ok {
|
||||
rects = bb
|
||||
}
|
||||
}
|
||||
if err := s.sendDirtyAndMoves(img, moves, rects); err != nil {
|
||||
return err
|
||||
}
|
||||
s.swapPrevCur()
|
||||
s.updateCopyRectIndex(dirty)
|
||||
return nil
|
||||
}
|
||||
|
||||
// backpressureSkipThreshold is the BusyFraction at and above which we drop
|
||||
// the next encode entirely and respond with an empty FramebufferUpdate.
|
||||
// Above this level the encoder would only stack more bytes behind a socket
|
||||
// that is already write-blocked, raising end-to-end latency.
|
||||
const backpressureSkipThreshold = 0.65
|
||||
|
||||
// backpressureRampStart is where adaptive quality begins clipping. Below
|
||||
// this fraction the honoured client quality is used as-is.
|
||||
const backpressureRampStart = 0.2
|
||||
|
||||
// backpressureMinQuality is the floor JPEG quality picked when the socket
|
||||
// is fully saturated short of the skip threshold.
|
||||
const backpressureMinQuality = 25
|
||||
|
||||
// applyBackpressure samples the socket BusyFraction (if available) and, if
|
||||
// Tight is in use, ramps the active JPEG quality from the client-honoured
|
||||
// value down to backpressureMinQuality as the fraction climbs from
|
||||
// backpressureRampStart toward backpressureSkipThreshold. Returns the
|
||||
// observed fraction so the caller can decide whether to skip the frame.
|
||||
func (s *session) applyBackpressure() float64 {
|
||||
type busyReporter interface{ BusyFraction() float64 }
|
||||
bs, ok := s.conn.(busyReporter)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
frac := bs.BusyFraction()
|
||||
|
||||
s.encMu.RLock()
|
||||
tight := s.tight
|
||||
s.encMu.RUnlock()
|
||||
if tight == nil {
|
||||
return frac
|
||||
}
|
||||
|
||||
base := jpegQualityForLevel(tight.qualityLevel)
|
||||
if base == 0 {
|
||||
// No client-negotiated quality; let tightQualityFor pick the
|
||||
// area-based default and skip backpressure adjustments that
|
||||
// would otherwise lock in a wrong starting point.
|
||||
tight.jpegQualityOverride = 0
|
||||
return frac
|
||||
}
|
||||
q := base
|
||||
if frac > backpressureRampStart {
|
||||
span := backpressureSkipThreshold - backpressureRampStart
|
||||
t := (frac - backpressureRampStart) / span
|
||||
if t > 1 {
|
||||
t = 1
|
||||
}
|
||||
q = base - int(float64(base-backpressureMinQuality)*t)
|
||||
if q < backpressureMinQuality {
|
||||
q = backpressureMinQuality
|
||||
}
|
||||
}
|
||||
tight.jpegQualityOverride = q
|
||||
return frac
|
||||
}
|
||||
|
||||
// captureErrorLog emits one log line on the first failure after success,
|
||||
// then at most once every captureErrThrottle while the capturer keeps
|
||||
// failing. The "recovered" transition is logged once when err is nil and
|
||||
// captureErrSeen was set.
|
||||
func (s *session) captureErrorLog(err error) {
|
||||
const captureErrThrottle = 5 * time.Second
|
||||
now := time.Now()
|
||||
if !s.captureErrSeen || now.Sub(s.captureErrLast) >= captureErrThrottle {
|
||||
s.log.Debugf("capture (transient): %v", err)
|
||||
s.captureErrLast = now
|
||||
}
|
||||
s.captureErrSeen = true
|
||||
}
|
||||
|
||||
// captureRecovered emits a one-shot debug line when capture works again
|
||||
// after a failure streak. Called by the success paths.
|
||||
func (s *session) captureRecovered() {
|
||||
if s.captureErrSeen {
|
||||
s.log.Debugf("capture recovered")
|
||||
s.captureErrSeen = false
|
||||
}
|
||||
}
|
||||
|
||||
// handleResize detects framebuffer-size changes between encode cycles and
|
||||
// notifies the client via the DesktopSize pseudo-encoding. Returns an
|
||||
// error only on write failure; capturers that don't expose Width/Height
|
||||
// yet (zero values during early startup) are silently ignored.
|
||||
func (s *session) handleResize() error {
|
||||
w, h := s.capturer.Width(), s.capturer.Height()
|
||||
if w <= 0 || h <= 0 {
|
||||
return nil
|
||||
}
|
||||
if w > maxFramebufferDim || h > maxFramebufferDim {
|
||||
s.log.Warnf("ignoring resize: %dx%d exceeds cap %d", w, h, maxFramebufferDim)
|
||||
return nil
|
||||
}
|
||||
if w == s.serverW && h == s.serverH {
|
||||
return nil
|
||||
}
|
||||
s.log.Debugf("framebuffer resized: %dx%d -> %dx%d", s.serverW, s.serverH, w, h)
|
||||
s.serverW = w
|
||||
s.serverH = h
|
||||
// Drop the prev frame so the next encode produces a full update at
|
||||
// the new dimensions rather than diffing against a stale-sized buffer.
|
||||
s.prevFrame = nil
|
||||
s.curFrame = nil
|
||||
if s.copyRectDet != nil {
|
||||
// Tile geometry changed; let updateDirty rebuild from scratch on
|
||||
// the next pass instead of reusing stale hashes keyed on old
|
||||
// (cols, rows).
|
||||
s.copyRectDet.prevTiles = nil
|
||||
s.copyRectDet.tileHash = nil
|
||||
}
|
||||
if err := s.sendDesktopSize(w, h); err != nil {
|
||||
return fmt.Errorf("send desktop size: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendDesktopSize emits a single-rect FramebufferUpdate carrying the
|
||||
// DesktopSize pseudo-encoding. No-op if the client did not negotiate it,
|
||||
// in which case the client just sees the new dimensions on the next full
|
||||
// update and will likely clip or scale.
|
||||
func (s *session) sendDesktopSize(w, h int) error {
|
||||
s.encMu.RLock()
|
||||
supported := s.clientSupportsDesktopSize || s.clientSupportsExtendedDesktopSize
|
||||
s.encMu.RUnlock()
|
||||
if !supported {
|
||||
return nil
|
||||
}
|
||||
header := make([]byte, 4)
|
||||
header[0] = serverFramebufferUpdate
|
||||
binary.BigEndian.PutUint16(header[2:4], 1)
|
||||
|
||||
body := encodeDesktopSizeBody(w, h)
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
if _, err := s.conn.Write(header); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := s.conn.Write(body)
|
||||
return err
|
||||
}
|
||||
|
||||
// sendExtMouseAck emits the pseudo-rect that flips the client into
|
||||
// ExtendedMouseButtons mode, where mouse-back and mouse-forward are
|
||||
// carried in a second mask byte. The rect has zero geometry and no
|
||||
// body; the encoding number alone is the signal.
|
||||
func (s *session) sendExtMouseAck() error {
|
||||
header := make([]byte, 4)
|
||||
header[0] = serverFramebufferUpdate
|
||||
binary.BigEndian.PutUint16(header[2:4], 1)
|
||||
|
||||
rect := make([]byte, 12)
|
||||
enc := int32(pseudoEncExtendedMouseButtons)
|
||||
binary.BigEndian.PutUint32(rect[8:12], uint32(enc))
|
||||
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
if _, err := s.conn.Write(header); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := s.conn.Write(rect)
|
||||
return err
|
||||
}
|
||||
|
||||
// refreshCopyRectIndex does a full hash sweep of the just-swapped prevFrame.
|
||||
// Used after full-frame sends, where we don't have a per-tile dirty list to
|
||||
// drive an incremental update.
|
||||
func (s *session) refreshCopyRectIndex() {
|
||||
if s.copyRectDet == nil || s.prevFrame == nil {
|
||||
return
|
||||
}
|
||||
s.copyRectDet.rebuild(s.prevFrame, s.serverW, s.serverH)
|
||||
}
|
||||
|
||||
// updateCopyRectIndex incrementally updates the CopyRect detector's hash
|
||||
// tables for the tiles that just changed. On first use (or after resize)
|
||||
// updateDirty internally falls back to a full rebuild.
|
||||
func (s *session) updateCopyRectIndex(dirty [][4]int) {
|
||||
if s.copyRectDet == nil || s.prevFrame == nil {
|
||||
return
|
||||
}
|
||||
s.copyRectDet.updateDirty(s.prevFrame, s.serverW, s.serverH, dirty)
|
||||
}
|
||||
|
||||
// captureFrame returns a session-owned frame for this encode cycle.
|
||||
// Capturers that implement captureIntoer (Linux X11, macOS) write directly
|
||||
// into curFrame, saving a per-frame full-screen memcpy. Capturers that
|
||||
// don't (Windows DXGI) return their own buffer which we copy into curFrame
|
||||
// to keep the encoder's prevFrame stable across the next capture cycle.
|
||||
func (s *session) captureFrame() (*image.RGBA, error) {
|
||||
w, h := s.serverW, s.serverH
|
||||
if s.curFrame == nil || s.curFrame.Rect.Dx() != w || s.curFrame.Rect.Dy() != h {
|
||||
s.curFrame = image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
}
|
||||
|
||||
if ci, ok := s.capturer.(captureIntoer); ok {
|
||||
if err := ci.CaptureInto(s.curFrame); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.curFrame, nil
|
||||
}
|
||||
|
||||
src, err := s.capturer.Capture()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.curFrame.Rect != src.Rect {
|
||||
s.curFrame = image.NewRGBA(src.Rect)
|
||||
}
|
||||
copy(s.curFrame.Pix, src.Pix)
|
||||
return s.curFrame, nil
|
||||
}
|
||||
|
||||
// promoteToBoundingBox replaces the rect list with a single rect covering
|
||||
// the bounding box of all inputs, provided the bbox is at least
|
||||
// bboxPromoteMinArea and the dirty pixels fill at least
|
||||
// bboxPromoteDensityPct of it. Returns the new rect list and true when the
|
||||
// promotion fires; otherwise returns nil, false and the caller keeps the
|
||||
// original list.
|
||||
func promoteToBoundingBox(rects [][4]int) ([][4]int, bool) {
|
||||
if len(rects) < 2 {
|
||||
return nil, false
|
||||
}
|
||||
x0, y0 := rects[0][0], rects[0][1]
|
||||
x1, y1 := x0+rects[0][2], y0+rects[0][3]
|
||||
dirty := 0
|
||||
for _, r := range rects {
|
||||
if r[0] < x0 {
|
||||
x0 = r[0]
|
||||
}
|
||||
if r[1] < y0 {
|
||||
y0 = r[1]
|
||||
}
|
||||
if r[0]+r[2] > x1 {
|
||||
x1 = r[0] + r[2]
|
||||
}
|
||||
if r[1]+r[3] > y1 {
|
||||
y1 = r[1] + r[3]
|
||||
}
|
||||
dirty += r[2] * r[3]
|
||||
}
|
||||
w, h := x1-x0, y1-y0
|
||||
bbox := w * h
|
||||
if bbox < bboxPromoteMinArea {
|
||||
return nil, false
|
||||
}
|
||||
if dirty*100 < bbox*bboxPromoteDensityPct {
|
||||
return nil, false
|
||||
}
|
||||
return [][4]int{{x0, y0, w, h}}, true
|
||||
}
|
||||
|
||||
// shouldPromoteToFullFrame returns true when the dirty rect set covers a
|
||||
// large enough fraction of the screen that a single full-frame zlib rect
|
||||
// beats per-tile encoding on both CPU time and wire bytes. The crossover
|
||||
// is measured via BenchmarkEncodeManyTilesVsFullFrame.
|
||||
func (s *session) shouldPromoteToFullFrame(rects [][4]int) bool {
|
||||
if s.serverW == 0 || s.serverH == 0 {
|
||||
return false
|
||||
}
|
||||
var dirty int
|
||||
for _, r := range rects {
|
||||
dirty += r[2] * r[3]
|
||||
}
|
||||
return dirty*fullFramePromoteDen > s.serverW*s.serverH*fullFramePromoteNum
|
||||
}
|
||||
|
||||
// swapPrevCur makes the just-encoded frame the new prevFrame (for the next
|
||||
// diff) and lets the old prevFrame buffer become the next curFrame. Avoids
|
||||
// an 8 MB copy per frame compared to the old savePrevFrame path.
|
||||
func (s *session) swapPrevCur() {
|
||||
s.prevFrame, s.curFrame = s.curFrame, s.prevFrame
|
||||
}
|
||||
|
||||
// sendEmptyUpdate sends a FramebufferUpdate with zero pixel rectangles.
|
||||
// When the cursor source reports a fresh sprite we still slip the Cursor
|
||||
// pseudo-rect into the same message so a shape change (e.g. hovering onto
|
||||
// a resize handle) reaches the client without waiting for a dirty frame.
|
||||
func (s *session) sendEmptyUpdate() error {
|
||||
cursorRect := s.pendingCursorRect()
|
||||
if cursorRect == nil {
|
||||
var buf [4]byte
|
||||
buf[0] = serverFramebufferUpdate
|
||||
return s.writeFramed(buf[:])
|
||||
}
|
||||
buf := make([]byte, 4+len(cursorRect))
|
||||
buf[0] = serverFramebufferUpdate
|
||||
binary.BigEndian.PutUint16(buf[2:4], 1)
|
||||
copy(buf[4:], cursorRect)
|
||||
return s.writeFramed(buf)
|
||||
}
|
||||
|
||||
func (s *session) sendFullUpdate(img *image.RGBA) error {
|
||||
w, h := s.serverW, s.serverH
|
||||
|
||||
s.encMu.RLock()
|
||||
pf := s.pf
|
||||
useTight := s.useTight
|
||||
tight := s.tight
|
||||
useZlib := s.useZlib
|
||||
zlib := s.zlib
|
||||
s.encMu.RUnlock()
|
||||
|
||||
cursorRect := s.pendingCursorRect()
|
||||
rectCount := uint16(1)
|
||||
if cursorRect != nil {
|
||||
rectCount++
|
||||
}
|
||||
|
||||
var rectBuf []byte
|
||||
switch {
|
||||
case useTight && tight != nil && pfIsTightCompatible(pf):
|
||||
rectBuf = encodeTightRect(img, pf, 0, 0, w, h, tight)
|
||||
case useZlib && zlib != nil:
|
||||
// encodeZlibRect bakes in its own FBU header; reuse it for the
|
||||
// single-rect path when there is no cursor to prepend.
|
||||
if cursorRect == nil {
|
||||
return s.writeFramed(encodeZlibRect(img, pf, 0, 0, w, h, zlib))
|
||||
}
|
||||
rectBuf = encodeZlibRect(img, pf, 0, 0, w, h, zlib)[4:]
|
||||
default:
|
||||
if cursorRect == nil {
|
||||
return s.writeFramed(encodeRawRect(img, pf, 0, 0, w, h))
|
||||
}
|
||||
rectBuf = encodeRawRect(img, pf, 0, 0, w, h)[4:]
|
||||
}
|
||||
|
||||
buf := make([]byte, 4+len(cursorRect)+len(rectBuf))
|
||||
buf[0] = serverFramebufferUpdate
|
||||
binary.BigEndian.PutUint16(buf[2:4], rectCount)
|
||||
off := 4
|
||||
off += copy(buf[off:], cursorRect)
|
||||
copy(buf[off:], rectBuf)
|
||||
return s.writeFramed(buf)
|
||||
}
|
||||
|
||||
func (s *session) writeFramed(buf []byte) error {
|
||||
s.writeMu.Lock()
|
||||
_, err := s.conn.Write(buf)
|
||||
s.writeMu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// sendDirtyAndMoves writes one FramebufferUpdate combining CopyRect moves
|
||||
// (cheap, 16 bytes each) and pixel-encoded dirty rects. Moves come first so
|
||||
// their source tiles are read from the client's pre-update framebuffer state,
|
||||
// before any subsequent rect overwrites them.
|
||||
func (s *session) sendDirtyAndMoves(img *image.RGBA, moves []copyRectMove, rects [][4]int) error {
|
||||
cursorRect := s.pendingCursorRect()
|
||||
if len(moves) == 0 && len(rects) == 0 && cursorRect == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
total := len(moves) + len(rects)
|
||||
if cursorRect != nil {
|
||||
total++
|
||||
}
|
||||
header := make([]byte, 4)
|
||||
header[0] = serverFramebufferUpdate
|
||||
binary.BigEndian.PutUint16(header[2:4], uint16(total))
|
||||
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
|
||||
if _, err := s.conn.Write(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cursorRect != nil {
|
||||
if _, err := s.conn.Write(cursorRect); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ts := tileSize
|
||||
for _, m := range moves {
|
||||
body := encodeCopyRectBody(m.srcX, m.srcY, m.dstX, m.dstY, ts, ts)
|
||||
if _, err := s.conn.Write(body); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range rects {
|
||||
x, y, w, h := r[0], r[1], r[2], r[3]
|
||||
rectBuf := s.encodeTile(img, x, y, w, h)
|
||||
if _, err := s.conn.Write(rectBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeTile produces the on-wire rect bytes for a single dirty tile. Tight
|
||||
// is the only non-Raw encoding we negotiate: uniform tiles collapse to its
|
||||
// Fill subencoding (~16 bytes), photo-like rects route to JPEG, and the
|
||||
// rest take the Basic+zlib path. Raw is the fallback when Tight is not
|
||||
// negotiated or the negotiated pixel format is incompatible with Tight's
|
||||
// mandatory 24-bit RGB TPIXEL encoding.
|
||||
//
|
||||
// Output omits the 4-byte FramebufferUpdate header; callers combine multiple
|
||||
// tiles into one message.
|
||||
func (s *session) encodeTile(img *image.RGBA, x, y, w, h int) []byte {
|
||||
s.encMu.RLock()
|
||||
pf := s.pf
|
||||
useHextile := s.useHextile
|
||||
useTight := s.useTight
|
||||
tight := s.tight
|
||||
useZlib := s.useZlib
|
||||
zlib := s.zlib
|
||||
s.encMu.RUnlock()
|
||||
|
||||
if useHextile {
|
||||
if pixel, uniform := tileIsUniform(img, x, y, w, h); uniform {
|
||||
r := byte(pixel)
|
||||
g := byte(pixel >> 8)
|
||||
b := byte(pixel >> 16)
|
||||
return encodeHextileSolidRect(r, g, b, pf, rect{x, y, w, h})
|
||||
}
|
||||
}
|
||||
if useTight && tight != nil && pfIsTightCompatible(pf) {
|
||||
return encodeTightRect(img, pf, x, y, w, h, tight)
|
||||
}
|
||||
if useZlib && zlib != nil {
|
||||
return encodeZlibRect(img, pf, x, y, w, h, zlib)[4:]
|
||||
}
|
||||
return encodeRawRect(img, pf, x, y, w, h)[4:]
|
||||
}
|
||||
|
||||
// drainRequests consumes any pending requests so the sender's close completes
|
||||
// cleanly after the encoder loop has decided to exit on error. Returns the
|
||||
// number of drained requests to defeat empty-block lints; callers ignore it.
|
||||
func drainRequests(ch chan fbRequest) int {
|
||||
var drained int
|
||||
for range ch {
|
||||
drained++
|
||||
}
|
||||
return drained
|
||||
}
|
||||
|
||||
// pfIsTightCompatible reports whether the negotiated client pixel format
|
||||
// satisfies Tight's TPIXEL constraint (RFB 7.7.6): the three RGB shifts form
|
||||
// a permutation of {0, 8, 16} so the colour values live in the low 24 bits.
|
||||
// bpp, endianness, and 8-bit channels are already enforced at SetPixelFormat
|
||||
// time. Any permutation works because Tight always emits a three-byte R, G,
|
||||
// B triple regardless of where the client stores each channel.
|
||||
func pfIsTightCompatible(pf clientPixelFormat) bool {
|
||||
shifts := uint32(1)<<pf.rShift | uint32(1)<<pf.gShift | uint32(1)<<pf.bShift
|
||||
return shifts == 1<<0|1<<8|1<<16
|
||||
}
|
||||
120
client/vnc/server/session_remote_cursor.go
Normal file
120
client/vnc/server/session_remote_cursor.go
Normal file
@@ -0,0 +1,120 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"io"
|
||||
)
|
||||
|
||||
// handleShowRemoteCursor handles the NetBird-specific RFB message that
|
||||
// toggles "show remote cursor" mode. Wire format: 1-byte enable flag
|
||||
// (0/1) plus 6 padding bytes reserved for future arguments.
|
||||
func (s *session) handleShowRemoteCursor() error {
|
||||
var data [7]byte
|
||||
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
|
||||
return fmt.Errorf("read showRemoteCursor: %w", err)
|
||||
}
|
||||
enable := data[0] != 0
|
||||
s.encMu.Lock()
|
||||
s.showRemoteCursor = enable
|
||||
s.encMu.Unlock()
|
||||
s.log.Debugf("show remote cursor: %v", enable)
|
||||
return nil
|
||||
}
|
||||
|
||||
// maybeCompositeCursor blends the current server cursor into img when the
|
||||
// client has enabled "show remote cursor" mode. Returns silently in every
|
||||
// error path: a failed compositing must not stop the regular encode flow.
|
||||
func (s *session) maybeCompositeCursor(img *image.RGBA) {
|
||||
s.encMu.RLock()
|
||||
enabled := s.showRemoteCursor
|
||||
s.encMu.RUnlock()
|
||||
if !enabled || img == nil {
|
||||
return
|
||||
}
|
||||
src, ok := s.capturer.(cursorSource)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
pos, ok := s.capturer.(cursorPositionSource)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
cursorImg, hotX, hotY, _, err := src.Cursor()
|
||||
if err != nil || cursorImg == nil {
|
||||
s.cursorWarnOnce.Do(func() {
|
||||
s.log.Warnf("remote cursor unavailable: %v", err)
|
||||
})
|
||||
return
|
||||
}
|
||||
posX, posY, err := pos.CursorPos()
|
||||
if err != nil {
|
||||
s.cursorWarnOnce.Do(func() {
|
||||
s.log.Warnf("remote cursor position unavailable: %v", err)
|
||||
})
|
||||
return
|
||||
}
|
||||
compositeCursor(img, cursorImg, posX-hotX, posY-hotY)
|
||||
}
|
||||
|
||||
// compositeCursor alpha-blends sprite onto frame at (dstX, dstY).
|
||||
// sprite is assumed to use premultiplied RGBA, which is what every
|
||||
// cursorSource implementation in this package produces (X11 XFixes and
|
||||
// macOS CG return premultiplied bytes natively; the Windows path
|
||||
// premultiplies during decodeColorCursor). Out-of-bounds destinations are
|
||||
// clipped.
|
||||
func compositeCursor(frame, sprite *image.RGBA, dstX, dstY int) {
|
||||
fw, fh := frame.Rect.Dx(), frame.Rect.Dy()
|
||||
sw, sh := sprite.Rect.Dx(), sprite.Rect.Dy()
|
||||
if sw == 0 || sh == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
x0, y0 := dstX, dstY
|
||||
x1, y1 := dstX+sw, dstY+sh
|
||||
if x0 < 0 {
|
||||
x0 = 0
|
||||
}
|
||||
if y0 < 0 {
|
||||
y0 = 0
|
||||
}
|
||||
if x1 > fw {
|
||||
x1 = fw
|
||||
}
|
||||
if y1 > fh {
|
||||
y1 = fh
|
||||
}
|
||||
if x0 >= x1 || y0 >= y1 {
|
||||
return
|
||||
}
|
||||
|
||||
fStride := frame.Stride
|
||||
sStride := sprite.Stride
|
||||
for y := y0; y < y1; y++ {
|
||||
sy := y - dstY
|
||||
fbRow := y * fStride
|
||||
sRow := sy * sStride
|
||||
for x := x0; x < x1; x++ {
|
||||
sx := x - dstX
|
||||
fbOff := fbRow + x*4
|
||||
sOff := sRow + sx*4
|
||||
a := uint32(sprite.Pix[sOff+3])
|
||||
if a == 0 {
|
||||
continue
|
||||
}
|
||||
if a == 255 {
|
||||
frame.Pix[fbOff+0] = sprite.Pix[sOff+0]
|
||||
frame.Pix[fbOff+1] = sprite.Pix[sOff+1]
|
||||
frame.Pix[fbOff+2] = sprite.Pix[sOff+2]
|
||||
continue
|
||||
}
|
||||
// Premultiplied compositing: dst = src + dst*(1-srcA).
|
||||
inv := 255 - a
|
||||
frame.Pix[fbOff+0] = sprite.Pix[sOff+0] + byte((uint32(frame.Pix[fbOff+0])*inv)/255)
|
||||
frame.Pix[fbOff+1] = sprite.Pix[sOff+1] + byte((uint32(frame.Pix[fbOff+1])*inv)/255)
|
||||
frame.Pix[fbOff+2] = sprite.Pix[sOff+2] + byte((uint32(frame.Pix[fbOff+2])*inv)/255)
|
||||
}
|
||||
}
|
||||
}
|
||||
80
client/vnc/server/shutdown_state.go
Normal file
80
client/vnc/server/shutdown_state.go
Normal file
@@ -0,0 +1,80 @@
|
||||
//go:build unix
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ShutdownState tracks VNC virtual session processes for crash recovery.
|
||||
// Persisted by the state manager; on restart, residual processes are killed.
|
||||
type ShutdownState struct {
|
||||
// Processes maps a description to its PID (e.g., "xvfb:50" -> 1234).
|
||||
Processes map[string]int `json:"processes,omitempty"`
|
||||
}
|
||||
|
||||
// Name returns the state name for the state manager.
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "vnc_sessions_state"
|
||||
}
|
||||
|
||||
// Cleanup kills any residual VNC session processes left from a crash.
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
if len(s.Processes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for desc, pid := range s.Processes {
|
||||
if pid <= 0 {
|
||||
continue
|
||||
}
|
||||
if !isOurProcess(pid, desc) {
|
||||
log.Debugf("cleanup:skipping PID %d (%s), not ours", pid, desc)
|
||||
continue
|
||||
}
|
||||
log.Infof("cleanup:killing residual process %d (%s)", pid, desc)
|
||||
// Kill the process group (negative PID) to get children too.
|
||||
if err := syscall.Kill(-pid, syscall.SIGTERM); err != nil {
|
||||
// Try individual process if group kill fails.
|
||||
if killErr := syscall.Kill(pid, syscall.SIGKILL); killErr != nil {
|
||||
log.Debugf("cleanup: kill pid %d (%s): group kill: %v, single kill: %v", pid, desc, err, killErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.Processes = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// isOurProcess verifies the PID still belongs to a VNC-related process
|
||||
// by checking /proc/<pid>/cmdline (Linux) or the process name.
|
||||
func isOurProcess(pid int, desc string) bool {
|
||||
// Check if the process exists at all.
|
||||
if err := syscall.Kill(pid, 0); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// On Linux, verify via /proc cmdline.
|
||||
cmdline, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid))
|
||||
if err != nil {
|
||||
log.Debugf("cleanup: cannot read /proc/%d/cmdline: %v, treating PID as foreign", pid, err)
|
||||
return false
|
||||
}
|
||||
|
||||
cmd := string(cmdline)
|
||||
// Match against expected process types.
|
||||
if strings.Contains(desc, "xvfb") || strings.Contains(desc, "xorg") {
|
||||
return strings.Contains(cmd, "Xvfb") || strings.Contains(cmd, "Xorg")
|
||||
}
|
||||
if strings.Contains(desc, "desktop") {
|
||||
return strings.Contains(cmd, "session") || strings.Contains(cmd, "plasma") ||
|
||||
strings.Contains(cmd, "gnome") || strings.Contains(cmd, "xfce") ||
|
||||
strings.Contains(cmd, "dbus-launch")
|
||||
}
|
||||
return false
|
||||
}
|
||||
53
client/vnc/server/stubs.go
Normal file
53
client/vnc/server/stubs.go
Normal file
@@ -0,0 +1,53 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
)
|
||||
|
||||
// StubCapturer is a placeholder for platforms without screen capture support.
|
||||
type StubCapturer struct{}
|
||||
|
||||
// Width returns 0 on unsupported platforms.
|
||||
func (c *StubCapturer) Width() int { return 0 }
|
||||
|
||||
// Height returns 0 on unsupported platforms.
|
||||
func (c *StubCapturer) Height() int { return 0 }
|
||||
|
||||
// Capture returns an error on unsupported platforms.
|
||||
func (c *StubCapturer) Capture() (*image.RGBA, error) {
|
||||
return nil, fmt.Errorf("screen capture not supported on this platform")
|
||||
}
|
||||
|
||||
// StubInputInjector is a placeholder for platforms without input injection support.
|
||||
type StubInputInjector struct{}
|
||||
|
||||
// InjectKey is a no-op on unsupported platforms.
|
||||
func (s *StubInputInjector) InjectKey(_ uint32, _ bool) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
// InjectKeyScancode is a no-op on unsupported platforms.
|
||||
func (s *StubInputInjector) InjectKeyScancode(_ uint32, _ uint32, _ bool) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
// InjectPointer is a no-op on unsupported platforms.
|
||||
func (s *StubInputInjector) InjectPointer(_ uint16, _, _, _, _ int) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
// SetClipboard is a no-op on unsupported platforms.
|
||||
func (s *StubInputInjector) SetClipboard(_ string) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
// GetClipboard returns empty on unsupported platforms.
|
||||
func (s *StubInputInjector) GetClipboard() string { return "" }
|
||||
|
||||
// TypeText is a no-op on unsupported platforms.
|
||||
func (s *StubInputInjector) TypeText(_ string) {
|
||||
// no-op
|
||||
}
|
||||
30
client/vnc/server/swizzle.go
Normal file
30
client/vnc/server/swizzle.go
Normal file
@@ -0,0 +1,30 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import "unsafe"
|
||||
|
||||
// swizzleBGRAtoRGBA swaps B and R channels in a BGRA pixel buffer and copies
|
||||
// into dst in-place (dst and src may alias). Operates on uint32 words: one
|
||||
// read-modify-write per pixel, which is meaningfully faster than the naive
|
||||
// three-byte-store per pixel for large buffers like framebuffers.
|
||||
//
|
||||
// The alpha byte is forced to 0xff so callers that capture from X11 GetImage
|
||||
// (where the X server leaves the pad byte as zero) still get an opaque image.
|
||||
func swizzleBGRAtoRGBA(dst, src []byte) {
|
||||
n := len(dst) / 4
|
||||
if len(src)/4 < n {
|
||||
n = len(src) / 4
|
||||
}
|
||||
if n == 0 {
|
||||
return
|
||||
}
|
||||
dp := unsafe.Slice((*uint32)(unsafe.Pointer(&dst[0])), n)
|
||||
sp := unsafe.Slice((*uint32)(unsafe.Pointer(&src[0])), n)
|
||||
for i := range n {
|
||||
p := sp[i]
|
||||
// p in memory: B, G, R, A -> as uint32 little-endian: 0xAARRGGBB
|
||||
// Want memory: R, G, B, 0xFF -> uint32 little-endian: 0xFFBBGGRR
|
||||
dp[i] = 0xFF000000 | (p & 0x0000FF00) | ((p & 0x00FF0000) >> 16) | ((p & 0x000000FF) << 16)
|
||||
}
|
||||
}
|
||||
111
client/vnc/server/tight_test.go
Normal file
111
client/vnc/server/tight_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"image/jpeg"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func makeUniformImage(w, h int, r, g, b byte) *image.RGBA {
|
||||
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
for i := 0; i < len(img.Pix); i += 4 {
|
||||
img.Pix[i+0] = r
|
||||
img.Pix[i+1] = g
|
||||
img.Pix[i+2] = b
|
||||
img.Pix[i+3] = 0xff
|
||||
}
|
||||
return img
|
||||
}
|
||||
|
||||
func makeTwoColorImage(w, h int) *image.RGBA {
|
||||
img := makeUniformImage(w, h, 0x10, 0x20, 0x30)
|
||||
fg := [3]byte{0xa0, 0xb0, 0xc0}
|
||||
for y := 0; y < h; y++ {
|
||||
for x := w / 4; x < w/2; x++ {
|
||||
i := y*img.Stride + x*4
|
||||
img.Pix[i+0] = fg[0]
|
||||
img.Pix[i+1] = fg[1]
|
||||
img.Pix[i+2] = fg[2]
|
||||
}
|
||||
}
|
||||
return img
|
||||
}
|
||||
|
||||
func decodeTightLength(buf []byte) (n, consumed int) {
|
||||
b0 := buf[0]
|
||||
n = int(b0 & 0x7f)
|
||||
if b0&0x80 == 0 {
|
||||
return n, 1
|
||||
}
|
||||
b1 := buf[1]
|
||||
n |= int(b1&0x7f) << 7
|
||||
if b1&0x80 == 0 {
|
||||
return n, 2
|
||||
}
|
||||
b2 := buf[2]
|
||||
n |= int(b2) << 14
|
||||
return n, 3
|
||||
}
|
||||
|
||||
func TestEncodeTightFill(t *testing.T) {
|
||||
pf := defaultClientPixelFormat()
|
||||
img := makeUniformImage(64, 64, 0x12, 0x34, 0x56)
|
||||
tstate := newTightState()
|
||||
buf := encodeTightRect(img, pf, 0, 0, 64, 64, tstate)
|
||||
if len(buf) != 12+1+3 {
|
||||
t.Fatalf("fill rect should be 16 bytes, got %d", len(buf))
|
||||
}
|
||||
if buf[12] != tightFillSubenc {
|
||||
t.Fatalf("expected fill subenc, got 0x%02x", buf[12])
|
||||
}
|
||||
if buf[13] != 0x12 || buf[14] != 0x34 || buf[15] != 0x56 {
|
||||
t.Fatalf("wrong fill colour: %v", buf[13:16])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeTightBasic(t *testing.T) {
|
||||
pf := defaultClientPixelFormat()
|
||||
img := makeTwoColorImage(64, 64)
|
||||
tstate := newTightState()
|
||||
buf := encodeTightRect(img, pf, 0, 0, 64, 64, tstate)
|
||||
if buf[12]&0xf0 != tightBasicFilter {
|
||||
t.Fatalf("expected basic+filter subenc, got 0x%02x", buf[12])
|
||||
}
|
||||
if buf[13] != tightFilterCopy {
|
||||
t.Fatalf("expected copy filter, got 0x%02x", buf[13])
|
||||
}
|
||||
// Length prefix and zlib stream follow.
|
||||
n, _ := decodeTightLength(buf[14:])
|
||||
if n == 0 {
|
||||
t.Fatalf("zero-length basic stream")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeTightJPEG(t *testing.T) {
|
||||
pf := defaultClientPixelFormat()
|
||||
img := makeBenchImage(128, 128, 7) // random → many colours
|
||||
tstate := newTightState()
|
||||
buf := encodeTightRect(img, pf, 0, 0, 128, 128, tstate)
|
||||
if buf[12] != tightJPEGSubenc {
|
||||
t.Fatalf("expected JPEG subenc, got 0x%02x", buf[12])
|
||||
}
|
||||
n, consumed := decodeTightLength(buf[13:])
|
||||
jpegBytes := buf[13+consumed : 13+consumed+n]
|
||||
if _, err := jpeg.Decode(bytes.NewReader(jpegBytes)); err != nil {
|
||||
t.Fatalf("emitted JPEG bytes do not decode: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSampledColorCount(t *testing.T) {
|
||||
uniform := makeUniformImage(64, 64, 0x10, 0x20, 0x30)
|
||||
if c := sampledColorCountInto(map[uint32]struct{}{}, uniform, 0, 0, 64, 64, 32); c != 1 {
|
||||
t.Fatalf("uniform should be 1 colour, got %d", c)
|
||||
}
|
||||
rnd := makeBenchImage(128, 128, 1)
|
||||
if c := sampledColorCountInto(map[uint32]struct{}{}, rnd, 0, 0, 128, 128, 16); c <= 16 {
|
||||
t.Fatalf("random image should exceed colour cap, got %d", c)
|
||||
}
|
||||
}
|
||||
736
client/vnc/server/virtual_x11.go
Normal file
736
client/vnc/server/virtual_x11.go
Normal file
@@ -0,0 +1,736 @@
|
||||
//go:build unix && !darwin && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// VirtualSession manages a virtual X11 display (Xvfb) with a desktop session
|
||||
// running as a target user. It implements ScreenCapturer and InputInjector by
|
||||
// delegating to an X11Capturer/X11InputInjector pointed at the virtual display.
|
||||
const (
|
||||
sessionIdleTimeout = 5 * time.Minute
|
||||
|
||||
defaultSessionWidth uint16 = 1280
|
||||
defaultSessionHeight uint16 = 800
|
||||
)
|
||||
|
||||
type VirtualSession struct {
|
||||
mu sync.Mutex
|
||||
display string
|
||||
user *user.User
|
||||
uid uint32
|
||||
gid uint32
|
||||
groups []uint32
|
||||
width uint16
|
||||
height uint16
|
||||
xvfb *exec.Cmd
|
||||
desktop *exec.Cmd
|
||||
poller *X11Poller
|
||||
injector *X11InputInjector
|
||||
log *log.Entry
|
||||
stopped bool
|
||||
clients int
|
||||
idleTimer *time.Timer
|
||||
onIdle func() // called when idle timeout fires or Xvfb dies
|
||||
}
|
||||
|
||||
// StartVirtualSession creates and starts a virtual X11 session for the given
|
||||
// user. Requires root privileges to create sessions as other users. width and
|
||||
// height request the virtual display geometry; 0 values fall back to the
|
||||
// defaults.
|
||||
func StartVirtualSession(username string, width, height uint16, logger *log.Entry) (*VirtualSession, error) {
|
||||
if os.Getuid() != 0 {
|
||||
return nil, fmt.Errorf("virtual sessions require root privileges")
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("Xvfb"); err != nil {
|
||||
if _, err := exec.LookPath("Xorg"); err != nil {
|
||||
return nil, fmt.Errorf("neither Xvfb nor Xorg found (install xvfb or xserver-xorg)")
|
||||
}
|
||||
if !hasDummyDriver() {
|
||||
return nil, fmt.Errorf("xvfb not found and xorg dummy driver not installed (install xvfb or xf86-video-dummy)")
|
||||
}
|
||||
}
|
||||
|
||||
u, err := user.Lookup(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("lookup user %s: %w", username, err)
|
||||
}
|
||||
|
||||
uid, err := strconv.ParseUint(u.Uid, 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse uid: %w", err)
|
||||
}
|
||||
gid, err := strconv.ParseUint(u.Gid, 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse gid: %w", err)
|
||||
}
|
||||
|
||||
groups, err := supplementaryGroups(u)
|
||||
if err != nil {
|
||||
logger.Debugf("supplementary groups for %s: %v", username, err)
|
||||
}
|
||||
|
||||
if width == 0 {
|
||||
width = defaultSessionWidth
|
||||
}
|
||||
if height == 0 {
|
||||
height = defaultSessionHeight
|
||||
}
|
||||
|
||||
vs := &VirtualSession{
|
||||
user: u,
|
||||
uid: uint32(uid),
|
||||
gid: uint32(gid),
|
||||
groups: groups,
|
||||
width: width,
|
||||
height: height,
|
||||
log: logger.WithField("vnc_user", username),
|
||||
}
|
||||
|
||||
if err := vs.start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return vs, nil
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) start() error {
|
||||
display, err := findFreeDisplay()
|
||||
if err != nil {
|
||||
return fmt.Errorf("find free display: %w", err)
|
||||
}
|
||||
vs.display = display
|
||||
|
||||
if err := vs.startXvfb(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
socketPath := fmt.Sprintf("%s/X%s", x11SocketDir, vs.display[1:])
|
||||
if err := waitForPath(socketPath, 5*time.Second); err != nil {
|
||||
vs.stopXvfb()
|
||||
return fmt.Errorf("wait for X11 socket %s: %w", socketPath, err)
|
||||
}
|
||||
|
||||
// Grant the target user access to the display via xhost.
|
||||
xhostCmd := exec.Command("xhost", "+SI:localuser:"+vs.user.Username)
|
||||
xhostCmd.Env = []string{envDisplay + "=" + vs.display}
|
||||
if out, err := xhostCmd.CombinedOutput(); err != nil {
|
||||
vs.log.Debugf("xhost: %s (%v)", strings.TrimSpace(string(out)), err)
|
||||
}
|
||||
|
||||
vs.poller = NewX11Poller(vs.display)
|
||||
|
||||
injector, err := NewX11InputInjector(vs.display)
|
||||
if err != nil {
|
||||
vs.stopXvfb()
|
||||
return fmt.Errorf("create X11 injector for %s: %w", vs.display, err)
|
||||
}
|
||||
vs.injector = injector
|
||||
|
||||
if err := vs.startDesktop(); err != nil {
|
||||
vs.injector.Close()
|
||||
vs.stopXvfb()
|
||||
return fmt.Errorf("start desktop: %w", err)
|
||||
}
|
||||
|
||||
vs.log.Infof("virtual session started: display=%s user=%s", vs.display, vs.user.Username)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClientConnect increments the client count and cancels any idle timer.
|
||||
func (vs *VirtualSession) ClientConnect() {
|
||||
vs.mu.Lock()
|
||||
defer vs.mu.Unlock()
|
||||
vs.clients++
|
||||
if vs.idleTimer != nil {
|
||||
vs.idleTimer.Stop()
|
||||
vs.idleTimer = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ClientDisconnect decrements the client count. When the last client
|
||||
// disconnects, starts an idle timer that destroys the session.
|
||||
func (vs *VirtualSession) ClientDisconnect() {
|
||||
vs.mu.Lock()
|
||||
defer vs.mu.Unlock()
|
||||
vs.clients--
|
||||
if vs.clients <= 0 {
|
||||
vs.clients = 0
|
||||
vs.log.Infof("no VNC clients connected, session will be destroyed in %s", sessionIdleTimeout)
|
||||
vs.idleTimer = time.AfterFunc(sessionIdleTimeout, vs.idleExpired)
|
||||
}
|
||||
}
|
||||
|
||||
// idleExpired is called by the idle timer. It stops the session and
|
||||
// notifies the session manager via onIdle so it removes us from the map.
|
||||
// Bails out early if a client reconnected before the timer callback won
|
||||
// the race (Stop() doesn't cancel an already-firing AfterFunc, so the
|
||||
// state check has to happen here under vs.mu).
|
||||
func (vs *VirtualSession) idleExpired() {
|
||||
vs.mu.Lock()
|
||||
if vs.stopped || vs.clients > 0 {
|
||||
vs.mu.Unlock()
|
||||
return
|
||||
}
|
||||
vs.mu.Unlock()
|
||||
|
||||
vs.log.Info("idle timeout reached, destroying virtual session")
|
||||
vs.Stop()
|
||||
if vs.onIdle != nil {
|
||||
vs.onIdle()
|
||||
}
|
||||
}
|
||||
|
||||
// isAlive returns true if the session is running and its X server socket exists.
|
||||
func (vs *VirtualSession) isAlive() bool {
|
||||
vs.mu.Lock()
|
||||
stopped := vs.stopped
|
||||
display := vs.display
|
||||
vs.mu.Unlock()
|
||||
|
||||
if stopped {
|
||||
return false
|
||||
}
|
||||
// Verify the X socket still exists on disk.
|
||||
socketPath := fmt.Sprintf("%s/X%s", x11SocketDir, display[1:])
|
||||
if _, err := os.Stat(socketPath); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Capturer returns the screen capturer for this virtual session.
|
||||
func (vs *VirtualSession) Capturer() ScreenCapturer {
|
||||
return vs.poller
|
||||
}
|
||||
|
||||
// Injector returns the input injector for this virtual session.
|
||||
func (vs *VirtualSession) Injector() InputInjector {
|
||||
return vs.injector
|
||||
}
|
||||
|
||||
// Display returns the X11 display string (e.g., ":99").
|
||||
func (vs *VirtualSession) Display() string {
|
||||
return vs.display
|
||||
}
|
||||
|
||||
// Stop terminates the virtual session, killing the desktop and Xvfb.
|
||||
func (vs *VirtualSession) Stop() {
|
||||
vs.mu.Lock()
|
||||
defer vs.mu.Unlock()
|
||||
|
||||
if vs.stopped {
|
||||
return
|
||||
}
|
||||
vs.stopped = true
|
||||
|
||||
if vs.injector != nil {
|
||||
vs.injector.Close()
|
||||
}
|
||||
if vs.poller != nil {
|
||||
vs.poller.Close()
|
||||
vs.poller = nil
|
||||
}
|
||||
|
||||
vs.stopDesktop()
|
||||
vs.stopXvfb()
|
||||
|
||||
vs.log.Info("virtual session stopped")
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) startXvfb() error {
|
||||
if _, err := exec.LookPath("Xvfb"); err == nil {
|
||||
return vs.startXvfbDirect()
|
||||
}
|
||||
return vs.startXorgDummy()
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) startXvfbDirect() error {
|
||||
geom := fmt.Sprintf("%dx%dx24", vs.width, vs.height)
|
||||
vs.xvfb = exec.Command("Xvfb", vs.display,
|
||||
"-screen", "0", geom,
|
||||
"-nolisten", "tcp",
|
||||
)
|
||||
vs.xvfb.SysProcAttr = &syscall.SysProcAttr{Setsid: true, Pdeathsig: syscall.SIGTERM}
|
||||
|
||||
if err := vs.xvfb.Start(); err != nil {
|
||||
return fmt.Errorf("start Xvfb on %s: %w", vs.display, err)
|
||||
}
|
||||
vs.log.Infof("Xvfb started on %s (pid=%d)", vs.display, vs.xvfb.Process.Pid)
|
||||
|
||||
go vs.monitorXvfb()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// startXorgDummy starts Xorg with the dummy video driver as a fallback when
|
||||
// Xvfb is not installed. Most systems with a desktop have Xorg available.
|
||||
func (vs *VirtualSession) startXorgDummy() error {
|
||||
conf := fmt.Sprintf(`Section "Device"
|
||||
Identifier "dummy"
|
||||
Driver "dummy"
|
||||
VideoRam 256000
|
||||
EndSection
|
||||
Section "Screen"
|
||||
Identifier "screen"
|
||||
Device "dummy"
|
||||
DefaultDepth 24
|
||||
SubSection "Display"
|
||||
Depth 24
|
||||
Modes "%dx%d"
|
||||
EndSubSection
|
||||
EndSection
|
||||
`, vs.width, vs.height)
|
||||
f, err := os.CreateTemp("", fmt.Sprintf("nbvnc-dummy-%s-*.conf", vs.display[1:]))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create Xorg dummy config: %w", err)
|
||||
}
|
||||
confPath := f.Name()
|
||||
if _, err := f.WriteString(conf); err != nil {
|
||||
f.Close()
|
||||
os.Remove(confPath)
|
||||
return fmt.Errorf("write Xorg dummy config: %w", err)
|
||||
}
|
||||
if err := f.Chmod(0600); err != nil {
|
||||
f.Close()
|
||||
os.Remove(confPath)
|
||||
return fmt.Errorf("chmod Xorg dummy config: %w", err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
os.Remove(confPath)
|
||||
return fmt.Errorf("close Xorg dummy config: %w", err)
|
||||
}
|
||||
|
||||
vs.xvfb = exec.Command("Xorg", vs.display,
|
||||
"-config", confPath,
|
||||
"-noreset",
|
||||
"-nolisten", "tcp",
|
||||
)
|
||||
vs.xvfb.SysProcAttr = &syscall.SysProcAttr{Setsid: true, Pdeathsig: syscall.SIGTERM}
|
||||
|
||||
if err := vs.xvfb.Start(); err != nil {
|
||||
os.Remove(confPath)
|
||||
return fmt.Errorf("start Xorg dummy on %s: %w", vs.display, err)
|
||||
}
|
||||
vs.log.Infof("Xorg (dummy driver) started on %s (pid=%d)", vs.display, vs.xvfb.Process.Pid)
|
||||
|
||||
go func() {
|
||||
vs.monitorXvfb()
|
||||
os.Remove(confPath)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// monitorXvfb waits for the Xvfb/Xorg process to exit. If it exits
|
||||
// unexpectedly (not via Stop), the session is marked as dead and the
|
||||
// onIdle callback fires so the session manager removes it from the map.
|
||||
// The next GetOrCreate call for this user will create a fresh session.
|
||||
func (vs *VirtualSession) monitorXvfb() {
|
||||
if err := vs.xvfb.Wait(); err != nil {
|
||||
vs.log.Debugf("X server exited: %v", err)
|
||||
}
|
||||
|
||||
vs.mu.Lock()
|
||||
alreadyStopped := vs.stopped
|
||||
if !alreadyStopped {
|
||||
vs.log.Warn("X server exited unexpectedly, marking session as dead")
|
||||
vs.stopped = true
|
||||
if vs.idleTimer != nil {
|
||||
vs.idleTimer.Stop()
|
||||
vs.idleTimer = nil
|
||||
}
|
||||
if vs.injector != nil {
|
||||
vs.injector.Close()
|
||||
}
|
||||
vs.stopDesktop()
|
||||
}
|
||||
onIdle := vs.onIdle
|
||||
vs.mu.Unlock()
|
||||
|
||||
if !alreadyStopped && onIdle != nil {
|
||||
onIdle()
|
||||
}
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) stopXvfb() {
|
||||
if vs.xvfb == nil || vs.xvfb.Process == nil {
|
||||
return
|
||||
}
|
||||
if err := syscall.Kill(-vs.xvfb.Process.Pid, syscall.SIGTERM); err != nil {
|
||||
vs.log.Debugf("SIGTERM xvfb group: %v", err)
|
||||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
if err := syscall.Kill(-vs.xvfb.Process.Pid, syscall.SIGKILL); err != nil {
|
||||
vs.log.Debugf("SIGKILL xvfb group: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) startDesktop() error {
|
||||
session := detectDesktopSession()
|
||||
|
||||
// Wrap the desktop command with dbus-launch to provide a session bus.
|
||||
// Without this, most desktop environments (XFCE, MATE, etc.) fail immediately.
|
||||
var args []string
|
||||
if _, err := exec.LookPath("dbus-launch"); err == nil {
|
||||
args = append([]string{"dbus-launch", "--exit-with-session"}, session...)
|
||||
} else {
|
||||
args = session
|
||||
}
|
||||
|
||||
vs.desktop = exec.Command(args[0], args[1:]...)
|
||||
vs.desktop.Dir = vs.user.HomeDir
|
||||
vs.desktop.Env = vs.buildUserEnv()
|
||||
vs.desktop.SysProcAttr = &syscall.SysProcAttr{
|
||||
Credential: &syscall.Credential{
|
||||
Uid: vs.uid,
|
||||
Gid: vs.gid,
|
||||
Groups: vs.groups,
|
||||
},
|
||||
Setsid: true,
|
||||
Pdeathsig: syscall.SIGTERM,
|
||||
}
|
||||
|
||||
if err := vs.desktop.Start(); err != nil {
|
||||
return fmt.Errorf("start desktop session (%v): %w", args, err)
|
||||
}
|
||||
vs.log.Infof("desktop session started: %v (pid=%d)", args, vs.desktop.Process.Pid)
|
||||
|
||||
go vs.monitorDesktop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// monitorDesktop waits for the desktop-session process to exit. When the user
|
||||
// logs out of GNOME/KDE/XFCE/etc., the session process terminates while Xvfb
|
||||
// keeps running, leaving a blank root window. Tear the whole virtual session
|
||||
// down so the next connect starts fresh with a login.
|
||||
func (vs *VirtualSession) monitorDesktop() {
|
||||
if err := vs.desktop.Wait(); err != nil {
|
||||
vs.log.Debugf("desktop session exited: %v", err)
|
||||
}
|
||||
|
||||
vs.mu.Lock()
|
||||
alreadyStopped := vs.stopped
|
||||
if !alreadyStopped {
|
||||
vs.log.Info("desktop session exited (logout), tearing down virtual session")
|
||||
vs.stopped = true
|
||||
if vs.idleTimer != nil {
|
||||
vs.idleTimer.Stop()
|
||||
vs.idleTimer = nil
|
||||
}
|
||||
if vs.injector != nil {
|
||||
vs.injector.Close()
|
||||
}
|
||||
vs.stopXvfb()
|
||||
}
|
||||
onIdle := vs.onIdle
|
||||
vs.mu.Unlock()
|
||||
|
||||
if !alreadyStopped && onIdle != nil {
|
||||
onIdle()
|
||||
}
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) stopDesktop() {
|
||||
if vs.desktop == nil || vs.desktop.Process == nil {
|
||||
return
|
||||
}
|
||||
if err := syscall.Kill(-vs.desktop.Process.Pid, syscall.SIGTERM); err != nil {
|
||||
vs.log.Debugf("SIGTERM desktop group: %v", err)
|
||||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
if err := syscall.Kill(-vs.desktop.Process.Pid, syscall.SIGKILL); err != nil {
|
||||
vs.log.Debugf("SIGKILL desktop group: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) buildUserEnv() []string {
|
||||
return []string{
|
||||
envDisplay + "=" + vs.display,
|
||||
"HOME=" + vs.user.HomeDir,
|
||||
"USER=" + vs.user.Username,
|
||||
"LOGNAME=" + vs.user.Username,
|
||||
"SHELL=" + getUserShell(vs.user.Uid),
|
||||
"PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
|
||||
"XDG_RUNTIME_DIR=/run/user/" + vs.user.Uid,
|
||||
"DBUS_SESSION_BUS_ADDRESS=unix:path=/run/user/" + vs.user.Uid + "/bus",
|
||||
}
|
||||
}
|
||||
|
||||
// detectDesktopSession discovers available desktop sessions from the standard
|
||||
// /usr/share/xsessions/*.desktop files (FreeDesktop standard, used by all
|
||||
// display managers). Falls back to a hardcoded list if no .desktop files found.
|
||||
func detectDesktopSession() []string {
|
||||
// Scan xsessions directories (Linux: /usr/share, FreeBSD: /usr/local/share).
|
||||
for _, dir := range []string{"/usr/share/xsessions", "/usr/local/share/xsessions"} {
|
||||
if cmd := findXSession(dir); cmd != nil {
|
||||
return cmd
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: try common session commands directly.
|
||||
fallbacks := [][]string{
|
||||
{"startplasma-x11"},
|
||||
{"gnome-session"},
|
||||
{"xfce4-session"},
|
||||
{"mate-session"},
|
||||
{"cinnamon-session"},
|
||||
{"openbox-session"},
|
||||
{"xterm"},
|
||||
}
|
||||
for _, s := range fallbacks {
|
||||
if _, err := exec.LookPath(s[0]); err == nil {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return []string{"xterm"}
|
||||
}
|
||||
|
||||
// sessionPriority defines preference order for desktop environments.
|
||||
// Lower number = higher priority. Unknown sessions get 100.
|
||||
var sessionPriority = map[string]int{
|
||||
"plasma": 1, // KDE
|
||||
"gnome": 2,
|
||||
"xfce": 3,
|
||||
"mate": 4,
|
||||
"cinnamon": 5,
|
||||
"lxqt": 6,
|
||||
"lxde": 7,
|
||||
"budgie": 8,
|
||||
"openbox": 20,
|
||||
"fluxbox": 21,
|
||||
"i3": 22,
|
||||
"xinit": 50, // generic user session
|
||||
"lightdm": 50,
|
||||
"default": 50,
|
||||
}
|
||||
|
||||
func findXSession(dir string) []string {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
candidates := collectSessionCandidates(dir, entries)
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
best := bestSessionCandidate(candidates)
|
||||
parts := strings.Fields(best.cmd)
|
||||
if _, err := exec.LookPath(parts[0]); err != nil {
|
||||
return nil
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
type sessionCandidate struct {
|
||||
cmd string
|
||||
priority int
|
||||
}
|
||||
|
||||
func collectSessionCandidates(dir string, entries []os.DirEntry) []sessionCandidate {
|
||||
var out []sessionCandidate
|
||||
for _, e := range entries {
|
||||
c, ok := parseSessionEntry(dir, e)
|
||||
if ok {
|
||||
out = append(out, c)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// parseSessionEntry reads a single .desktop file and extracts its Exec
|
||||
// command plus the priority hint to be used when picking the best session.
|
||||
func parseSessionEntry(dir string, e os.DirEntry) (sessionCandidate, bool) {
|
||||
if !strings.HasSuffix(e.Name(), ".desktop") {
|
||||
return sessionCandidate{}, false
|
||||
}
|
||||
data, err := os.ReadFile(filepath.Join(dir, e.Name()))
|
||||
if err != nil {
|
||||
return sessionCandidate{}, false
|
||||
}
|
||||
execCmd := extractExecLine(data)
|
||||
if execCmd == "" || execCmd == "default" {
|
||||
return sessionCandidate{}, false
|
||||
}
|
||||
return sessionCandidate{cmd: execCmd, priority: sessionPriorityFor(e.Name(), execCmd)}, true
|
||||
}
|
||||
|
||||
func extractExecLine(data []byte) string {
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
if strings.HasPrefix(line, "Exec=") {
|
||||
return strings.TrimSpace(strings.TrimPrefix(line, "Exec="))
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func sessionPriorityFor(name, execCmd string) int {
|
||||
pri := 100
|
||||
lower := strings.ToLower(name + " " + execCmd)
|
||||
for keyword, p := range sessionPriority {
|
||||
if strings.Contains(lower, keyword) && p < pri {
|
||||
pri = p
|
||||
}
|
||||
}
|
||||
return pri
|
||||
}
|
||||
|
||||
func bestSessionCandidate(candidates []sessionCandidate) sessionCandidate {
|
||||
best := candidates[0]
|
||||
for _, c := range candidates[1:] {
|
||||
if c.priority < best.priority {
|
||||
best = c
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
// findFreeDisplay scans for an unused X11 display number.
|
||||
func findFreeDisplay() (string, error) {
|
||||
for n := 50; n < 200; n++ {
|
||||
lockFile := fmt.Sprintf("/tmp/.X%d-lock", n)
|
||||
socketFile := fmt.Sprintf("%s/X%d", x11SocketDir, n)
|
||||
if _, err := os.Stat(lockFile); err == nil {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(socketFile); err == nil {
|
||||
continue
|
||||
}
|
||||
return fmt.Sprintf(":%d", n), nil
|
||||
}
|
||||
return "", fmt.Errorf("no free X11 display found (checked :50-:199)")
|
||||
}
|
||||
|
||||
// waitForPath polls until a filesystem path exists or the timeout expires.
|
||||
func waitForPath(path string, timeout time.Duration) error {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
return fmt.Errorf("timeout waiting for %s", path)
|
||||
}
|
||||
|
||||
// getUserShell returns the login shell for the given UID.
|
||||
func getUserShell(uid string) string {
|
||||
data, err := os.ReadFile("/etc/passwd")
|
||||
if err != nil {
|
||||
return "/bin/sh"
|
||||
}
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
fields := strings.Split(line, ":")
|
||||
if len(fields) >= 7 && fields[2] == uid {
|
||||
return fields[6]
|
||||
}
|
||||
}
|
||||
return "/bin/sh"
|
||||
}
|
||||
|
||||
// supplementaryGroups returns the supplementary group IDs for a user.
|
||||
func supplementaryGroups(u *user.User) ([]uint32, error) {
|
||||
gids, err := u.GroupIds()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var groups []uint32
|
||||
for _, g := range gids {
|
||||
id, err := strconv.ParseUint(g, 10, 32)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
groups = append(groups, uint32(id))
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// sessionManager tracks active virtual sessions by username.
|
||||
type sessionManager struct {
|
||||
mu sync.Mutex
|
||||
sessions map[string]*VirtualSession
|
||||
log *log.Entry
|
||||
}
|
||||
|
||||
func newSessionManager(logger *log.Entry) *sessionManager {
|
||||
return &sessionManager{
|
||||
sessions: make(map[string]*VirtualSession),
|
||||
log: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetOrCreate returns an existing virtual session or creates a new one with
|
||||
// the requested geometry. If a previous session for this user is alive it is
|
||||
// reused regardless of the requested geometry; the first caller's size wins
|
||||
// until the session idles out. If a previous session is stopped or its X
|
||||
// server died, it is replaced.
|
||||
func (sm *sessionManager) GetOrCreate(username string, width, height uint16) (vncSession, error) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
if vs, ok := sm.sessions[username]; ok {
|
||||
if vs.isAlive() {
|
||||
return vs, nil
|
||||
}
|
||||
sm.log.Infof("replacing dead virtual session for %s", username)
|
||||
vs.Stop()
|
||||
delete(sm.sessions, username)
|
||||
}
|
||||
|
||||
vs, err := StartVirtualSession(username, width, height, sm.log)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vs.onIdle = func() {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
if cur, ok := sm.sessions[username]; ok && cur == vs {
|
||||
delete(sm.sessions, username)
|
||||
sm.log.Infof("removed idle virtual session for %s", username)
|
||||
}
|
||||
}
|
||||
sm.sessions[username] = vs
|
||||
return vs, nil
|
||||
}
|
||||
|
||||
// hasDummyDriver checks common paths for the Xorg dummy video driver.
|
||||
func hasDummyDriver() bool {
|
||||
paths := []string{
|
||||
"/usr/lib/xorg/modules/drivers/dummy_drv.so", // Debian/Ubuntu
|
||||
"/usr/lib64/xorg/modules/drivers/dummy_drv.so", // RHEL/Fedora
|
||||
"/usr/local/lib/xorg/modules/drivers/dummy_drv.so", // FreeBSD
|
||||
"/usr/lib/x86_64-linux-gnu/xorg/modules/drivers/dummy_drv.so", // Debian multiarch
|
||||
}
|
||||
for _, p := range paths {
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// StopAll terminates all active virtual sessions.
|
||||
func (sm *sessionManager) StopAll() {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
for username, vs := range sm.sessions {
|
||||
vs.Stop()
|
||||
delete(sm.sessions, username)
|
||||
sm.log.Infof("stopped virtual session for %s", username)
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
@@ -21,6 +22,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/vnc"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -38,6 +40,7 @@ const (
|
||||
|
||||
func main() {
|
||||
js.Global().Set("NetBirdClient", js.FuncOf(netBirdClientConstructor))
|
||||
js.Global().Set("netbirdGenerateVNCSessionKey", createGenerateVNCSessionKeyMethod())
|
||||
|
||||
select {}
|
||||
}
|
||||
@@ -387,6 +390,156 @@ func createRDPProxyMethod(client *netbird.Client) js.Func {
|
||||
})
|
||||
}
|
||||
|
||||
// createGenerateVNCSessionKeyMethod returns a JS func that mints a fresh
|
||||
// X25519 keypair, stashes the private half inside wasm under a random
|
||||
// session id, and returns { publicKey, sessionId } to JS. The private
|
||||
// key never leaves the wasm heap.
|
||||
func createGenerateVNCSessionKeyMethod() js.Func {
|
||||
return js.FuncOf(func(_ js.Value, _ []js.Value) any {
|
||||
id, pub, err := vnc.NewSessionKey()
|
||||
if err != nil {
|
||||
return js.ValueOf(err.Error())
|
||||
}
|
||||
out := js.Global().Get("Object").New()
|
||||
out.Set("sessionId", id)
|
||||
out.Set("publicKey", base64.StdEncoding.EncodeToString(pub))
|
||||
return out
|
||||
})
|
||||
}
|
||||
|
||||
// createVNCProxyMethod creates the VNC proxy method for raw TCP-over-WebSocket bridging.
|
||||
// JS signature: createVNCProxy(hostname, port, mode?, username?, keySessionID?, sessionID?, width?, height?, peerPublicKey?)
|
||||
// mode: "attach" (default) or "session"
|
||||
// username: required when mode is "session"
|
||||
// keySessionID: handle for the wasm-resident session keypair minted by netbirdGenerateVNCSessionKey
|
||||
// sessionID: Windows session ID (0 = console/auto)
|
||||
// width/height: requested viewport size for session mode (0 = server default)
|
||||
// peerPublicKey: base64 X25519 static pubkey of the destination peer (required for auth)
|
||||
func createVNCProxyMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
params, err := parseVNCProxyArgs(args)
|
||||
if err != nil {
|
||||
if params.rejectViaPromise {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
})
|
||||
}
|
||||
return js.ValueOf(err.Error())
|
||||
}
|
||||
proxy := vnc.NewVNCProxy(client)
|
||||
return proxy.CreateProxy(vnc.ProxyRequest{
|
||||
Hostname: params.hostname,
|
||||
Port: params.port,
|
||||
Mode: params.mode,
|
||||
Username: params.username,
|
||||
SessionID: params.sessionID,
|
||||
Width: params.width,
|
||||
Height: params.height,
|
||||
PeerPublicKey: params.peerPublicKey,
|
||||
KeySessionID: params.keySessionID,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
type vncProxyParams struct {
|
||||
hostname string
|
||||
port string
|
||||
mode string
|
||||
username string
|
||||
keySessionID string
|
||||
sessionID uint32
|
||||
width uint16
|
||||
height uint16
|
||||
peerPublicKey string
|
||||
rejectViaPromise bool
|
||||
}
|
||||
|
||||
// parseVNCProxyArgs validates JS args for createVNCProxyMethod and returns
|
||||
// the parsed params plus the first validation error (nil on success).
|
||||
// vncProxyParams.rejectViaPromise tells the caller which JS-side response
|
||||
// path to use for the returned error.
|
||||
func parseVNCProxyArgs(args []js.Value) (vncProxyParams, error) {
|
||||
var p vncProxyParams
|
||||
if err := parseVNCProxyRequiredArgs(args, &p); err != nil {
|
||||
return p, err
|
||||
}
|
||||
if err := parseVNCProxyOptionalStrings(args, &p); err != nil {
|
||||
return p, err
|
||||
}
|
||||
if err := parseVNCProxyOptionalNumbers(args, &p); err != nil {
|
||||
return p, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func parseVNCProxyRequiredArgs(args []js.Value, p *vncProxyParams) error {
|
||||
if len(args) < 2 {
|
||||
return fmt.Errorf("hostname and port required")
|
||||
}
|
||||
if args[0].Type() != js.TypeString {
|
||||
p.rejectViaPromise = true
|
||||
return fmt.Errorf("hostname parameter must be a string")
|
||||
}
|
||||
if args[1].Type() != js.TypeString {
|
||||
p.rejectViaPromise = true
|
||||
return fmt.Errorf("port parameter must be a string")
|
||||
}
|
||||
p.hostname = args[0].String()
|
||||
p.port = args[1].String()
|
||||
p.mode = "attach"
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseVNCProxyOptionalStrings(args []js.Value, p *vncProxyParams) error {
|
||||
if len(args) > 2 && args[2].Type() == js.TypeString {
|
||||
p.mode = args[2].String()
|
||||
}
|
||||
if p.mode != "attach" && p.mode != "session" {
|
||||
p.rejectViaPromise = true
|
||||
return fmt.Errorf("invalid mode %q: expected \"attach\" or \"session\"", p.mode)
|
||||
}
|
||||
if len(args) > 3 && args[3].Type() == js.TypeString {
|
||||
p.username = args[3].String()
|
||||
}
|
||||
if len(args) > 4 && args[4].Type() == js.TypeString {
|
||||
p.keySessionID = args[4].String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseVNCProxyOptionalNumbers(args []js.Value, p *vncProxyParams) error {
|
||||
if len(args) > 5 && args[5].Type() == js.TypeNumber {
|
||||
v := args[5].Int()
|
||||
if v < 0 || v > 0xFFFFFFFF {
|
||||
p.rejectViaPromise = true
|
||||
return fmt.Errorf("invalid sessionID %d: must be 0..0xFFFFFFFF", v)
|
||||
}
|
||||
p.sessionID = uint32(v)
|
||||
}
|
||||
// width=0 / height=0 mean "use server default"; reject only out-of-range
|
||||
// non-zero values so attach mode (which omits width/height) still works.
|
||||
if len(args) > 6 && args[6].Type() == js.TypeNumber {
|
||||
v := args[6].Int()
|
||||
if v < 0 || v > 0xFFFF {
|
||||
p.rejectViaPromise = true
|
||||
return fmt.Errorf("invalid width %d: must be 0..65535", v)
|
||||
}
|
||||
p.width = uint16(v)
|
||||
}
|
||||
if len(args) > 7 && args[7].Type() == js.TypeNumber {
|
||||
v := args[7].Int()
|
||||
if v < 0 || v > 0xFFFF {
|
||||
p.rejectViaPromise = true
|
||||
return fmt.Errorf("invalid height %d: must be 0..65535", v)
|
||||
}
|
||||
p.height = uint16(v)
|
||||
}
|
||||
if len(args) > 8 && args[8].Type() == js.TypeString {
|
||||
p.peerPublicKey = args[8].String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getStatusOverview is a helper to get the status overview
|
||||
func getStatusOverview(client *netbird.Client) (nbstatus.OutputOverview, error) {
|
||||
fullStatus, err := client.Status()
|
||||
@@ -563,10 +716,10 @@ func createStartCaptureMethod(client *netbird.Client) js.Func {
|
||||
//
|
||||
// Usage from browser devtools console:
|
||||
//
|
||||
// await client.capture() // capture all packets
|
||||
// await client.capture("tcp") // capture with filter
|
||||
// await client.capture({filter: "host 10.0.0.1", verbose: true})
|
||||
// client.stopCapture() // stop and print stats
|
||||
// await netbird.capture() // capture all packets
|
||||
// await netbird.capture("tcp") // capture with filter
|
||||
// await netbird.capture({filter: "host 10.0.0.1", verbose: true})
|
||||
// netbird.stopCapture() // stop and print stats
|
||||
func captureMethods(client *netbird.Client) (startFn, stopFn js.Func) {
|
||||
var mu sync.Mutex
|
||||
var active *wasmcapture.Handle
|
||||
@@ -594,7 +747,7 @@ func captureMethods(client *netbird.Client) (startFn, stopFn js.Func) {
|
||||
active = h
|
||||
|
||||
console := js.Global().Get("console")
|
||||
console.Call("log", "[capture] started, call client.stopCapture() to stop")
|
||||
console.Call("log", "[capture] started, call netbird.stopCapture() to stop")
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
@@ -677,6 +830,7 @@ func createClientObject(client *netbird.Client) js.Value {
|
||||
obj["createSSHConnection"] = createSSHMethod(client)
|
||||
obj["proxyRequest"] = createProxyRequestMethod(client)
|
||||
obj["createRDPProxy"] = createRDPProxyMethod(client)
|
||||
obj["createVNCProxy"] = createVNCProxyMethod(client)
|
||||
obj["status"] = createStatusMethod(client)
|
||||
obj["statusSummary"] = createStatusSummaryMethod(client)
|
||||
obj["statusDetail"] = createStatusDetailMethod(client)
|
||||
|
||||
586
client/wasm/internal/vnc/proxy.go
Normal file
586
client/wasm/internal/vnc/proxy.go
Normal file
@@ -0,0 +1,586 @@
|
||||
//go:build js
|
||||
|
||||
package vnc
|
||||
|
||||
import (
|
||||
"context"
|
||||
crand "crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var cryptoRandRead = crand.Read
|
||||
|
||||
// vncIdentityMagic mirrors the server side in client/vnc/server/server.go.
|
||||
var vncIdentityMagic = []byte("NBV3")
|
||||
|
||||
// Noise_IK_25519_ChaChaPoly_SHA256 message sizes (with empty payloads).
|
||||
const (
|
||||
noiseInitiatorMsgLen = 96
|
||||
noiseResponderMsgLen = 48
|
||||
)
|
||||
|
||||
var vncNoiseSuite = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
|
||||
// sessionKeyStore retains per-session X25519 keypairs so the JS layer
|
||||
// only sees an opaque session id + the public key; the private key never
|
||||
// leaves wasm.
|
||||
var sessionKeyStore = struct {
|
||||
mu sync.Mutex
|
||||
keys map[string]noise.DHKey
|
||||
}{keys: map[string]noise.DHKey{}}
|
||||
|
||||
// NewSessionKey mints an X25519 keypair, stores the private half under a
|
||||
// fresh random session id, and returns (id, pubkey).
|
||||
func NewSessionKey() (string, []byte, error) {
|
||||
kp, err := noise.DH25519.GenerateKeypair(nil)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("generate keypair: %w", err)
|
||||
}
|
||||
idBytes := make([]byte, 16)
|
||||
if _, err := cryptoRandRead(idBytes); err != nil {
|
||||
return "", nil, fmt.Errorf("session id randomness: %w", err)
|
||||
}
|
||||
id := base64.RawURLEncoding.EncodeToString(idBytes)
|
||||
sessionKeyStore.mu.Lock()
|
||||
sessionKeyStore.keys[id] = kp
|
||||
sessionKeyStore.mu.Unlock()
|
||||
return id, kp.Public, nil
|
||||
}
|
||||
|
||||
// consumeSessionKey atomically retrieves and removes the keypair for id.
|
||||
// A session handle is single-use; combining lookup and delete under one
|
||||
// critical section prevents concurrent callers from observing the same key.
|
||||
func consumeSessionKey(id string) (noise.DHKey, bool) {
|
||||
sessionKeyStore.mu.Lock()
|
||||
defer sessionKeyStore.mu.Unlock()
|
||||
kp, ok := sessionKeyStore.keys[id]
|
||||
if ok {
|
||||
delete(sessionKeyStore.keys, id)
|
||||
}
|
||||
return kp, ok
|
||||
}
|
||||
|
||||
const (
|
||||
vncProxyHost = "vnc.proxy.local"
|
||||
vncProxyScheme = "ws"
|
||||
vncDialTimeout = 15 * time.Second
|
||||
|
||||
// Connection modes matching server/server.go constants.
|
||||
modeAttach byte = 0
|
||||
modeSession byte = 1
|
||||
|
||||
// WebSocket close codes the dashboard branches on. Codes 1000-1015
|
||||
// are reserved by RFC 6455; 4000-4999 are application-defined.
|
||||
wsCodeNormal = 1000
|
||||
wsCodeAbnormal = 1006
|
||||
wsCodeDialTimeout = 4001
|
||||
wsCodeDialFailure = 4002
|
||||
wsCodeSessionSetup = 4003
|
||||
wsCodeTransport = 4004
|
||||
)
|
||||
|
||||
// VNCProxy bridges WebSocket connections from noVNC in the browser
|
||||
// to TCP VNC server connections through the NetBird tunnel.
|
||||
type vncNBClient interface {
|
||||
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
type VNCProxy struct {
|
||||
nbClient vncNBClient
|
||||
activeConnections map[string]*vncConnection
|
||||
destinations map[string]vncDestination
|
||||
// pendingHandlers holds the js.Func for handleVNCWebSocket_<id> between
|
||||
// CreateProxy and handleWebSocketConnection so we can move it onto the
|
||||
// vncConnection for later release.
|
||||
pendingHandlers map[string]js.Func
|
||||
mu sync.Mutex
|
||||
nextID atomic.Uint64
|
||||
}
|
||||
|
||||
type vncDestination struct {
|
||||
address string
|
||||
mode byte
|
||||
username string
|
||||
sessionPriv []byte
|
||||
sessionPub []byte
|
||||
sessionID uint32
|
||||
width uint16
|
||||
height uint16
|
||||
peerPubKey []byte
|
||||
}
|
||||
|
||||
type vncConnection struct {
|
||||
id string
|
||||
destination vncDestination
|
||||
mu sync.Mutex
|
||||
vncConn net.Conn
|
||||
wsHandlers js.Value
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
// Go-side callbacks exposed to JS. js.FuncOf pins the Go closure in a
|
||||
// global handle map and MUST be released, otherwise every connection
|
||||
// leaks the Go memory the closure captures.
|
||||
wsHandlerFn js.Func
|
||||
onMessageFn js.Func
|
||||
onCloseFn js.Func
|
||||
}
|
||||
|
||||
// NewVNCProxy creates a new VNC proxy.
|
||||
func NewVNCProxy(client vncNBClient) *VNCProxy {
|
||||
return &VNCProxy{
|
||||
nbClient: client,
|
||||
activeConnections: make(map[string]*vncConnection),
|
||||
}
|
||||
}
|
||||
|
||||
// ProxyRequest bundles the per-call parameters for CreateProxy so the JS
|
||||
// boundary doesn't drown callers in a wide positional argument list.
|
||||
type ProxyRequest struct {
|
||||
Hostname string
|
||||
Port string
|
||||
Mode string
|
||||
Username string
|
||||
SessionID uint32
|
||||
Width uint16
|
||||
Height uint16
|
||||
// PeerPublicKey is the destination peer's base64 X25519 public key,
|
||||
// used as the responder static in the Noise_IK handshake.
|
||||
PeerPublicKey string
|
||||
// KeySessionID is the handle returned by generateVNCSessionKey. The
|
||||
// matching private key is looked up inside wasm and never crosses
|
||||
// the JS boundary.
|
||||
KeySessionID string
|
||||
}
|
||||
|
||||
// CreateProxy creates a new proxy endpoint for the given VNC destination.
|
||||
// req.Mode is "attach" (capture current display) or "session" (virtual session).
|
||||
// req.Username is required for session mode. req.Width/Height request the
|
||||
// virtual display geometry for session mode; 0 means use the server default.
|
||||
// Returns a JS Promise that resolves to the WebSocket proxy URL.
|
||||
func (p *VNCProxy) CreateProxy(req ProxyRequest) js.Value {
|
||||
hostname, port, mode, username := req.Hostname, req.Port, req.Mode, req.Username
|
||||
sessionID, width, height := req.SessionID, req.Width, req.Height
|
||||
address := net.JoinHostPort(hostname, port)
|
||||
|
||||
var m byte
|
||||
if mode == "session" {
|
||||
m = modeSession
|
||||
}
|
||||
|
||||
dest := vncDestination{
|
||||
address: address,
|
||||
mode: m,
|
||||
username: username,
|
||||
sessionID: sessionID,
|
||||
width: width,
|
||||
height: height,
|
||||
}
|
||||
if req.KeySessionID != "" {
|
||||
kp, ok := consumeSessionKey(req.KeySessionID)
|
||||
if !ok {
|
||||
return rejectedPromise("unknown VNC session id")
|
||||
}
|
||||
dest.sessionPriv = kp.Private
|
||||
dest.sessionPub = kp.Public
|
||||
pub, err := decodePeerPubKey(req.PeerPublicKey)
|
||||
if err != nil {
|
||||
return rejectedPromise(fmt.Sprintf("invalid peer public key: %v", err))
|
||||
}
|
||||
dest.peerPubKey = pub
|
||||
}
|
||||
return p.newProxyPromise(address, mode, username, dest)
|
||||
}
|
||||
|
||||
// decodePeerPubKey parses a base64-encoded 32-byte X25519 public key.
|
||||
func decodePeerPubKey(b64 string) ([]byte, error) {
|
||||
if b64 == "" {
|
||||
return nil, errors.New("peer public key missing")
|
||||
}
|
||||
raw, err := base64.StdEncoding.DecodeString(b64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("base64 decode: %w", err)
|
||||
}
|
||||
if len(raw) != 32 {
|
||||
return nil, fmt.Errorf("expected 32 bytes, got %d", len(raw))
|
||||
}
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
// rejectedPromise returns a resolved Promise carrying msg as an error
|
||||
// string, mirroring how CreateProxy reports earlier validation failures.
|
||||
func rejectedPromise(msg string) js.Value {
|
||||
promise := js.Global().Get("Promise")
|
||||
return promise.Call("resolve", js.ValueOf(msg))
|
||||
}
|
||||
|
||||
// newProxyPromise wraps the JS Promise creation + executor lifecycle so
|
||||
// CreateProxy stays a thin parameter-bundling entrypoint.
|
||||
func (p *VNCProxy) newProxyPromise(address, mode, username string, dest vncDestination) js.Value {
|
||||
|
||||
var executor js.Func
|
||||
executor = js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
resolve := args[0]
|
||||
|
||||
go func() {
|
||||
defer executor.Release()
|
||||
|
||||
proxyID := fmt.Sprintf("vnc_proxy_%d", p.nextID.Add(1))
|
||||
|
||||
p.mu.Lock()
|
||||
if p.destinations == nil {
|
||||
p.destinations = make(map[string]vncDestination)
|
||||
}
|
||||
p.destinations[proxyID] = dest
|
||||
p.mu.Unlock()
|
||||
|
||||
proxyURL := fmt.Sprintf("%s://%s/%s", vncProxyScheme, vncProxyHost, proxyID)
|
||||
|
||||
handlerFn := js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: requires WebSocket argument")
|
||||
}
|
||||
p.handleWebSocketConnection(args[0], proxyID)
|
||||
return nil
|
||||
})
|
||||
p.mu.Lock()
|
||||
if p.pendingHandlers == nil {
|
||||
p.pendingHandlers = make(map[string]js.Func)
|
||||
}
|
||||
p.pendingHandlers[proxyID] = handlerFn
|
||||
p.mu.Unlock()
|
||||
js.Global().Set(fmt.Sprintf("handleVNCWebSocket_%s", proxyID), handlerFn)
|
||||
|
||||
log.Infof("created VNC proxy: %s -> %s (mode=%s, user=%s)", proxyURL, address, mode, username)
|
||||
resolve.Invoke(proxyURL)
|
||||
}()
|
||||
|
||||
return nil
|
||||
})
|
||||
return js.Global().Get("Promise").New(executor)
|
||||
}
|
||||
|
||||
func (p *VNCProxy) handleWebSocketConnection(ws js.Value, proxyID string) {
|
||||
p.mu.Lock()
|
||||
dest, ok := p.destinations[proxyID]
|
||||
handlerFn := p.pendingHandlers[proxyID]
|
||||
delete(p.pendingHandlers, proxyID)
|
||||
p.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
log.Errorf("no destination for VNC proxy %s", proxyID)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
conn := &vncConnection{
|
||||
id: proxyID,
|
||||
destination: dest,
|
||||
wsHandlers: ws,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
wsHandlerFn: handlerFn,
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
p.activeConnections[proxyID] = conn
|
||||
p.mu.Unlock()
|
||||
|
||||
p.setupWebSocketHandlers(ws, conn)
|
||||
go p.connectToVNC(conn)
|
||||
|
||||
log.Infof("VNC proxy WebSocket connection established for %s", proxyID)
|
||||
}
|
||||
|
||||
func (p *VNCProxy) setupWebSocketHandlers(ws js.Value, conn *vncConnection) {
|
||||
conn.onMessageFn = js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return nil
|
||||
}
|
||||
data := args[0]
|
||||
go p.handleWebSocketMessage(conn, data)
|
||||
return nil
|
||||
})
|
||||
ws.Set("onGoMessage", conn.onMessageFn)
|
||||
|
||||
conn.onCloseFn = js.FuncOf(func(_ js.Value, _ []js.Value) any {
|
||||
log.Debug("VNC WebSocket closed by JavaScript")
|
||||
conn.cancel()
|
||||
return nil
|
||||
})
|
||||
ws.Set("onGoClose", conn.onCloseFn)
|
||||
}
|
||||
|
||||
func (p *VNCProxy) handleWebSocketMessage(conn *vncConnection, data js.Value) {
|
||||
if !data.InstanceOf(js.Global().Get("Uint8Array")) {
|
||||
return
|
||||
}
|
||||
|
||||
length := data.Get("length").Int()
|
||||
buf := make([]byte, length)
|
||||
js.CopyBytesToGo(buf, data)
|
||||
|
||||
conn.mu.Lock()
|
||||
vncConn := conn.vncConn
|
||||
conn.mu.Unlock()
|
||||
|
||||
if vncConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := vncConn.Write(buf); err != nil {
|
||||
log.Debugf("write to VNC server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *VNCProxy) connectToVNC(conn *vncConnection) {
|
||||
ctx, cancel := context.WithTimeout(conn.ctx, vncDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
vncConn, err := p.nbClient.Dial(ctx, "tcp", conn.destination.address)
|
||||
if err != nil {
|
||||
log.Errorf("VNC connect to %s: %v", conn.destination.address, err)
|
||||
// Close the WebSocket so noVNC fires a disconnect event.
|
||||
code := wsCodeDialFailure
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
code = wsCodeDialTimeout
|
||||
}
|
||||
if conn.wsHandlers.Get("close").Truthy() {
|
||||
conn.wsHandlers.Call("close", code, fmt.Sprintf("connect to peer: %v", err))
|
||||
}
|
||||
p.cleanupConnection(conn)
|
||||
return
|
||||
}
|
||||
conn.mu.Lock()
|
||||
conn.vncConn = vncConn
|
||||
conn.mu.Unlock()
|
||||
|
||||
// Send the NetBird VNC session header before the RFB handshake.
|
||||
if err := p.sendSessionHeader(vncConn, conn.destination); err != nil {
|
||||
log.Errorf("send VNC session header: %v", err)
|
||||
if conn.wsHandlers.Get("close").Truthy() {
|
||||
conn.wsHandlers.Call("close", wsCodeSessionSetup, fmt.Sprintf("send session header: %v", err))
|
||||
}
|
||||
p.cleanupConnection(conn)
|
||||
return
|
||||
}
|
||||
|
||||
// WS→TCP is handled by the onGoMessage handler set in setupWebSocketHandlers,
|
||||
// which writes directly to the VNC connection as data arrives from JS.
|
||||
// Only the TCP→WS direction needs a read loop here.
|
||||
go p.forwardConnToWS(conn)
|
||||
|
||||
<-conn.ctx.Done()
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
// sendSessionHeader writes the NetBird VNC connection header: mode +
|
||||
// username prefix, an optional Noise_IK handshake that authenticates the
|
||||
// client and the server, then the trailing sessionID / width / height
|
||||
// fields the daemon needs once auth is settled.
|
||||
func (p *VNCProxy) sendSessionHeader(conn net.Conn, dest vncDestination) error {
|
||||
usernameBytes := []byte(dest.username)
|
||||
if len(usernameBytes) > 0xFFFF {
|
||||
return fmt.Errorf("username too long: %d bytes (max %d)", len(usernameBytes), 0xFFFF)
|
||||
}
|
||||
prefix := make([]byte, 3+len(usernameBytes))
|
||||
prefix[0] = dest.mode
|
||||
prefix[1] = byte(len(usernameBytes) >> 8)
|
||||
prefix[2] = byte(len(usernameBytes))
|
||||
copy(prefix[3:], usernameBytes)
|
||||
if err := writeAll(conn, prefix); err != nil {
|
||||
return fmt.Errorf("write header prefix: %w", err)
|
||||
}
|
||||
|
||||
if dest.sessionPriv == nil {
|
||||
return p.writeHeaderTail(conn, dest)
|
||||
}
|
||||
if err := p.runNoiseHandshake(conn, dest); err != nil {
|
||||
return fmt.Errorf("noise handshake: %w", err)
|
||||
}
|
||||
return p.writeHeaderTail(conn, dest)
|
||||
}
|
||||
|
||||
// writeHeaderTail writes the post-auth trailing fields (sessionID,
|
||||
// width, height) the daemon reads regardless of whether the Noise
|
||||
// handshake was performed.
|
||||
func (p *VNCProxy) writeHeaderTail(conn net.Conn, dest vncDestination) error {
|
||||
tail := make([]byte, 4+4)
|
||||
tail[0] = byte(dest.sessionID >> 24)
|
||||
tail[1] = byte(dest.sessionID >> 16)
|
||||
tail[2] = byte(dest.sessionID >> 8)
|
||||
tail[3] = byte(dest.sessionID)
|
||||
tail[4] = byte(dest.width >> 8)
|
||||
tail[5] = byte(dest.width)
|
||||
tail[6] = byte(dest.height >> 8)
|
||||
tail[7] = byte(dest.height)
|
||||
if err := writeAll(conn, tail); err != nil {
|
||||
return fmt.Errorf("write header tail: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// runNoiseHandshake performs the initiator side of a Noise_IK handshake
|
||||
// against the destination daemon. The session keypair authenticates the
|
||||
// client; the daemon's pre-known peer pubkey authenticates the server.
|
||||
func (p *VNCProxy) runNoiseHandshake(conn net.Conn, dest vncDestination) error {
|
||||
state, err := noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: vncNoiseSuite,
|
||||
Pattern: noise.HandshakeIK,
|
||||
Initiator: true,
|
||||
StaticKeypair: noise.DHKey{Private: dest.sessionPriv, Public: dest.sessionPub},
|
||||
PeerStatic: dest.peerPubKey,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("noise initiator init: %w", err)
|
||||
}
|
||||
msg1, _, _, err := state.WriteMessage(nil, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("noise write msg1: %w", err)
|
||||
}
|
||||
out := make([]byte, 0, len(vncIdentityMagic)+len(msg1))
|
||||
out = append(out, vncIdentityMagic...)
|
||||
out = append(out, msg1...)
|
||||
if err := writeAll(conn, out); err != nil {
|
||||
return fmt.Errorf("send noise msg1: %w", err)
|
||||
}
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
return fmt.Errorf("set noise deadline: %w", err)
|
||||
}
|
||||
defer conn.SetReadDeadline(time.Time{}) //nolint:errcheck
|
||||
msg2 := make([]byte, noiseResponderMsgLen)
|
||||
if _, err := io.ReadFull(conn, msg2); err != nil {
|
||||
return fmt.Errorf("read noise msg2: %w", err)
|
||||
}
|
||||
if _, _, _, err := state.ReadMessage(nil, msg2); err != nil {
|
||||
return fmt.Errorf("noise read msg2: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeAll(conn net.Conn, buf []byte) error {
|
||||
for off := 0; off < len(buf); {
|
||||
n, err := conn.Write(buf[off:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
off += n
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *VNCProxy) forwardConnToWS(conn *vncConnection) {
|
||||
buf := make([]byte, 32*1024)
|
||||
|
||||
for {
|
||||
if conn.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
vc, ok := conn.snapshotVNC()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := vc.SetReadDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
||||
log.Debugf("set VNC read deadline: %v", err)
|
||||
}
|
||||
n, err := vc.Read(buf)
|
||||
if err != nil {
|
||||
if p.handleConnReadError(conn, err) {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if n > 0 {
|
||||
p.sendToWebSocket(conn, buf[:n])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// snapshotVNC returns the current vncConn under conn.mu, with ok=false when
|
||||
// the connection has already been cleaned up.
|
||||
func (c *vncConnection) snapshotVNC() (net.Conn, bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.vncConn == nil {
|
||||
return nil, false
|
||||
}
|
||||
return c.vncConn, true
|
||||
}
|
||||
|
||||
// handleConnReadError classifies an error from the VNC read loop. Returns
|
||||
// true if the caller should exit and trigger the cleanup path. A read
|
||||
// timeout counts as a fatal error: in a healthy session the server emits
|
||||
// empty FramebufferUpdate responses several times per second, so a full
|
||||
// idleReadDeadline of silence means the peer is dead (process gone,
|
||||
// machine off, network partition) and the in-browser TCP stack will
|
||||
// never surface that on its own.
|
||||
func (p *VNCProxy) handleConnReadError(conn *vncConnection, err error) bool {
|
||||
if conn.ctx.Err() != nil {
|
||||
return true
|
||||
}
|
||||
if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
|
||||
log.Debugf("VNC read deadline expired; treating peer as dead")
|
||||
} else if err != io.EOF {
|
||||
log.Debugf("read from VNC connection: %v", err)
|
||||
}
|
||||
if conn.wsHandlers.Get("close").Truthy() {
|
||||
conn.wsHandlers.Call("close", wsCodeTransport, "VNC connection lost")
|
||||
}
|
||||
conn.cancel()
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *VNCProxy) sendToWebSocket(conn *vncConnection, data []byte) {
|
||||
if conn.wsHandlers.Get("receiveFromGo").Truthy() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||
js.CopyBytesToJS(uint8Array, data)
|
||||
conn.wsHandlers.Call("receiveFromGo", uint8Array.Get("buffer"))
|
||||
} else if conn.wsHandlers.Get("send").Truthy() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||
js.CopyBytesToJS(uint8Array, data)
|
||||
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *VNCProxy) cleanupConnection(conn *vncConnection) {
|
||||
log.Debugf("cleaning up VNC connection %s", conn.id)
|
||||
conn.cancel()
|
||||
|
||||
conn.mu.Lock()
|
||||
vncConn := conn.vncConn
|
||||
conn.vncConn = nil
|
||||
conn.mu.Unlock()
|
||||
|
||||
if vncConn != nil {
|
||||
if err := vncConn.Close(); err != nil {
|
||||
log.Debugf("close VNC connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the global JS handler registered in CreateProxy.
|
||||
globalName := fmt.Sprintf("handleVNCWebSocket_%s", conn.id)
|
||||
js.Global().Delete(globalName)
|
||||
|
||||
// Release all js.Func handles; js.FuncOf pins the Go closure and the
|
||||
// allocations it captures until Release is called.
|
||||
conn.wsHandlerFn.Release()
|
||||
conn.onMessageFn.Release()
|
||||
conn.onCloseFn.Release()
|
||||
|
||||
p.mu.Lock()
|
||||
delete(p.activeConnections, conn.id)
|
||||
delete(p.destinations, conn.id)
|
||||
delete(p.pendingHandlers, conn.id)
|
||||
p.mu.Unlock()
|
||||
}
|
||||
@@ -1,28 +1,2 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
if ! which realpath > /dev/null 2>&1
|
||||
then
|
||||
echo realpath is not installed
|
||||
echo run: brew install coreutils
|
||||
exit 1
|
||||
fi
|
||||
|
||||
old_pwd=$(pwd)
|
||||
script_path=$(dirname "$(realpath "$0")")
|
||||
cd "$script_path/.."
|
||||
|
||||
repo_root=$(git rev-parse --show-toplevel)
|
||||
# shellcheck source=/dev/null
|
||||
. "$repo_root/proto-tools.env"
|
||||
|
||||
actual_protoc=$(protoc --version | awk '{print $2}')
|
||||
if [[ "$actual_protoc" != "$PROTOC_VERSION" ]]; then
|
||||
echo "ERROR: protoc version $actual_protoc differs from pinned $PROTOC_VERSION" >&2
|
||||
echo "Install protoc $PROTOC_VERSION from https://github.com/protocolbuffers/protobuf/releases" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
go install "google.golang.org/protobuf/cmd/protoc-gen-go@${PROTOC_GEN_GO_VERSION}"
|
||||
protoc -I testprotos/ testprotos/testproto.proto --go_out=.
|
||||
cd "$old_pwd"
|
||||
protoc -I testprotos/ testprotos/testproto.proto --go_out=.
|
||||
@@ -1,7 +1,7 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.36.6
|
||||
// protoc v6.33.1
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v3.12.4
|
||||
// source: testproto.proto
|
||||
|
||||
package testprotos
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
unsafe "unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -22,17 +21,20 @@ const (
|
||||
)
|
||||
|
||||
type TestMessage struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Body string `protobuf:"bytes,1,opt,name=body,proto3" json:"body,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Body string `protobuf:"bytes,1,opt,name=body,proto3" json:"body,omitempty"`
|
||||
}
|
||||
|
||||
func (x *TestMessage) Reset() {
|
||||
*x = TestMessage{}
|
||||
mi := &file_testproto_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_testproto_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *TestMessage) String() string {
|
||||
@@ -43,7 +45,7 @@ func (*TestMessage) ProtoMessage() {}
|
||||
|
||||
func (x *TestMessage) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_testproto_proto_msgTypes[0]
|
||||
if x != nil {
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
@@ -67,27 +69,29 @@ func (x *TestMessage) GetBody() string {
|
||||
|
||||
var File_testproto_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_testproto_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x0ftestproto.proto\x12\n" +
|
||||
"testprotos\"!\n" +
|
||||
"\vTestMessage\x12\x12\n" +
|
||||
"\x04body\x18\x01 \x01(\tR\x04bodyB\rZ\v/testprotosb\x06proto3"
|
||||
var file_testproto_proto_rawDesc = []byte{
|
||||
0x0a, 0x0f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x12, 0x0a, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x22, 0x21, 0x0a,
|
||||
0x0b, 0x54, 0x65, 0x73, 0x74, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x12, 0x0a, 0x04,
|
||||
0x62, 0x6f, 0x64, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x62, 0x6f, 0x64, 0x79,
|
||||
0x42, 0x0d, 0x5a, 0x0b, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x73, 0x62,
|
||||
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
file_testproto_proto_rawDescOnce sync.Once
|
||||
file_testproto_proto_rawDescData []byte
|
||||
file_testproto_proto_rawDescData = file_testproto_proto_rawDesc
|
||||
)
|
||||
|
||||
func file_testproto_proto_rawDescGZIP() []byte {
|
||||
file_testproto_proto_rawDescOnce.Do(func() {
|
||||
file_testproto_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_testproto_proto_rawDesc), len(file_testproto_proto_rawDesc)))
|
||||
file_testproto_proto_rawDescData = protoimpl.X.CompressGZIP(file_testproto_proto_rawDescData)
|
||||
})
|
||||
return file_testproto_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_testproto_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
|
||||
var file_testproto_proto_goTypes = []any{
|
||||
var file_testproto_proto_goTypes = []interface{}{
|
||||
(*TestMessage)(nil), // 0: testprotos.TestMessage
|
||||
}
|
||||
var file_testproto_proto_depIdxs = []int32{
|
||||
@@ -103,11 +107,25 @@ func file_testproto_proto_init() {
|
||||
if File_testproto_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_testproto_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*TestMessage); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_testproto_proto_rawDesc), len(file_testproto_proto_rawDesc)),
|
||||
RawDescriptor: file_testproto_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 1,
|
||||
NumExtensions: 0,
|
||||
@@ -118,6 +136,7 @@ func file_testproto_proto_init() {
|
||||
MessageInfos: file_testproto_proto_msgTypes,
|
||||
}.Build()
|
||||
File_testproto_proto = out.File
|
||||
file_testproto_proto_rawDesc = nil
|
||||
file_testproto_proto_goTypes = nil
|
||||
file_testproto_proto_depIdxs = nil
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.36.6
|
||||
// protoc v6.33.1
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v3.21.9
|
||||
// source: flow.proto
|
||||
|
||||
package proto
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
unsafe "unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -126,24 +125,27 @@ func (Direction) EnumDescriptor() ([]byte, []int) {
|
||||
}
|
||||
|
||||
type FlowEvent struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
// Unique client event identifier
|
||||
EventId []byte `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"`
|
||||
// When the event occurred
|
||||
Timestamp *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
|
||||
// Public key of the sending peer
|
||||
PublicKey []byte `protobuf:"bytes,3,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"`
|
||||
FlowFields *FlowFields `protobuf:"bytes,4,opt,name=flow_fields,json=flowFields,proto3" json:"flow_fields,omitempty"`
|
||||
IsInitiator bool `protobuf:"varint,5,opt,name=isInitiator,proto3" json:"isInitiator,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
PublicKey []byte `protobuf:"bytes,3,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"`
|
||||
FlowFields *FlowFields `protobuf:"bytes,4,opt,name=flow_fields,json=flowFields,proto3" json:"flow_fields,omitempty"`
|
||||
IsInitiator bool `protobuf:"varint,5,opt,name=isInitiator,proto3" json:"isInitiator,omitempty"`
|
||||
}
|
||||
|
||||
func (x *FlowEvent) Reset() {
|
||||
*x = FlowEvent{}
|
||||
mi := &file_flow_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_flow_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *FlowEvent) String() string {
|
||||
@@ -154,7 +156,7 @@ func (*FlowEvent) ProtoMessage() {}
|
||||
|
||||
func (x *FlowEvent) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_flow_proto_msgTypes[0]
|
||||
if x != nil {
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
@@ -205,19 +207,22 @@ func (x *FlowEvent) GetIsInitiator() bool {
|
||||
}
|
||||
|
||||
type FlowEventAck struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
// Unique client event identifier that has been ack'ed
|
||||
EventId []byte `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"`
|
||||
IsInitiator bool `protobuf:"varint,2,opt,name=isInitiator,proto3" json:"isInitiator,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
// Unique client event identifier that has been ack'ed
|
||||
EventId []byte `protobuf:"bytes,1,opt,name=event_id,json=eventId,proto3" json:"event_id,omitempty"`
|
||||
IsInitiator bool `protobuf:"varint,2,opt,name=isInitiator,proto3" json:"isInitiator,omitempty"`
|
||||
}
|
||||
|
||||
func (x *FlowEventAck) Reset() {
|
||||
*x = FlowEventAck{}
|
||||
mi := &file_flow_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_flow_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *FlowEventAck) String() string {
|
||||
@@ -228,7 +233,7 @@ func (*FlowEventAck) ProtoMessage() {}
|
||||
|
||||
func (x *FlowEventAck) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_flow_proto_msgTypes[1]
|
||||
if x != nil {
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
@@ -258,7 +263,10 @@ func (x *FlowEventAck) GetIsInitiator() bool {
|
||||
}
|
||||
|
||||
type FlowFields struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
// Unique client flow session identifier
|
||||
FlowId []byte `protobuf:"bytes,1,opt,name=flow_id,json=flowId,proto3" json:"flow_id,omitempty"`
|
||||
// Flow type
|
||||
@@ -275,7 +283,7 @@ type FlowFields struct {
|
||||
DestIp []byte `protobuf:"bytes,7,opt,name=dest_ip,json=destIp,proto3" json:"dest_ip,omitempty"`
|
||||
// Layer 4 -specific information
|
||||
//
|
||||
// Types that are valid to be assigned to ConnectionInfo:
|
||||
// Types that are assignable to ConnectionInfo:
|
||||
//
|
||||
// *FlowFields_PortInfo
|
||||
// *FlowFields_IcmpInfo
|
||||
@@ -289,15 +297,15 @@ type FlowFields struct {
|
||||
// Resource ID
|
||||
SourceResourceId []byte `protobuf:"bytes,14,opt,name=source_resource_id,json=sourceResourceId,proto3" json:"source_resource_id,omitempty"`
|
||||
DestResourceId []byte `protobuf:"bytes,15,opt,name=dest_resource_id,json=destResourceId,proto3" json:"dest_resource_id,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *FlowFields) Reset() {
|
||||
*x = FlowFields{}
|
||||
mi := &file_flow_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_flow_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *FlowFields) String() string {
|
||||
@@ -308,7 +316,7 @@ func (*FlowFields) ProtoMessage() {}
|
||||
|
||||
func (x *FlowFields) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_flow_proto_msgTypes[2]
|
||||
if x != nil {
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
@@ -372,27 +380,23 @@ func (x *FlowFields) GetDestIp() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *FlowFields) GetConnectionInfo() isFlowFields_ConnectionInfo {
|
||||
if x != nil {
|
||||
return x.ConnectionInfo
|
||||
func (m *FlowFields) GetConnectionInfo() isFlowFields_ConnectionInfo {
|
||||
if m != nil {
|
||||
return m.ConnectionInfo
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *FlowFields) GetPortInfo() *PortInfo {
|
||||
if x != nil {
|
||||
if x, ok := x.ConnectionInfo.(*FlowFields_PortInfo); ok {
|
||||
return x.PortInfo
|
||||
}
|
||||
if x, ok := x.GetConnectionInfo().(*FlowFields_PortInfo); ok {
|
||||
return x.PortInfo
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *FlowFields) GetIcmpInfo() *ICMPInfo {
|
||||
if x != nil {
|
||||
if x, ok := x.ConnectionInfo.(*FlowFields_IcmpInfo); ok {
|
||||
return x.IcmpInfo
|
||||
}
|
||||
if x, ok := x.GetConnectionInfo().(*FlowFields_IcmpInfo); ok {
|
||||
return x.IcmpInfo
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -459,18 +463,21 @@ func (*FlowFields_IcmpInfo) isFlowFields_ConnectionInfo() {}
|
||||
|
||||
// TCP/UDP port information
|
||||
type PortInfo struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
SourcePort uint32 `protobuf:"varint,1,opt,name=source_port,json=sourcePort,proto3" json:"source_port,omitempty"`
|
||||
DestPort uint32 `protobuf:"varint,2,opt,name=dest_port,json=destPort,proto3" json:"dest_port,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
SourcePort uint32 `protobuf:"varint,1,opt,name=source_port,json=sourcePort,proto3" json:"source_port,omitempty"`
|
||||
DestPort uint32 `protobuf:"varint,2,opt,name=dest_port,json=destPort,proto3" json:"dest_port,omitempty"`
|
||||
}
|
||||
|
||||
func (x *PortInfo) Reset() {
|
||||
*x = PortInfo{}
|
||||
mi := &file_flow_proto_msgTypes[3]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_flow_proto_msgTypes[3]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *PortInfo) String() string {
|
||||
@@ -481,7 +488,7 @@ func (*PortInfo) ProtoMessage() {}
|
||||
|
||||
func (x *PortInfo) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_flow_proto_msgTypes[3]
|
||||
if x != nil {
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
@@ -512,18 +519,21 @@ func (x *PortInfo) GetDestPort() uint32 {
|
||||
|
||||
// ICMP message information
|
||||
type ICMPInfo struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
IcmpType uint32 `protobuf:"varint,1,opt,name=icmp_type,json=icmpType,proto3" json:"icmp_type,omitempty"`
|
||||
IcmpCode uint32 `protobuf:"varint,2,opt,name=icmp_code,json=icmpCode,proto3" json:"icmp_code,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
IcmpType uint32 `protobuf:"varint,1,opt,name=icmp_type,json=icmpType,proto3" json:"icmp_type,omitempty"`
|
||||
IcmpCode uint32 `protobuf:"varint,2,opt,name=icmp_code,json=icmpCode,proto3" json:"icmp_code,omitempty"`
|
||||
}
|
||||
|
||||
func (x *ICMPInfo) Reset() {
|
||||
*x = ICMPInfo{}
|
||||
mi := &file_flow_proto_msgTypes[4]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_flow_proto_msgTypes[4]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *ICMPInfo) String() string {
|
||||
@@ -534,7 +544,7 @@ func (*ICMPInfo) ProtoMessage() {}
|
||||
|
||||
func (x *ICMPInfo) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_flow_proto_msgTypes[4]
|
||||
if x != nil {
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
@@ -565,79 +575,102 @@ func (x *ICMPInfo) GetIcmpCode() uint32 {
|
||||
|
||||
var File_flow_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_flow_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\n" +
|
||||
"flow.proto\x12\x04flow\x1a\x1fgoogle/protobuf/timestamp.proto\"\xd4\x01\n" +
|
||||
"\tFlowEvent\x12\x19\n" +
|
||||
"\bevent_id\x18\x01 \x01(\fR\aeventId\x128\n" +
|
||||
"\ttimestamp\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1d\n" +
|
||||
"\n" +
|
||||
"public_key\x18\x03 \x01(\fR\tpublicKey\x121\n" +
|
||||
"\vflow_fields\x18\x04 \x01(\v2\x10.flow.FlowFieldsR\n" +
|
||||
"flowFields\x12 \n" +
|
||||
"\visInitiator\x18\x05 \x01(\bR\visInitiator\"K\n" +
|
||||
"\fFlowEventAck\x12\x19\n" +
|
||||
"\bevent_id\x18\x01 \x01(\fR\aeventId\x12 \n" +
|
||||
"\visInitiator\x18\x02 \x01(\bR\visInitiator\"\x9c\x04\n" +
|
||||
"\n" +
|
||||
"FlowFields\x12\x17\n" +
|
||||
"\aflow_id\x18\x01 \x01(\fR\x06flowId\x12\x1e\n" +
|
||||
"\x04type\x18\x02 \x01(\x0e2\n" +
|
||||
".flow.TypeR\x04type\x12\x17\n" +
|
||||
"\arule_id\x18\x03 \x01(\fR\x06ruleId\x12-\n" +
|
||||
"\tdirection\x18\x04 \x01(\x0e2\x0f.flow.DirectionR\tdirection\x12\x1a\n" +
|
||||
"\bprotocol\x18\x05 \x01(\rR\bprotocol\x12\x1b\n" +
|
||||
"\tsource_ip\x18\x06 \x01(\fR\bsourceIp\x12\x17\n" +
|
||||
"\adest_ip\x18\a \x01(\fR\x06destIp\x12-\n" +
|
||||
"\tport_info\x18\b \x01(\v2\x0e.flow.PortInfoH\x00R\bportInfo\x12-\n" +
|
||||
"\ticmp_info\x18\t \x01(\v2\x0e.flow.ICMPInfoH\x00R\bicmpInfo\x12\x1d\n" +
|
||||
"\n" +
|
||||
"rx_packets\x18\n" +
|
||||
" \x01(\x04R\trxPackets\x12\x1d\n" +
|
||||
"\n" +
|
||||
"tx_packets\x18\v \x01(\x04R\ttxPackets\x12\x19\n" +
|
||||
"\brx_bytes\x18\f \x01(\x04R\arxBytes\x12\x19\n" +
|
||||
"\btx_bytes\x18\r \x01(\x04R\atxBytes\x12,\n" +
|
||||
"\x12source_resource_id\x18\x0e \x01(\fR\x10sourceResourceId\x12(\n" +
|
||||
"\x10dest_resource_id\x18\x0f \x01(\fR\x0edestResourceIdB\x11\n" +
|
||||
"\x0fconnection_info\"H\n" +
|
||||
"\bPortInfo\x12\x1f\n" +
|
||||
"\vsource_port\x18\x01 \x01(\rR\n" +
|
||||
"sourcePort\x12\x1b\n" +
|
||||
"\tdest_port\x18\x02 \x01(\rR\bdestPort\"D\n" +
|
||||
"\bICMPInfo\x12\x1b\n" +
|
||||
"\ticmp_type\x18\x01 \x01(\rR\bicmpType\x12\x1b\n" +
|
||||
"\ticmp_code\x18\x02 \x01(\rR\bicmpCode*E\n" +
|
||||
"\x04Type\x12\x10\n" +
|
||||
"\fTYPE_UNKNOWN\x10\x00\x12\x0e\n" +
|
||||
"\n" +
|
||||
"TYPE_START\x10\x01\x12\f\n" +
|
||||
"\bTYPE_END\x10\x02\x12\r\n" +
|
||||
"\tTYPE_DROP\x10\x03*;\n" +
|
||||
"\tDirection\x12\x15\n" +
|
||||
"\x11DIRECTION_UNKNOWN\x10\x00\x12\v\n" +
|
||||
"\aINGRESS\x10\x01\x12\n" +
|
||||
"\n" +
|
||||
"\x06EGRESS\x10\x022B\n" +
|
||||
"\vFlowService\x123\n" +
|
||||
"\x06Events\x12\x0f.flow.FlowEvent\x1a\x12.flow.FlowEventAck\"\x00(\x010\x01B\bZ\x06/protob\x06proto3"
|
||||
var file_flow_proto_rawDesc = []byte{
|
||||
0x0a, 0x0a, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x66, 0x6c,
|
||||
0x6f, 0x77, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f,
|
||||
0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72,
|
||||
0x6f, 0x74, 0x6f, 0x22, 0xd4, 0x01, 0x0a, 0x09, 0x46, 0x6c, 0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e,
|
||||
0x74, 0x12, 0x19, 0x0a, 0x08, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x0c, 0x52, 0x07, 0x65, 0x76, 0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x38, 0x0a, 0x09,
|
||||
0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32,
|
||||
0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75,
|
||||
0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d,
|
||||
0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63,
|
||||
0x5f, 0x6b, 0x65, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x70, 0x75, 0x62, 0x6c,
|
||||
0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x31, 0x0a, 0x0b, 0x66, 0x6c, 0x6f, 0x77, 0x5f, 0x66, 0x69,
|
||||
0x65, 0x6c, 0x64, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x10, 0x2e, 0x66, 0x6c, 0x6f,
|
||||
0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x52, 0x0a, 0x66, 0x6c,
|
||||
0x6f, 0x77, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x12, 0x20, 0x0a, 0x0b, 0x69, 0x73, 0x49, 0x6e,
|
||||
0x69, 0x74, 0x69, 0x61, 0x74, 0x6f, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, 0x69,
|
||||
0x73, 0x49, 0x6e, 0x69, 0x74, 0x69, 0x61, 0x74, 0x6f, 0x72, 0x22, 0x4b, 0x0a, 0x0c, 0x46, 0x6c,
|
||||
0x6f, 0x77, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x6b, 0x12, 0x19, 0x0a, 0x08, 0x65, 0x76,
|
||||
0x65, 0x6e, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x65, 0x76,
|
||||
0x65, 0x6e, 0x74, 0x49, 0x64, 0x12, 0x20, 0x0a, 0x0b, 0x69, 0x73, 0x49, 0x6e, 0x69, 0x74, 0x69,
|
||||
0x61, 0x74, 0x6f, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, 0x69, 0x73, 0x49, 0x6e,
|
||||
0x69, 0x74, 0x69, 0x61, 0x74, 0x6f, 0x72, 0x22, 0x9c, 0x04, 0x0a, 0x0a, 0x46, 0x6c, 0x6f, 0x77,
|
||||
0x46, 0x69, 0x65, 0x6c, 0x64, 0x73, 0x12, 0x17, 0x0a, 0x07, 0x66, 0x6c, 0x6f, 0x77, 0x5f, 0x69,
|
||||
0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, 0x66, 0x6c, 0x6f, 0x77, 0x49, 0x64, 0x12,
|
||||
0x1e, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0a, 0x2e,
|
||||
0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12,
|
||||
0x17, 0x0a, 0x07, 0x72, 0x75, 0x6c, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c,
|
||||
0x52, 0x06, 0x72, 0x75, 0x6c, 0x65, 0x49, 0x64, 0x12, 0x2d, 0x0a, 0x09, 0x64, 0x69, 0x72, 0x65,
|
||||
0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x0f, 0x2e, 0x66, 0x6c,
|
||||
0x6f, 0x77, 0x2e, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x64, 0x69,
|
||||
0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f,
|
||||
0x63, 0x6f, 0x6c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f,
|
||||
0x63, 0x6f, 0x6c, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70,
|
||||
0x18, 0x06, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70,
|
||||
0x12, 0x17, 0x0a, 0x07, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x70, 0x18, 0x07, 0x20, 0x01, 0x28,
|
||||
0x0c, 0x52, 0x06, 0x64, 0x65, 0x73, 0x74, 0x49, 0x70, 0x12, 0x2d, 0x0a, 0x09, 0x70, 0x6f, 0x72,
|
||||
0x74, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x66,
|
||||
0x6c, 0x6f, 0x77, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x48, 0x00, 0x52, 0x08,
|
||||
0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x2d, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70,
|
||||
0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x66, 0x6c,
|
||||
0x6f, 0x77, 0x2e, 0x49, 0x43, 0x4d, 0x50, 0x49, 0x6e, 0x66, 0x6f, 0x48, 0x00, 0x52, 0x08, 0x69,
|
||||
0x63, 0x6d, 0x70, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x78, 0x5f, 0x70, 0x61,
|
||||
0x63, 0x6b, 0x65, 0x74, 0x73, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x04, 0x52, 0x09, 0x72, 0x78, 0x50,
|
||||
0x61, 0x63, 0x6b, 0x65, 0x74, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x74, 0x78, 0x5f, 0x70, 0x61, 0x63,
|
||||
0x6b, 0x65, 0x74, 0x73, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x04, 0x52, 0x09, 0x74, 0x78, 0x50, 0x61,
|
||||
0x63, 0x6b, 0x65, 0x74, 0x73, 0x12, 0x19, 0x0a, 0x08, 0x72, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65,
|
||||
0x73, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x04, 0x52, 0x07, 0x72, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73,
|
||||
0x12, 0x19, 0x0a, 0x08, 0x74, 0x78, 0x5f, 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x0d, 0x20, 0x01,
|
||||
0x28, 0x04, 0x52, 0x07, 0x74, 0x78, 0x42, 0x79, 0x74, 0x65, 0x73, 0x12, 0x2c, 0x0a, 0x12, 0x73,
|
||||
0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69,
|
||||
0x64, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x10, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52,
|
||||
0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x64, 0x12, 0x28, 0x0a, 0x10, 0x64, 0x65, 0x73,
|
||||
0x74, 0x5f, 0x72, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x0f, 0x20,
|
||||
0x01, 0x28, 0x0c, 0x52, 0x0e, 0x64, 0x65, 0x73, 0x74, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63,
|
||||
0x65, 0x49, 0x64, 0x42, 0x11, 0x0a, 0x0f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f,
|
||||
0x6e, 0x5f, 0x69, 0x6e, 0x66, 0x6f, 0x22, 0x48, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e,
|
||||
0x66, 0x6f, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x70, 0x6f, 0x72,
|
||||
0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50,
|
||||
0x6f, 0x72, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x70, 0x6f, 0x72, 0x74,
|
||||
0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x64, 0x65, 0x73, 0x74, 0x50, 0x6f, 0x72, 0x74,
|
||||
0x22, 0x44, 0x0a, 0x08, 0x49, 0x43, 0x4d, 0x50, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1b, 0x0a, 0x09,
|
||||
0x69, 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52,
|
||||
0x08, 0x69, 0x63, 0x6d, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x69, 0x63, 0x6d,
|
||||
0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x08, 0x69, 0x63,
|
||||
0x6d, 0x70, 0x43, 0x6f, 0x64, 0x65, 0x2a, 0x45, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x10,
|
||||
0x0a, 0x0c, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00,
|
||||
0x12, 0x0e, 0x0a, 0x0a, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x53, 0x54, 0x41, 0x52, 0x54, 0x10, 0x01,
|
||||
0x12, 0x0c, 0x0a, 0x08, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x45, 0x4e, 0x44, 0x10, 0x02, 0x12, 0x0d,
|
||||
0x0a, 0x09, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x03, 0x2a, 0x3b, 0x0a,
|
||||
0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x15, 0x0a, 0x11, 0x44, 0x49,
|
||||
0x52, 0x45, 0x43, 0x54, 0x49, 0x4f, 0x4e, 0x5f, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10,
|
||||
0x00, 0x12, 0x0b, 0x0a, 0x07, 0x49, 0x4e, 0x47, 0x52, 0x45, 0x53, 0x53, 0x10, 0x01, 0x12, 0x0a,
|
||||
0x0a, 0x06, 0x45, 0x47, 0x52, 0x45, 0x53, 0x53, 0x10, 0x02, 0x32, 0x42, 0x0a, 0x0b, 0x46, 0x6c,
|
||||
0x6f, 0x77, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x06, 0x45, 0x76, 0x65,
|
||||
0x6e, 0x74, 0x73, 0x12, 0x0f, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77, 0x45,
|
||||
0x76, 0x65, 0x6e, 0x74, 0x1a, 0x12, 0x2e, 0x66, 0x6c, 0x6f, 0x77, 0x2e, 0x46, 0x6c, 0x6f, 0x77,
|
||||
0x45, 0x76, 0x65, 0x6e, 0x74, 0x41, 0x63, 0x6b, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08,
|
||||
0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
file_flow_proto_rawDescOnce sync.Once
|
||||
file_flow_proto_rawDescData []byte
|
||||
file_flow_proto_rawDescData = file_flow_proto_rawDesc
|
||||
)
|
||||
|
||||
func file_flow_proto_rawDescGZIP() []byte {
|
||||
file_flow_proto_rawDescOnce.Do(func() {
|
||||
file_flow_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_flow_proto_rawDesc), len(file_flow_proto_rawDesc)))
|
||||
file_flow_proto_rawDescData = protoimpl.X.CompressGZIP(file_flow_proto_rawDescData)
|
||||
})
|
||||
return file_flow_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_flow_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
|
||||
var file_flow_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
|
||||
var file_flow_proto_goTypes = []any{
|
||||
var file_flow_proto_goTypes = []interface{}{
|
||||
(Type)(0), // 0: flow.Type
|
||||
(Direction)(0), // 1: flow.Direction
|
||||
(*FlowEvent)(nil), // 2: flow.FlowEvent
|
||||
@@ -668,7 +701,69 @@ func file_flow_proto_init() {
|
||||
if File_flow_proto != nil {
|
||||
return
|
||||
}
|
||||
file_flow_proto_msgTypes[2].OneofWrappers = []any{
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_flow_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*FlowEvent); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_flow_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*FlowEventAck); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_flow_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*FlowFields); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_flow_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*PortInfo); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_flow_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*ICMPInfo); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
file_flow_proto_msgTypes[2].OneofWrappers = []interface{}{
|
||||
(*FlowFields_PortInfo)(nil),
|
||||
(*FlowFields_IcmpInfo)(nil),
|
||||
}
|
||||
@@ -676,7 +771,7 @@ func file_flow_proto_init() {
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_flow_proto_rawDesc), len(file_flow_proto_rawDesc)),
|
||||
RawDescriptor: file_flow_proto_rawDesc,
|
||||
NumEnums: 2,
|
||||
NumMessages: 5,
|
||||
NumExtensions: 0,
|
||||
@@ -688,6 +783,7 @@ func file_flow_proto_init() {
|
||||
MessageInfos: file_flow_proto_msgTypes,
|
||||
}.Build()
|
||||
File_flow_proto = out.File
|
||||
file_flow_proto_rawDesc = nil
|
||||
file_flow_proto_goTypes = nil
|
||||
file_flow_proto_depIdxs = nil
|
||||
}
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.6.1
|
||||
// - protoc v6.33.1
|
||||
// source: flow.proto
|
||||
|
||||
package proto
|
||||
|
||||
@@ -15,19 +11,15 @@ import (
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.64.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion9
|
||||
|
||||
const (
|
||||
FlowService_Events_FullMethodName = "/flow.FlowService/Events"
|
||||
)
|
||||
// Requires gRPC-Go v1.32.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion7
|
||||
|
||||
// FlowServiceClient is the client API for FlowService service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type FlowServiceClient interface {
|
||||
// Client to receiver streams of events and acknowledgements
|
||||
Events(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[FlowEvent, FlowEventAck], error)
|
||||
Events(ctx context.Context, opts ...grpc.CallOption) (FlowService_EventsClient, error)
|
||||
}
|
||||
|
||||
type flowServiceClient struct {
|
||||
@@ -38,40 +30,54 @@ func NewFlowServiceClient(cc grpc.ClientConnInterface) FlowServiceClient {
|
||||
return &flowServiceClient{cc}
|
||||
}
|
||||
|
||||
func (c *flowServiceClient) Events(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[FlowEvent, FlowEventAck], error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
stream, err := c.cc.NewStream(ctx, &FlowService_ServiceDesc.Streams[0], FlowService_Events_FullMethodName, cOpts...)
|
||||
func (c *flowServiceClient) Events(ctx context.Context, opts ...grpc.CallOption) (FlowService_EventsClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &FlowService_ServiceDesc.Streams[0], "/flow.FlowService/Events", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &grpc.GenericClientStream[FlowEvent, FlowEventAck]{ClientStream: stream}
|
||||
x := &flowServiceEventsClient{stream}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type FlowService_EventsClient = grpc.BidiStreamingClient[FlowEvent, FlowEventAck]
|
||||
type FlowService_EventsClient interface {
|
||||
Send(*FlowEvent) error
|
||||
Recv() (*FlowEventAck, error)
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
type flowServiceEventsClient struct {
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (x *flowServiceEventsClient) Send(m *FlowEvent) error {
|
||||
return x.ClientStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (x *flowServiceEventsClient) Recv() (*FlowEventAck, error) {
|
||||
m := new(FlowEventAck)
|
||||
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// FlowServiceServer is the server API for FlowService service.
|
||||
// All implementations must embed UnimplementedFlowServiceServer
|
||||
// for forward compatibility.
|
||||
// for forward compatibility
|
||||
type FlowServiceServer interface {
|
||||
// Client to receiver streams of events and acknowledgements
|
||||
Events(grpc.BidiStreamingServer[FlowEvent, FlowEventAck]) error
|
||||
Events(FlowService_EventsServer) error
|
||||
mustEmbedUnimplementedFlowServiceServer()
|
||||
}
|
||||
|
||||
// UnimplementedFlowServiceServer must be embedded to have
|
||||
// forward compatible implementations.
|
||||
//
|
||||
// NOTE: this should be embedded by value instead of pointer to avoid a nil
|
||||
// pointer dereference when methods are called.
|
||||
type UnimplementedFlowServiceServer struct{}
|
||||
// UnimplementedFlowServiceServer must be embedded to have forward compatible implementations.
|
||||
type UnimplementedFlowServiceServer struct {
|
||||
}
|
||||
|
||||
func (UnimplementedFlowServiceServer) Events(grpc.BidiStreamingServer[FlowEvent, FlowEventAck]) error {
|
||||
return status.Error(codes.Unimplemented, "method Events not implemented")
|
||||
func (UnimplementedFlowServiceServer) Events(FlowService_EventsServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method Events not implemented")
|
||||
}
|
||||
func (UnimplementedFlowServiceServer) mustEmbedUnimplementedFlowServiceServer() {}
|
||||
func (UnimplementedFlowServiceServer) testEmbeddedByValue() {}
|
||||
|
||||
// UnsafeFlowServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to FlowServiceServer will
|
||||
@@ -81,22 +87,34 @@ type UnsafeFlowServiceServer interface {
|
||||
}
|
||||
|
||||
func RegisterFlowServiceServer(s grpc.ServiceRegistrar, srv FlowServiceServer) {
|
||||
// If the following call panics, it indicates UnimplementedFlowServiceServer was
|
||||
// embedded by pointer and is nil. This will cause panics if an
|
||||
// unimplemented method is ever invoked, so we test this at initialization
|
||||
// time to prevent it from happening at runtime later due to I/O.
|
||||
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
|
||||
t.testEmbeddedByValue()
|
||||
}
|
||||
s.RegisterService(&FlowService_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _FlowService_Events_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
return srv.(FlowServiceServer).Events(&grpc.GenericServerStream[FlowEvent, FlowEventAck]{ServerStream: stream})
|
||||
return srv.(FlowServiceServer).Events(&flowServiceEventsServer{stream})
|
||||
}
|
||||
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type FlowService_EventsServer = grpc.BidiStreamingServer[FlowEvent, FlowEventAck]
|
||||
type FlowService_EventsServer interface {
|
||||
Send(*FlowEventAck) error
|
||||
Recv() (*FlowEvent, error)
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
type flowServiceEventsServer struct {
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (x *flowServiceEventsServer) Send(m *FlowEventAck) error {
|
||||
return x.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (x *flowServiceEventsServer) Recv() (*FlowEvent, error) {
|
||||
m := new(FlowEvent)
|
||||
if err := x.ServerStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// FlowService_ServiceDesc is the grpc.ServiceDesc for FlowService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
|
||||
@@ -9,21 +9,9 @@ then
|
||||
fi
|
||||
|
||||
old_pwd=$(pwd)
|
||||
script_path=$(dirname "$(realpath "$0")")
|
||||
script_path=$(dirname $(realpath "$0"))
|
||||
cd "$script_path"
|
||||
|
||||
repo_root=$(git rev-parse --show-toplevel)
|
||||
# shellcheck source=/dev/null
|
||||
. "$repo_root/proto-tools.env"
|
||||
|
||||
actual_protoc=$(protoc --version | awk '{print $2}')
|
||||
if [[ "$actual_protoc" != "$PROTOC_VERSION" ]]; then
|
||||
echo "ERROR: protoc version $actual_protoc differs from pinned $PROTOC_VERSION" >&2
|
||||
echo "Install protoc $PROTOC_VERSION from https://github.com/protocolbuffers/protobuf/releases" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
go install "google.golang.org/protobuf/cmd/protoc-gen-go@${PROTOC_GEN_GO_VERSION}"
|
||||
go install "google.golang.org/grpc/cmd/protoc-gen-go-grpc@${PROTOC_GEN_GO_GRPC_VERSION}"
|
||||
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26
|
||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
|
||||
protoc -I ./ ./flow.proto --go_out=../ --go-grpc_out=../
|
||||
cd "$old_pwd"
|
||||
|
||||
3
go.mod
3
go.mod
@@ -51,6 +51,7 @@ require (
|
||||
github.com/eko/gocache/lib/v4 v4.2.0
|
||||
github.com/eko/gocache/store/go_cache/v4 v4.2.2
|
||||
github.com/eko/gocache/store/redis/v4 v4.2.2
|
||||
github.com/flynn/noise v1.1.0
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/gliderlabs/ssh v0.3.8
|
||||
github.com/go-jose/go-jose/v4 v4.1.4
|
||||
@@ -66,6 +67,8 @@ require (
|
||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
||||
github.com/hashicorp/go-version v1.7.0
|
||||
github.com/jackc/pgx/v5 v5.5.5
|
||||
github.com/jezek/xgb v1.3.0
|
||||
github.com/kirides/go-d3d v1.0.1
|
||||
github.com/libdns/route53 v1.5.0
|
||||
github.com/libp2p/go-nat v0.2.0
|
||||
github.com/libp2p/go-netroute v0.4.0
|
||||
|
||||
8
go.sum
8
go.sum
@@ -162,6 +162,8 @@ github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g=
|
||||
github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
|
||||
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
||||
github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw=
|
||||
github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
@@ -378,6 +380,8 @@ github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZ
|
||||
github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
|
||||
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade h1:FmusiCI1wHw+XQbvL9M+1r/C3SPqKrmBaIOYwVfQoDE=
|
||||
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade/go.mod h1:ZDXo8KHryOWSIqnsb/CiDq7hQUYryCgdVnxbj8tDG7o=
|
||||
github.com/jezek/xgb v1.3.0 h1:Wa1pn4GVtcmNVAVB6/pnQVJ7xPFZVZ/W1Tc27msDhgI=
|
||||
github.com/jezek/xgb v1.3.0/go.mod h1:nrhwO0FX/enq75I7Y7G8iN1ubpSGZEiA3v9e9GyRFlk=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
@@ -396,6 +400,8 @@ github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6U
|
||||
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE=
|
||||
github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8=
|
||||
github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg=
|
||||
github.com/kirides/go-d3d v1.0.1 h1:ZDANfvo34vskBMET1uwUUMNw8545Kbe8qYSiRwlNIuA=
|
||||
github.com/kirides/go-d3d v1.0.1/go.mod h1:99AjD+5mRTFEnkpRWkwq8UYMQDljGIIvLn2NyRdVImY=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
@@ -408,6 +414,7 @@ github.com/koron/go-ssdp v0.0.4/go.mod h1:oDXq+E5IL5q0U8uSBcoAXzTzInwy5lEgC91HoK
|
||||
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
|
||||
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
@@ -752,6 +759,7 @@ goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
|
||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
||||
|
||||
@@ -2,6 +2,7 @@ package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
@@ -98,10 +99,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
|
||||
|
||||
sshConfig := &proto.SSHConfig{
|
||||
SshEnabled: peer.SSHEnabled || enableSSH,
|
||||
}
|
||||
|
||||
if sshConfig.SshEnabled {
|
||||
sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig)
|
||||
JwtConfig: buildJWTConfig(httpConfig, deviceFlowConfig),
|
||||
}
|
||||
|
||||
peerConfig := &proto.PeerConfig{
|
||||
@@ -134,13 +132,14 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
includeIPv6 := peer.SupportsIPv6() && peer.IPv6.IsValid()
|
||||
useSourcePrefixes := peer.SupportsSourcePrefixes()
|
||||
|
||||
peerConfig := toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH)
|
||||
response := &proto.SyncResponse{
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||
PeerConfig: peerConfig,
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||
PeerConfig: peerConfig,
|
||||
},
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
@@ -149,8 +148,6 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||
response.NetbirdConfig = extendedConfig
|
||||
|
||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||
|
||||
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
|
||||
response.RemotePeers = remotePeers
|
||||
@@ -176,18 +173,59 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
response.NetworkMap.ForwardingRules = forwardingRules
|
||||
}
|
||||
|
||||
userIDClaim := auth.DefaultUserIDClaim
|
||||
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||
userIDClaim = httpConfig.AuthUserIDClaim
|
||||
}
|
||||
|
||||
if networkMap.AuthorizedUsers != nil {
|
||||
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
|
||||
userIDClaim := auth.DefaultUserIDClaim
|
||||
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||
userIDClaim = httpConfig.AuthUserIDClaim
|
||||
}
|
||||
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
|
||||
}
|
||||
|
||||
if networkMap.VNCAuthorizedUsers != nil {
|
||||
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.VNCAuthorizedUsers)
|
||||
response.NetworkMap.VncAuth = &proto.VNCAuth{
|
||||
AuthorizedUsers: hashedUsers,
|
||||
MachineUsers: machineUsers,
|
||||
SessionPubKeys: buildSessionPubKeysProto(ctx, networkMap.VNCSessionPubKeys),
|
||||
}
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// buildSessionPubKeysProto decodes base64 X25519 session pubkeys and
|
||||
// hashes the user IDs they belong to, emitting the proto entries the
|
||||
// daemon's authorizer indexes by pubkey.
|
||||
func buildSessionPubKeysProto(ctx context.Context, in []types.VNCSessionPubKey) []*proto.SessionPubKey {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.SessionPubKey, 0, len(in))
|
||||
for _, e := range in {
|
||||
pub, err := base64.StdEncoding.DecodeString(e.PubKey)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("decode VNC session pubkey: %v", err)
|
||||
continue
|
||||
}
|
||||
if len(pub) != 32 {
|
||||
log.WithContext(ctx).Warnf("VNC session pubkey wrong length: %d", len(pub))
|
||||
continue
|
||||
}
|
||||
hash, err := sshauth.HashUserID(e.UserID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("hash VNC session user id: %v", err)
|
||||
continue
|
||||
}
|
||||
out = append(out, &proto.SessionPubKey{
|
||||
PubKey: pub,
|
||||
UserIdHash: hash[:],
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
|
||||
userIDToIndex := make(map[string]uint32)
|
||||
var hashedUsers [][]byte
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -137,12 +136,9 @@ type proxyConnection struct {
|
||||
tokenID string
|
||||
capabilities *proto.ProxyCapabilities
|
||||
stream proto.ProxyService_GetMappingUpdateServer
|
||||
// syncStream is set when the proxy connected via SyncMappings.
|
||||
// When non-nil, the sender goroutine uses this instead of stream.
|
||||
syncStream proto.ProxyService_SyncMappingsServer
|
||||
sendChan chan *proto.GetMappingUpdateResponse
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
sendChan chan *proto.GetMappingUpdateResponse
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func enforceAccountScope(ctx context.Context, requestAccountID string) error {
|
||||
@@ -210,322 +206,145 @@ func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller
|
||||
s.proxyController = proxyController
|
||||
}
|
||||
|
||||
// proxyConnectParams holds the validated parameters extracted from either
|
||||
// a GetMappingUpdateRequest or a SyncMappingsInit message.
|
||||
type proxyConnectParams struct {
|
||||
proxyID string
|
||||
address string
|
||||
capabilities *proto.ProxyCapabilities
|
||||
}
|
||||
|
||||
// GetMappingUpdate handles the control stream with proxy clients
|
||||
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
|
||||
params, err := s.validateProxyConnect(req.GetProxyId(), req.GetAddress(), stream.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
params.capabilities = req.GetCapabilities()
|
||||
ctx := stream.Context()
|
||||
|
||||
conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{
|
||||
stream: stream,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.sendSnapshot(stream.Context(), conn); err != nil {
|
||||
s.cleanupFailedSnapshot(stream.Context(), conn)
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
|
||||
return s.serveProxyConnection(conn, proxyRecord, errChan, false)
|
||||
}
|
||||
|
||||
// SyncMappings implements the bidirectional SyncMappings RPC.
|
||||
// It mirrors GetMappingUpdate but provides application-level back-pressure:
|
||||
// management waits for an ack from the proxy before sending the next batch.
|
||||
func (s *ProxyServiceServer) SyncMappings(stream proto.ProxyService_SyncMappingsServer) error {
|
||||
init, err := recvSyncInit(stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
params, err := s.validateProxyConnect(init.GetProxyId(), init.GetAddress(), stream.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
params.capabilities = init.GetCapabilities()
|
||||
|
||||
conn, proxyRecord, err := s.registerProxyConnection(stream.Context(), params, &proxyConnection{
|
||||
syncStream: stream,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.sendSnapshotSync(stream.Context(), conn, stream); err != nil {
|
||||
s.cleanupFailedSnapshot(stream.Context(), conn)
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", params.proxyID, err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
go s.drainRecv(stream, errChan)
|
||||
|
||||
return s.serveProxyConnection(conn, proxyRecord, errChan, true)
|
||||
}
|
||||
|
||||
// recvSyncInit receives and validates the first message on a SyncMappings stream.
|
||||
func recvSyncInit(stream proto.ProxyService_SyncMappingsServer) (*proto.SyncMappingsInit, error) {
|
||||
firstMsg, err := stream.Recv()
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "receive init: %v", err)
|
||||
}
|
||||
init := firstMsg.GetInit()
|
||||
if init == nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "first message must be init")
|
||||
}
|
||||
return init, nil
|
||||
}
|
||||
|
||||
// validateProxyConnect validates the proxy ID and address, and checks cluster
|
||||
// address availability for account-scoped tokens.
|
||||
func (s *ProxyServiceServer) validateProxyConnect(proxyID, address string, ctx context.Context) (proxyConnectParams, error) {
|
||||
if proxyID == "" {
|
||||
return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy_id is required")
|
||||
}
|
||||
if !isProxyAddressValid(address) {
|
||||
return proxyConnectParams{}, status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||
}
|
||||
|
||||
token := GetProxyTokenFromContext(ctx)
|
||||
if token != nil && token.AccountID != nil {
|
||||
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, address, *token.AccountID)
|
||||
if err != nil {
|
||||
return proxyConnectParams{}, status.Errorf(codes.Internal, "check cluster address: %v", err)
|
||||
}
|
||||
if !available {
|
||||
return proxyConnectParams{}, status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", address)
|
||||
}
|
||||
}
|
||||
|
||||
return proxyConnectParams{proxyID: proxyID, address: address}, nil
|
||||
}
|
||||
|
||||
// registerProxyConnection creates a proxyConnection, registers it with the
|
||||
// proxy manager and cluster, and stores it in connectedProxies. The caller
|
||||
// provides a partially initialised connSeed with stream-specific fields set;
|
||||
// the remaining fields are filled in here.
|
||||
func (s *ProxyServiceServer) registerProxyConnection(ctx context.Context, params proxyConnectParams, connSeed *proxyConnection) (*proxyConnection, *proxy.Proxy, error) {
|
||||
peerInfo := PeerIPFromContext(ctx)
|
||||
log.Infof("New proxy connection from %s", peerInfo)
|
||||
|
||||
proxyID := req.GetProxyId()
|
||||
if proxyID == "" {
|
||||
return status.Errorf(codes.InvalidArgument, "proxy_id is required")
|
||||
}
|
||||
|
||||
proxyAddress := req.GetAddress()
|
||||
if !isProxyAddressValid(proxyAddress) {
|
||||
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||
}
|
||||
|
||||
var accountID *string
|
||||
var tokenID string
|
||||
if token := GetProxyTokenFromContext(ctx); token != nil {
|
||||
if token.AccountID != nil {
|
||||
accountID = token.AccountID
|
||||
token := GetProxyTokenFromContext(ctx)
|
||||
if token != nil && token.AccountID != nil {
|
||||
accountID = token.AccountID
|
||||
|
||||
available, err := s.proxyManager.IsClusterAddressAvailable(ctx, proxyAddress, *accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "check cluster address: %v", err)
|
||||
}
|
||||
if !available {
|
||||
return status.Errorf(codes.AlreadyExists, "cluster address %s is already in use", proxyAddress)
|
||||
}
|
||||
}
|
||||
|
||||
var tokenID string
|
||||
if token != nil {
|
||||
tokenID = token.ID
|
||||
}
|
||||
|
||||
sessionID := uuid.NewString()
|
||||
s.supersedePriorConnection(params.proxyID, sessionID)
|
||||
|
||||
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
|
||||
oldConn := old.(*proxyConnection)
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"old_session_id": oldConn.sessionID,
|
||||
"new_session_id": sessionID,
|
||||
}).Info("Superseding existing proxy connection")
|
||||
oldConn.cancel()
|
||||
}
|
||||
|
||||
connCtx, cancel := context.WithCancel(ctx)
|
||||
connSeed.proxyID = params.proxyID
|
||||
connSeed.sessionID = sessionID
|
||||
connSeed.address = params.address
|
||||
connSeed.accountID = accountID
|
||||
connSeed.tokenID = tokenID
|
||||
connSeed.capabilities = params.capabilities
|
||||
connSeed.sendChan = make(chan *proto.GetMappingUpdateResponse, 100)
|
||||
connSeed.ctx = connCtx
|
||||
connSeed.cancel = cancel
|
||||
conn := &proxyConnection{
|
||||
proxyID: proxyID,
|
||||
sessionID: sessionID,
|
||||
address: proxyAddress,
|
||||
accountID: accountID,
|
||||
tokenID: tokenID,
|
||||
capabilities: req.GetCapabilities(),
|
||||
stream: stream,
|
||||
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
|
||||
ctx: connCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
var caps *proxy.Capabilities
|
||||
if c := params.capabilities; c != nil {
|
||||
if c := req.GetCapabilities(); c != nil {
|
||||
caps = &proxy.Capabilities{
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||
}
|
||||
}
|
||||
|
||||
proxyRecord, err := s.proxyManager.Connect(ctx, params.proxyID, sessionID, params.address, peerInfo, accountID, caps)
|
||||
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, accountID, caps)
|
||||
if err != nil {
|
||||
cancel()
|
||||
if accountID != nil {
|
||||
return nil, nil, status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
|
||||
return status.Errorf(codes.Internal, "failed to register BYOP proxy: %v", err)
|
||||
}
|
||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", params.proxyID, err)
|
||||
return nil, nil, status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
||||
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||
}
|
||||
|
||||
s.connectedProxies.Store(params.proxyID, connSeed)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, params.address, params.proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", params.proxyID, err)
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
return connSeed, proxyRecord, nil
|
||||
}
|
||||
|
||||
// supersedePriorConnection cancels any existing connection for the given proxy.
|
||||
func (s *ProxyServiceServer) supersedePriorConnection(proxyID, newSessionID string) {
|
||||
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
|
||||
oldConn := old.(*proxyConnection)
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"old_session_id": oldConn.sessionID,
|
||||
"new_session_id": newSessionID,
|
||||
}).Info("Superseding existing proxy connection")
|
||||
oldConn.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupFailedSnapshot removes the connection from the cluster and store
|
||||
// after a snapshot send failure.
|
||||
func (s *ProxyServiceServer) cleanupFailedSnapshot(ctx context.Context, conn *proxyConnection) {
|
||||
if s.connectedProxies.CompareAndDelete(conn.proxyID, conn) {
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err)
|
||||
if err := s.sendSnapshot(ctx, conn); err != nil {
|
||||
if s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
conn.cancel()
|
||||
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", conn.proxyID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// drainRecv consumes and discards messages from a bidirectional stream.
|
||||
// The proxy sends an ack for every incremental update; we don't need them
|
||||
// after the snapshot phase. Recv errors are forwarded to errChan.
|
||||
func (s *ProxyServiceServer) drainRecv(stream proto.ProxyService_SyncMappingsServer, errChan chan<- error) {
|
||||
for {
|
||||
if _, err := stream.Recv(); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
cancel()
|
||||
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr)
|
||||
}
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// serveProxyConnection runs the post-snapshot lifecycle: heartbeat, sender,
|
||||
// and wait for termination. When bidi is true, normal stream closure (EOF,
|
||||
// canceled) is treated as a clean disconnect rather than an error.
|
||||
func (s *ProxyServiceServer) serveProxyConnection(conn *proxyConnection, proxyRecord *proxy.Proxy, errChan <-chan error, bidi bool) error {
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": conn.proxyID,
|
||||
"session_id": conn.sessionID,
|
||||
"address": conn.address,
|
||||
"cluster_addr": conn.address,
|
||||
"account_id": conn.accountID,
|
||||
"proxy_id": proxyID,
|
||||
"session_id": sessionID,
|
||||
"address": proxyAddress,
|
||||
"cluster_addr": proxyAddress,
|
||||
"account_id": accountID,
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
defer func() {
|
||||
if !s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID)
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
|
||||
defer s.disconnectProxy(conn)
|
||||
go s.heartbeat(conn.ctx, conn, proxyRecord)
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
|
||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
|
||||
}
|
||||
if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil {
|
||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
|
||||
}()
|
||||
|
||||
go s.heartbeat(connCtx, conn, proxyRecord)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
if bidi && isStreamClosed(err) {
|
||||
log.Infof("Proxy %s stream closed", conn.proxyID)
|
||||
return nil
|
||||
}
|
||||
log.Warnf("Failed to send update: %v", err)
|
||||
return fmt.Errorf("send update to proxy %s: %w", conn.proxyID, err)
|
||||
case <-conn.ctx.Done():
|
||||
log.Infof("Proxy %s context canceled", conn.proxyID)
|
||||
return conn.ctx.Err()
|
||||
log.WithContext(ctx).Warnf("Failed to send update: %v", err)
|
||||
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
|
||||
case <-connCtx.Done():
|
||||
log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID)
|
||||
return connCtx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// disconnectProxy removes the connection from cluster and store, unless it
|
||||
// has already been superseded by a newer connection.
|
||||
func (s *ProxyServiceServer) disconnectProxy(conn *proxyConnection) {
|
||||
if !s.connectedProxies.CompareAndDelete(conn.proxyID, conn) {
|
||||
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", conn.proxyID, conn.sessionID)
|
||||
conn.cancel()
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
|
||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", conn.proxyID, err)
|
||||
}
|
||||
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
|
||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", conn.proxyID, err)
|
||||
}
|
||||
|
||||
conn.cancel()
|
||||
log.Infof("Proxy %s session %s disconnected", conn.proxyID, conn.sessionID)
|
||||
}
|
||||
|
||||
// sendSnapshotSync sends the initial snapshot with back-pressure: it sends
|
||||
// one batch, then waits for the proxy to ack before sending the next.
|
||||
func (s *ProxyServiceServer) sendSnapshotSync(ctx context.Context, conn *proxyConnection, stream proto.ProxyService_SyncMappingsServer) error {
|
||||
if !isProxyAddressValid(conn.address) {
|
||||
return fmt.Errorf("proxy address is invalid")
|
||||
}
|
||||
if s.snapshotBatchSize <= 0 {
|
||||
return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize)
|
||||
}
|
||||
|
||||
mappings, err := s.snapshotServiceMappings(ctx, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
|
||||
end := i + s.snapshotBatchSize
|
||||
if end > len(mappings) {
|
||||
end = len(mappings)
|
||||
}
|
||||
for _, m := range mappings[i:end] {
|
||||
token, err := s.tokenStore.GenerateToken(m.AccountId, m.Id, s.proxyTokenTTL())
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate auth token for service %s: %w", m.Id, err)
|
||||
}
|
||||
m.AuthToken = token
|
||||
}
|
||||
if err := stream.Send(&proto.SyncMappingsResponse{
|
||||
Mapping: mappings[i:end],
|
||||
InitialSyncComplete: end == len(mappings),
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send snapshot batch: %w", err)
|
||||
}
|
||||
|
||||
if err := waitForAck(stream); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(mappings) == 0 {
|
||||
if err := stream.Send(&proto.SyncMappingsResponse{
|
||||
InitialSyncComplete: true,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send snapshot completion: %w", err)
|
||||
}
|
||||
|
||||
if err := waitForAck(stream); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func waitForAck(stream proto.ProxyService_SyncMappingsServer) error {
|
||||
msg, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("receive ack: %w", err)
|
||||
}
|
||||
if msg.GetAck() == nil {
|
||||
return fmt.Errorf("expected ack, got %T", msg.GetMsg())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// heartbeat updates the proxy's last_seen timestamp every minute and
|
||||
// disconnects the proxy if its access token has been revoked.
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, conn *proxyConnection, p *proxy.Proxy) {
|
||||
@@ -562,9 +381,6 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
if !isProxyAddressValid(conn.address) {
|
||||
return fmt.Errorf("proxy address is invalid")
|
||||
}
|
||||
if s.snapshotBatchSize <= 0 {
|
||||
return fmt.Errorf("invalid snapshot batch size: %d", s.snapshotBatchSize)
|
||||
}
|
||||
|
||||
mappings, err := s.snapshotServiceMappings(ctx, conn)
|
||||
if err != nil {
|
||||
@@ -644,26 +460,12 @@ func isProxyAddressValid(addr string) bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// isStreamClosed returns true for errors that indicate normal stream
|
||||
// termination: io.EOF, context cancellation, or gRPC Canceled.
|
||||
func isStreamClosed(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
|
||||
return true
|
||||
}
|
||||
return status.Code(err) == codes.Canceled
|
||||
}
|
||||
|
||||
// sender handles sending messages to proxy.
|
||||
// When conn.syncStream is set the message is sent as SyncMappingsResponse;
|
||||
// otherwise the legacy GetMappingUpdateResponse stream is used.
|
||||
// sender handles sending messages to proxy
|
||||
func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) {
|
||||
for {
|
||||
select {
|
||||
case resp := <-conn.sendChan:
|
||||
if err := conn.sendResponse(resp); err != nil {
|
||||
if err := conn.stream.Send(resp); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
@@ -673,17 +475,6 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
||||
}
|
||||
}
|
||||
|
||||
// sendResponse sends a mapping update on whichever stream the proxy connected with.
|
||||
func (conn *proxyConnection) sendResponse(resp *proto.GetMappingUpdateResponse) error {
|
||||
if conn.syncStream != nil {
|
||||
return conn.syncStream.Send(&proto.SyncMappingsResponse{
|
||||
Mapping: resp.Mapping,
|
||||
InitialSyncComplete: resp.InitialSyncComplete,
|
||||
})
|
||||
}
|
||||
return conn.stream.Send(resp)
|
||||
}
|
||||
|
||||
// SendAccessLog processes access log from proxy
|
||||
func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendAccessLogRequest) (*proto.SendAccessLogResponse, error) {
|
||||
accessLog := req.GetLog()
|
||||
@@ -750,8 +541,8 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
||||
return true
|
||||
}
|
||||
connUpdate = &proto.GetMappingUpdateResponse{
|
||||
Mapping: filtered,
|
||||
InitialSyncComplete: update.InitialSyncComplete,
|
||||
Mapping: filtered,
|
||||
InitialSyncComplete: update.InitialSyncComplete,
|
||||
}
|
||||
}
|
||||
resp := s.perProxyMessage(connUpdate, conn.proxyID)
|
||||
|
||||
@@ -674,6 +674,7 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
|
||||
RosenpassEnabled: meta.GetFlags().GetRosenpassEnabled(),
|
||||
RosenpassPermissive: meta.GetFlags().GetRosenpassPermissive(),
|
||||
ServerSSHAllowed: meta.GetFlags().GetServerSSHAllowed(),
|
||||
ServerVNCAllowed: meta.GetFlags().GetServerVNCAllowed(),
|
||||
DisableClientRoutes: meta.GetFlags().GetDisableClientRoutes(),
|
||||
DisableServerRoutes: meta.GetFlags().GetDisableServerRoutes(),
|
||||
DisableDNS: meta.GetFlags().GetDisableDNS(),
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user