mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-21 16:19:56 +00:00
Compare commits
76 Commits
feature/se
...
embedded-v
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
412193c602 | ||
|
|
5e67febf57 | ||
|
|
ee348ba007 | ||
|
|
3d3055dc7f | ||
|
|
2f4ddf0796 | ||
|
|
98d533c8e8 | ||
|
|
ef4ea2e311 | ||
|
|
b41d11bbbe | ||
|
|
f37e228cc2 | ||
|
|
640a267556 | ||
|
|
17359cdc1e | ||
|
|
7e5846a1ee | ||
|
|
517bea0daf | ||
|
|
9192b4f029 | ||
|
|
c784b02550 | ||
|
|
896530fd82 | ||
|
|
354fd004c7 | ||
|
|
c28e41e82b | ||
|
|
02b9fe704b | ||
|
|
5e200fa571 | ||
|
|
7d61975f6c | ||
|
|
62b36112ea | ||
|
|
df9a6fb020 | ||
|
|
b1b04f9ec6 | ||
|
|
fe15688f20 | ||
|
|
2285db2b62 | ||
|
|
b3f0f53a23 | ||
|
|
5eec9962ba | ||
|
|
393c102f45 | ||
|
|
b41fbad5e1 | ||
|
|
24a5f2252c | ||
|
|
9d189bb3e8 | ||
|
|
8e2505b59c | ||
|
|
97bc1eebde | ||
|
|
32a5a061b8 | ||
|
|
d927ef468a | ||
|
|
d3f3e08035 | ||
|
|
6bb66e0fad | ||
|
|
d250f92c43 | ||
|
|
80966ab1b0 | ||
|
|
bc407527f4 | ||
|
|
5543404188 | ||
|
|
c2fdf62f1f | ||
|
|
b9f5264e36 | ||
|
|
97d0a6776f | ||
|
|
7e7e056f3a | ||
|
|
785f94d13f | ||
|
|
bfb6750b13 | ||
|
|
f5e1057127 | ||
|
|
ee393d0e62 | ||
|
|
0b8fc5da59 | ||
|
|
2d0a54f31a | ||
|
|
61ec8d67de | ||
|
|
76add0b9b2 | ||
|
|
a11341f57a | ||
|
|
b135d462d6 | ||
|
|
da37a28951 | ||
|
|
4f884d9f30 | ||
|
|
2bed8b641b | ||
|
|
b4f696272a | ||
|
|
6d937af7a0 | ||
|
|
db5b6cfbb7 | ||
|
|
e75948753a | ||
|
|
047cc958b5 | ||
|
|
cd005ef9a9 | ||
|
|
44ed0c1992 | ||
|
|
d6d3fa95c7 | ||
|
|
fa90283781 | ||
|
|
8bf13b0d0c | ||
|
|
a8541a1529 | ||
|
|
94068d3ebc | ||
|
|
738c585ee7 | ||
|
|
9b5541d17d | ||
|
|
7123e6d1f4 | ||
|
|
62cf9e873b | ||
|
|
9f0aa1ce26 |
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@@ -12,6 +12,7 @@
|
||||
- [ ] Is a feature enhancement
|
||||
- [ ] It is a refactor
|
||||
- [ ] Created tests that fail without the change (if possible)
|
||||
- [ ] This change does **not** modify the public API, gRPC protocols, functionality behavior, CLI / service flags, or introduce a new feature — **OR** I have discussed it with the NetBird team beforehand (link the issue / Slack thread in the description). See [CONTRIBUTING.md](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTING.md#discuss-changes-with-the-netbird-team-first).
|
||||
|
||||
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
|
||||
|
||||
|
||||
4
.github/workflows/wasm-build-validation.yml
vendored
4
.github/workflows/wasm-build-validation.yml
vendored
@@ -61,8 +61,8 @@ jobs:
|
||||
|
||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||
|
||||
if [ ${SIZE} -gt 58720256 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||
if [ ${SIZE} -gt 62914560 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 60MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ If you haven't already, join our slack workspace [here](https://docs.netbird.io/
|
||||
- [Contributing to NetBird](#contributing-to-netbird)
|
||||
- [Contents](#contents)
|
||||
- [Code of conduct](#code-of-conduct)
|
||||
- [Discuss changes with the NetBird team first](#discuss-changes-with-the-netbird-team-first)
|
||||
- [Directory structure](#directory-structure)
|
||||
- [Development setup](#development-setup)
|
||||
- [Requirements](#requirements)
|
||||
@@ -33,6 +34,14 @@ Conduct which can be found in the file [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md).
|
||||
By participating, you are expected to uphold this code. Please report
|
||||
unacceptable behavior to community@netbird.io.
|
||||
|
||||
## Discuss changes with the NetBird team first
|
||||
|
||||
Changes to the **public API**, **gRPC protocols**, **functionality behavior**, **CLI / service flags**, or **new features** should be discussed with the NetBird team before you start the work. These surfaces are part of NetBird's contract with operators, self-hosters, and downstream integrators, and changes to them have compatibility, security, and release-planning implications that benefit from an early conversation.
|
||||
|
||||
Open an issue or reach out on [Slack](https://docs.netbird.io/slack-url) to talk through what you have in mind. We'll help shape the change, flag any constraints we know about, and confirm the direction so the PR review can focus on implementation rather than design.
|
||||
|
||||
Typical bug fixes, internal refactors, documentation updates, and tests do not need pre-discussion — open the PR directly.
|
||||
|
||||
## Directory structure
|
||||
|
||||
The NetBird project monorepo is organized to maintain most of its individual dependencies code within their directories, except for a few auxiliary or shared packages.
|
||||
|
||||
@@ -361,6 +361,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
req.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
req.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
req.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
@@ -467,6 +470,9 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
ic.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
ic.EnableSSHRoot = &enableSSHRoot
|
||||
@@ -595,6 +601,9 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
loginRequest.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
loginRequest.EnableSSHRoot = &enableSSHRoot
|
||||
|
||||
73
client/cmd/vnc_agent.go
Normal file
73
client/cmd/vnc_agent.go
Normal file
@@ -0,0 +1,73 @@
|
||||
//go:build windows || (darwin && !ios)
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
var vncAgentPort uint16
|
||||
|
||||
func init() {
|
||||
vncAgentCmd.Flags().Uint16Var(&vncAgentPort, "port", 15900, "Port for the VNC agent to listen on")
|
||||
rootCmd.AddCommand(vncAgentCmd)
|
||||
}
|
||||
|
||||
// vncAgentCmd runs a VNC server inside the user's interactive session,
|
||||
// listening on localhost. The NetBird service spawns it: on Windows via
|
||||
// CreateProcessAsUser into the console session, on macOS via
|
||||
// launchctl asuser into the Aqua session.
|
||||
var vncAgentCmd = &cobra.Command{
|
||||
Use: "vnc-agent",
|
||||
Short: "Run VNC capture agent (internal, spawned by service)",
|
||||
Hidden: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
log.SetReportCaller(true)
|
||||
log.SetFormatter(&log.JSONFormatter{})
|
||||
log.SetOutput(os.Stderr)
|
||||
|
||||
log.Infof("VNC agent starting on 127.0.0.1:%d", vncAgentPort)
|
||||
|
||||
token := os.Getenv("NB_VNC_AGENT_TOKEN")
|
||||
if token == "" {
|
||||
return fmt.Errorf("NB_VNC_AGENT_TOKEN not set; agent requires a token from the service")
|
||||
}
|
||||
// Drop the token from our process environment so any child the
|
||||
// agent spawns does not inherit it, and casual debugging tools
|
||||
// that dump /proc/<pid>/environ (or the Windows equivalent) on a
|
||||
// running agent don't surface the loopback shared secret.
|
||||
if err := os.Unsetenv("NB_VNC_AGENT_TOKEN"); err != nil {
|
||||
log.Debugf("unset NB_VNC_AGENT_TOKEN: %v", err)
|
||||
}
|
||||
|
||||
capturer, injector, err := newAgentResources()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// The per-user agent listens only on loopback and is gated by an
|
||||
// agent token shared with the daemon, so no X25519 identity key
|
||||
// is needed; auth is disabled at the RFB layer.
|
||||
srv := vncserver.New(capturer, injector, nil)
|
||||
srv.SetDisableAuth(true)
|
||||
srv.SetAgentToken(token)
|
||||
|
||||
addr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), vncAgentPort)
|
||||
loopback := netip.PrefixFrom(netip.AddrFrom4([4]byte{127, 0, 0, 0}), 8)
|
||||
if err := srv.Start(cmd.Context(), addr, loopback); err != nil {
|
||||
return fmt.Errorf("start vnc server: %w", err)
|
||||
}
|
||||
log.Infof("vnc-agent listening on 127.0.0.1:%d, ready", vncAgentPort)
|
||||
|
||||
<-cmd.Context().Done()
|
||||
log.Info("vnc-agent context cancelled, shutting down")
|
||||
return srv.Stop()
|
||||
},
|
||||
SilenceUsage: true,
|
||||
}
|
||||
18
client/cmd/vnc_agent_darwin.go
Normal file
18
client/cmd/vnc_agent_darwin.go
Normal file
@@ -0,0 +1,18 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newAgentResources() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
|
||||
capturer := vncserver.NewMacPoller()
|
||||
injector, err := vncserver.NewMacInputInjector()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("macOS input injector: %w", err)
|
||||
}
|
||||
return capturer, injector, nil
|
||||
}
|
||||
15
client/cmd/vnc_agent_windows.go
Normal file
15
client/cmd/vnc_agent_windows.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newAgentResources() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
|
||||
sessionID := vncserver.GetCurrentSessionID()
|
||||
log.Infof("VNC agent running in Windows session %d", sessionID)
|
||||
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector(), nil
|
||||
}
|
||||
9
client/cmd/vnc_flags.go
Normal file
9
client/cmd/vnc_flags.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package cmd
|
||||
|
||||
const serverVNCAllowedFlag = "allow-server-vnc"
|
||||
|
||||
var serverVNCAllowed bool
|
||||
|
||||
func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&serverVNCAllowed, serverVNCAllowedFlag, false, "Allow embedded VNC server on peer")
|
||||
}
|
||||
@@ -315,6 +315,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
a.config.RosenpassEnabled,
|
||||
a.config.RosenpassPermissive,
|
||||
a.config.ServerSSHAllowed,
|
||||
a.config.ServerVNCAllowed,
|
||||
a.config.DisableClientRoutes,
|
||||
a.config.DisableServerRoutes,
|
||||
a.config.DisableDNS,
|
||||
|
||||
@@ -562,6 +562,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
ServerVNCAllowed: config.ServerVNCAllowed != nil && *config.ServerVNCAllowed,
|
||||
EnableSSHRoot: config.EnableSSHRoot,
|
||||
EnableSSHSFTP: config.EnableSSHSFTP,
|
||||
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
||||
@@ -644,6 +645,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.RosenpassEnabled,
|
||||
config.RosenpassPermissive,
|
||||
config.ServerSSHAllowed,
|
||||
config.ServerVNCAllowed,
|
||||
config.DisableClientRoutes,
|
||||
config.DisableServerRoutes,
|
||||
config.DisableDNS,
|
||||
|
||||
@@ -636,6 +636,9 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
if g.internalConfig.SSHJWTCacheTTL != nil {
|
||||
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
|
||||
}
|
||||
if g.internalConfig.ServerVNCAllowed != nil {
|
||||
configContent.WriteString(fmt.Sprintf("ServerVNCAllowed: %v\n", *g.internalConfig.ServerVNCAllowed))
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||
|
||||
@@ -862,6 +862,7 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
||||
RosenpassEnabled: true,
|
||||
RosenpassPermissive: true,
|
||||
ServerSSHAllowed: &bTrue,
|
||||
ServerVNCAllowed: &bTrue,
|
||||
EnableSSHRoot: &bTrue,
|
||||
EnableSSHSFTP: &bTrue,
|
||||
EnableSSHLocalPortForwarding: &bTrue,
|
||||
|
||||
@@ -123,6 +123,7 @@ type EngineConfig struct {
|
||||
RosenpassPermissive bool
|
||||
|
||||
ServerSSHAllowed bool
|
||||
ServerVNCAllowed bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
@@ -205,6 +206,7 @@ type Engine struct {
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServer sshServer
|
||||
vncSrv vncServer
|
||||
|
||||
statusRecorder *peer.Status
|
||||
|
||||
@@ -320,6 +322,10 @@ func (e *Engine) Stop() error {
|
||||
log.Warnf("failed to stop SSH server: %v", err)
|
||||
}
|
||||
|
||||
if err := e.stopVNCServer(); err != nil {
|
||||
log.Warnf("failed to stop VNC server: %v", err)
|
||||
}
|
||||
|
||||
e.cleanupSSHConfig()
|
||||
|
||||
if e.ingressGatewayMgr != nil {
|
||||
@@ -1010,6 +1016,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -1057,6 +1064,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := e.updateVNC(); err != nil {
|
||||
log.Warnf("failed handling VNC server setup: %v", err)
|
||||
}
|
||||
|
||||
state := e.statusRecorder.GetLocalPeerState()
|
||||
state.IP = e.wgInterface.Address().String()
|
||||
state.IPv6 = e.wgInterface.Address().IPv6String()
|
||||
@@ -1182,6 +1193,7 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -1371,6 +1383,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||
}
|
||||
|
||||
// VNC auth: always sync, including nil so cleared auth on the management
|
||||
// side is applied locally, and so it isn't skipped on the RemotePeersIsEmpty
|
||||
// cleanup path.
|
||||
e.updateVNCServerAuth(networkMap.GetVncAuth())
|
||||
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
|
||||
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
||||
@@ -1826,6 +1843,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
|
||||
236
client/internal/engine_vnc.go
Normal file
236
client/internal/engine_vnc.go
Normal file
@@ -0,0 +1,236 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/metrics"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
const (
|
||||
vncExternalPort uint16 = 5900
|
||||
vncInternalPort uint16 = 25900
|
||||
)
|
||||
|
||||
type vncServer interface {
|
||||
Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error
|
||||
Stop() error
|
||||
ActiveSessions() []vncserver.ActiveSessionInfo
|
||||
}
|
||||
|
||||
func (e *Engine) setupVNCPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
|
||||
return fmt.Errorf("add VNC port redirection: %w", err)
|
||||
}
|
||||
log.Infof("VNC port redirection: %s:%d -> %s:%d", localAddr, vncExternalPort, localAddr, vncInternalPort)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) cleanupVNCPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
|
||||
return fmt.Errorf("remove VNC port redirection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateVNC handles starting/stopping the VNC server based on the config flag.
|
||||
func (e *Engine) updateVNC() error {
|
||||
if !e.config.ServerVNCAllowed {
|
||||
if e.vncSrv != nil {
|
||||
log.Info("VNC server disabled, stopping")
|
||||
}
|
||||
return e.stopVNCServer()
|
||||
}
|
||||
|
||||
if e.config.BlockInbound {
|
||||
log.Info("VNC server disabled because inbound connections are blocked")
|
||||
return e.stopVNCServer()
|
||||
}
|
||||
|
||||
if e.vncSrv != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.startVNCServer()
|
||||
}
|
||||
|
||||
func (e *Engine) startVNCServer() error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wg interface not initialized")
|
||||
}
|
||||
|
||||
capturer, injector, ok := newPlatformVNC()
|
||||
if !ok {
|
||||
log.Debug("VNC server not supported on this platform")
|
||||
return nil
|
||||
}
|
||||
|
||||
netbirdIP := e.wgInterface.Address().IP
|
||||
|
||||
srv := vncserver.New(capturer, injector, e.config.WgPrivateKey[:])
|
||||
if e.clientMetrics != nil {
|
||||
srv.SetSessionRecorder(func(t vncserver.SessionTick) {
|
||||
e.clientMetrics.RecordVNCSessionTick(e.ctx, metrics.VNCSessionTick{
|
||||
Period: t.Period,
|
||||
BytesOut: t.BytesOut,
|
||||
Writes: t.Writes,
|
||||
FBUs: t.FBUs,
|
||||
MaxFBUBytes: t.MaxFBUBytes,
|
||||
MaxFBURects: t.MaxFBURects,
|
||||
MaxWriteBytes: t.MaxWriteBytes,
|
||||
WriteNanos: t.WriteNanos,
|
||||
})
|
||||
})
|
||||
}
|
||||
if vncNeedsServiceMode() {
|
||||
log.Info("VNC: running in Session 0, enabling service mode (agent proxy)")
|
||||
srv.SetServiceMode(true)
|
||||
}
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
srv.SetNetstackNet(netstackNet)
|
||||
}
|
||||
|
||||
listenAddr := netip.AddrPortFrom(netbirdIP, vncInternalPort)
|
||||
network := e.wgInterface.Address().Network
|
||||
if err := srv.Start(e.ctx, listenAddr, network); err != nil {
|
||||
return fmt.Errorf("start VNC server: %w", err)
|
||||
}
|
||||
|
||||
e.vncSrv = srv
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.RegisterNetstackService(nftypes.TCP, vncInternalPort)
|
||||
log.Debugf("registered VNC service with netstack for TCP:%d", vncInternalPort)
|
||||
}
|
||||
}
|
||||
|
||||
if err := e.setupVNCPortRedirection(); err != nil {
|
||||
log.Warnf("setup VNC port redirection: %v", err)
|
||||
}
|
||||
|
||||
log.Info("VNC server enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// updateVNCServerAuth updates VNC fine-grained access control from management.
|
||||
func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
|
||||
if vncAuth == nil || e.vncSrv == nil {
|
||||
return
|
||||
}
|
||||
|
||||
vncSrv, ok := e.vncSrv.(*vncserver.Server)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
protoUsers := vncAuth.GetAuthorizedUsers()
|
||||
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
|
||||
for i, hash := range protoUsers {
|
||||
if len(hash) != 16 {
|
||||
log.Warnf("invalid VNC auth hash length %d, expected 16", len(hash))
|
||||
return
|
||||
}
|
||||
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
|
||||
}
|
||||
|
||||
machineUsers := make(map[string][]uint32)
|
||||
for osUser, indexes := range vncAuth.GetMachineUsers() {
|
||||
machineUsers[osUser] = indexes.GetIndexes()
|
||||
}
|
||||
|
||||
sessionPubKeys := make([]sshauth.SessionPubKey, 0, len(vncAuth.GetSessionPubKeys()))
|
||||
for _, e := range vncAuth.GetSessionPubKeys() {
|
||||
pub := e.GetPubKey()
|
||||
if len(pub) != 32 {
|
||||
log.Warnf("VNC session pubkey wrong length %d", len(pub))
|
||||
continue
|
||||
}
|
||||
hash := e.GetUserIdHash()
|
||||
if len(hash) != 16 {
|
||||
log.Warnf("VNC session user id hash wrong length %d", len(hash))
|
||||
continue
|
||||
}
|
||||
sessionPubKeys = append(sessionPubKeys, sshauth.SessionPubKey{
|
||||
PubKey: pub,
|
||||
UserIDHash: sshuserhash.UserIDHash(hash),
|
||||
})
|
||||
}
|
||||
|
||||
vncSrv.UpdateVNCAuth(&sshauth.Config{
|
||||
AuthorizedUsers: authorizedUsers,
|
||||
MachineUsers: machineUsers,
|
||||
SessionPubKeys: sessionPubKeys,
|
||||
})
|
||||
}
|
||||
|
||||
// GetVNCServerStatus returns whether the VNC server is running and the list
|
||||
// of active VNC sessions.
|
||||
func (e *Engine) GetVNCServerStatus() (enabled bool, sessions []vncserver.ActiveSessionInfo) {
|
||||
if e.vncSrv == nil {
|
||||
return false, nil
|
||||
}
|
||||
return true, e.vncSrv.ActiveSessions()
|
||||
}
|
||||
|
||||
func (e *Engine) stopVNCServer() error {
|
||||
if e.vncSrv == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.cleanupVNCPortRedirection(); err != nil {
|
||||
log.Warnf("cleanup VNC port redirection: %v", err)
|
||||
}
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.UnregisterNetstackService(nftypes.TCP, vncInternalPort)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("stopping VNC server")
|
||||
err := e.vncSrv.Stop()
|
||||
e.vncSrv = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop VNC server: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
31
client/internal/engine_vnc_console_freebsd.go
Normal file
31
client/internal/engine_vnc_console_freebsd.go
Normal file
@@ -0,0 +1,31 @@
|
||||
//go:build freebsd
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
// newConsoleVNC builds the FreeBSD console fallback: vt(4) framebuffer
|
||||
// for capture, /dev/uinput for input. The uinput device requires the
|
||||
// `uinput` kernel module (`kldload uinput`); without it, input init
|
||||
// fails and we drop to a stub injector so the user still gets a
|
||||
// view-only screen mirror.
|
||||
func newConsoleVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
|
||||
poller := vncserver.NewFBPoller("")
|
||||
w, h := poller.Width(), poller.Height()
|
||||
if w == 0 || h == 0 {
|
||||
poller.Close()
|
||||
return nil, nil, fmt.Errorf("vt framebuffer init failed (vt may not allow mmap on this driver)")
|
||||
}
|
||||
if inj, err := vncserver.NewUInputInjector(w, h); err == nil {
|
||||
return poller, inj, nil
|
||||
} else {
|
||||
log.Infof("VNC console: uinput unavailable (%v); view-only mode. Run `kldload uinput` to enable input.", err)
|
||||
return poller, &vncserver.StubInputInjector{}, nil
|
||||
}
|
||||
}
|
||||
30
client/internal/engine_vnc_console_linux.go
Normal file
30
client/internal/engine_vnc_console_linux.go
Normal file
@@ -0,0 +1,30 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
// newConsoleVNC builds a framebuffer + uinput VNC backend for boxes
|
||||
// without a running X server. Used as the auto-fallback when
|
||||
// newPlatformVNC can't reach X. Returns an error when /dev/fb0 or
|
||||
// /dev/uinput aren't usable so the caller can drop back to a stub.
|
||||
func newConsoleVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
|
||||
poller := vncserver.NewFBPoller("")
|
||||
w, h := poller.Width(), poller.Height()
|
||||
if w == 0 || h == 0 {
|
||||
poller.Close()
|
||||
return nil, nil, fmt.Errorf("framebuffer capturer init failed (is /dev/fb0 readable?)")
|
||||
}
|
||||
inj, err := vncserver.NewUInputInjector(w, h)
|
||||
if err != nil {
|
||||
log.Debugf("uinput unavailable, falling back to view-only VNC: %v", err)
|
||||
return poller, &vncserver.StubInputInjector{}, nil
|
||||
}
|
||||
return poller, inj, nil
|
||||
}
|
||||
34
client/internal/engine_vnc_darwin.go
Normal file
34
client/internal/engine_vnc_darwin.go
Normal file
@@ -0,0 +1,34 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
|
||||
capturer := vncserver.NewMacPoller()
|
||||
// Prompt for Screen Recording at server-enable time rather than first
|
||||
// client-connect. The native prompt is far easier for users to act on
|
||||
// in the moment they toggled VNC on than later when "the screen looks
|
||||
// like wallpaper" would otherwise be the only clue.
|
||||
vncserver.PrimeScreenCapturePermission()
|
||||
injector, err := vncserver.NewMacInputInjector()
|
||||
if err != nil {
|
||||
log.Debugf("VNC: macOS input injector: %v", err)
|
||||
return capturer, &vncserver.StubInputInjector{}, true
|
||||
}
|
||||
return capturer, injector, true
|
||||
}
|
||||
|
||||
// vncNeedsServiceMode reports whether the running process is a system
|
||||
// LaunchDaemon (root, parented by launchd). Daemons sit in the global
|
||||
// bootstrap namespace and cannot talk to WindowServer; we route capture
|
||||
// through a per-user agent in that case.
|
||||
func vncNeedsServiceMode() bool {
|
||||
return os.Geteuid() == 0 && os.Getppid() == 1
|
||||
}
|
||||
17
client/internal/engine_vnc_stub.go
Normal file
17
client/internal/engine_vnc_stub.go
Normal file
@@ -0,0 +1,17 @@
|
||||
//go:build js || ios || android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type vncServer interface{}
|
||||
|
||||
func (e *Engine) updateVNC() error { return nil }
|
||||
|
||||
func (e *Engine) updateVNCServerAuth(_ *mgmProto.VNCAuth) {
|
||||
// no-op on platforms without a VNC server
|
||||
}
|
||||
|
||||
func (e *Engine) stopVNCServer() error { return nil }
|
||||
13
client/internal/engine_vnc_windows.go
Normal file
13
client/internal/engine_vnc_windows.go
Normal file
@@ -0,0 +1,13 @@
|
||||
//go:build windows
|
||||
|
||||
package internal
|
||||
|
||||
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
|
||||
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector(), true
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return vncserver.GetCurrentSessionID() == 0
|
||||
}
|
||||
35
client/internal/engine_vnc_x11.go
Normal file
35
client/internal/engine_vnc_x11.go
Normal file
@@ -0,0 +1,35 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
|
||||
// Prefer X11 when an X server is reachable. NewX11InputInjector probes
|
||||
// DISPLAY (and /proc) eagerly, so a non-nil error here means no X.
|
||||
injector, err := vncserver.NewX11InputInjector("")
|
||||
if err == nil {
|
||||
return vncserver.NewX11Poller(""), injector, true
|
||||
}
|
||||
log.Debugf("VNC: X11 not available: %v", err)
|
||||
|
||||
// Fallback for headless / pre-X states (kernel console, login manager
|
||||
// without X, physical server in recovery): stream the framebuffer and
|
||||
// inject input via /dev/uinput.
|
||||
consoleCap, consoleInj, err := newConsoleVNC()
|
||||
if err == nil {
|
||||
log.Infof("VNC: using framebuffer console capture (%dx%d)", consoleCap.Width(), consoleCap.Height())
|
||||
return consoleCap, consoleInj, true
|
||||
}
|
||||
log.Debugf("VNC: framebuffer console fallback unavailable: %v", err)
|
||||
|
||||
return &vncserver.StubCapturer{}, &vncserver.StubInputInjector{}, false
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return false
|
||||
}
|
||||
@@ -120,6 +120,36 @@ func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentI
|
||||
m.trimLocked()
|
||||
}
|
||||
|
||||
func (m *influxDBMetrics) RecordVNCSessionTick(_ context.Context, agentInfo AgentInfo, tick VNCSessionTick) {
|
||||
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s",
|
||||
agentInfo.DeploymentType.String(),
|
||||
agentInfo.Version,
|
||||
agentInfo.OS,
|
||||
agentInfo.Arch,
|
||||
agentInfo.peerID,
|
||||
)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.samples = append(m.samples, influxSample{
|
||||
measurement: "netbird_vnc_traffic",
|
||||
tags: tags,
|
||||
fields: map[string]float64{
|
||||
"period_seconds": tick.Period.Seconds(),
|
||||
"bytes_out": float64(tick.BytesOut),
|
||||
"writes": float64(tick.Writes),
|
||||
"fbus": float64(tick.FBUs),
|
||||
"max_fbu_bytes": float64(tick.MaxFBUBytes),
|
||||
"max_fbu_rects": float64(tick.MaxFBURects),
|
||||
"max_write_bytes": float64(tick.MaxWriteBytes),
|
||||
"write_time_seconds": float64(tick.WriteNanos) / 1e9,
|
||||
},
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
m.trimLocked()
|
||||
}
|
||||
|
||||
func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) {
|
||||
result := "success"
|
||||
if !success {
|
||||
|
||||
@@ -59,6 +59,11 @@ type metricsImplementation interface {
|
||||
// RecordLoginDuration records how long the login to management took
|
||||
RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool)
|
||||
|
||||
// RecordVNCSessionTick records a periodic snapshot of one VNC
|
||||
// session's wire activity. Called once per metricsConn tick interval
|
||||
// (and once at session close), only when the tick saw activity.
|
||||
RecordVNCSessionTick(ctx context.Context, agentInfo AgentInfo, tick VNCSessionTick)
|
||||
|
||||
// Export exports metrics in InfluxDB line protocol format
|
||||
Export(w io.Writer) error
|
||||
|
||||
@@ -78,6 +83,21 @@ type ClientMetrics struct {
|
||||
pushCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// VNCSessionTick is one sampling slice of a VNC session's wire activity.
|
||||
// BytesOut / Writes / FBUs / WriteNanos are deltas observed during this
|
||||
// tick; Max* fields are the high-water marks observed during the tick.
|
||||
// Period is the wall-clock duration the deltas cover.
|
||||
type VNCSessionTick struct {
|
||||
Period time.Duration
|
||||
BytesOut uint64
|
||||
Writes uint64
|
||||
FBUs uint64
|
||||
MaxFBUBytes uint64
|
||||
MaxFBURects uint64
|
||||
MaxWriteBytes uint64
|
||||
WriteNanos uint64
|
||||
}
|
||||
|
||||
// ConnectionStageTimestamps holds timestamps for each connection stage
|
||||
type ConnectionStageTimestamps struct {
|
||||
SignalingReceived time.Time // First signal received from remote peer (both initial and reconnection)
|
||||
@@ -127,6 +147,17 @@ func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Du
|
||||
c.impl.RecordSyncDuration(ctx, agentInfo, duration)
|
||||
}
|
||||
|
||||
// RecordVNCSessionTick records a periodic snapshot of one VNC session.
|
||||
func (c *ClientMetrics) RecordVNCSessionTick(ctx context.Context, tick VNCSessionTick) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.mu.RLock()
|
||||
agentInfo := c.agentInfo
|
||||
c.mu.RUnlock()
|
||||
c.impl.RecordVNCSessionTick(ctx, agentInfo, tick)
|
||||
}
|
||||
|
||||
// RecordLoginDuration records how long the login to management server took
|
||||
func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) {
|
||||
if c == nil {
|
||||
|
||||
@@ -73,6 +73,9 @@ func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.
|
||||
func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) {
|
||||
}
|
||||
|
||||
func (m *mockMetrics) RecordVNCSessionTick(_ context.Context, _ AgentInfo, _ VNCSessionTick) {
|
||||
}
|
||||
|
||||
func (m *mockMetrics) Export(w io.Writer) error {
|
||||
if m.exportData != "" {
|
||||
_, err := w.Write([]byte(m.exportData))
|
||||
|
||||
@@ -65,6 +65,7 @@ type ConfigInput struct {
|
||||
StateFilePath string
|
||||
PreSharedKey *string
|
||||
ServerSSHAllowed *bool
|
||||
ServerVNCAllowed *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
@@ -116,6 +117,7 @@ type Config struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed *bool
|
||||
ServerVNCAllowed *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
@@ -418,6 +420,21 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.ServerVNCAllowed != nil {
|
||||
if config.ServerVNCAllowed == nil || *input.ServerVNCAllowed != *config.ServerVNCAllowed {
|
||||
if *input.ServerVNCAllowed {
|
||||
log.Infof("enabling VNC server")
|
||||
} else {
|
||||
log.Infof("disabling VNC server")
|
||||
}
|
||||
config.ServerVNCAllowed = input.ServerVNCAllowed
|
||||
updated = true
|
||||
}
|
||||
} else if config.ServerVNCAllowed == nil {
|
||||
config.ServerVNCAllowed = util.True()
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
||||
if *input.EnableSSHRoot {
|
||||
log.Infof("enabling SSH root login")
|
||||
|
||||
@@ -188,7 +188,9 @@ func (d *Detector) triggerCallback(event EventType, cb func(event EventType), do
|
||||
}
|
||||
|
||||
doneChan := make(chan struct{})
|
||||
timeout := time.NewTimer(500 * time.Millisecond)
|
||||
// macOS forces sleep ~30s after kIOMessageSystemWillSleep, so block long
|
||||
// enough for teardown to finish while staying under that deadline.
|
||||
timeout := time.NewTimer(20 * time.Second)
|
||||
defer timeout.Stop()
|
||||
|
||||
go func() {
|
||||
|
||||
@@ -74,6 +74,14 @@ func New(filePath string) *Manager {
|
||||
}
|
||||
}
|
||||
|
||||
// FilePath returns the path of the underlying state file.
|
||||
func (m *Manager) FilePath() string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
return m.filePath
|
||||
}
|
||||
|
||||
// Start starts the state manager periodic save routine
|
||||
func (m *Manager) Start() {
|
||||
if m == nil {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -205,6 +205,8 @@ message LoginRequest {
|
||||
optional bool disableSSHAuth = 38;
|
||||
optional int32 sshJWTCacheTTL = 39;
|
||||
optional bool disable_ipv6 = 40;
|
||||
|
||||
optional bool serverVNCAllowed = 41;
|
||||
}
|
||||
|
||||
message LoginResponse {
|
||||
@@ -314,6 +316,8 @@ message GetConfigResponse {
|
||||
int32 sshJWTCacheTTL = 26;
|
||||
|
||||
bool disable_ipv6 = 27;
|
||||
|
||||
bool serverVNCAllowed = 28;
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
@@ -394,6 +398,22 @@ message SSHServerState {
|
||||
repeated SSHSessionInfo sessions = 2;
|
||||
}
|
||||
|
||||
// VNCSessionInfo contains information about an active VNC session
|
||||
message VNCSessionInfo {
|
||||
string remoteAddress = 1;
|
||||
string mode = 2;
|
||||
string username = 3;
|
||||
// userID is the Noise-verified session identity (hashed user ID from
|
||||
// the ACL session-key entry), empty when auth is disabled.
|
||||
string userID = 4;
|
||||
}
|
||||
|
||||
// VNCServerState contains the latest state of the VNC server
|
||||
message VNCServerState {
|
||||
bool enabled = 1;
|
||||
repeated VNCSessionInfo sessions = 2;
|
||||
}
|
||||
|
||||
// FullStatus contains the full state held by the Status instance
|
||||
message FullStatus {
|
||||
ManagementState managementState = 1;
|
||||
@@ -408,6 +428,7 @@ message FullStatus {
|
||||
|
||||
bool lazyConnectionEnabled = 9;
|
||||
SSHServerState sshServerState = 10;
|
||||
VNCServerState vncServerState = 11;
|
||||
}
|
||||
|
||||
// Networks
|
||||
@@ -678,6 +699,8 @@ message SetConfigRequest {
|
||||
optional bool disableSSHAuth = 33;
|
||||
optional int32 sshJWTCacheTTL = 34;
|
||||
optional bool disable_ipv6 = 35;
|
||||
|
||||
optional bool serverVNCAllowed = 36;
|
||||
}
|
||||
|
||||
message SetConfigResponse{}
|
||||
|
||||
@@ -376,6 +376,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
config.RosenpassPermissive = msg.RosenpassPermissive
|
||||
config.DisableAutoConnect = msg.DisableAutoConnect
|
||||
config.ServerSSHAllowed = msg.ServerSSHAllowed
|
||||
config.ServerVNCAllowed = msg.ServerVNCAllowed
|
||||
config.NetworkMonitor = msg.NetworkMonitor
|
||||
config.DisableClientRoutes = msg.DisableClientRoutes
|
||||
config.DisableServerRoutes = msg.DisableServerRoutes
|
||||
@@ -1136,6 +1137,7 @@ func (s *Server) Status(
|
||||
pbFullStatus := fullStatus.ToProto()
|
||||
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||
pbFullStatus.SshServerState = s.getSSHServerState()
|
||||
pbFullStatus.VncServerState = s.getVNCServerState()
|
||||
statusResponse.FullStatus = pbFullStatus
|
||||
}
|
||||
|
||||
@@ -1175,6 +1177,37 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
|
||||
return sshServerState
|
||||
}
|
||||
|
||||
// getVNCServerState retrieves the current VNC server state.
|
||||
func (s *Server) getVNCServerState() *proto.VNCServerState {
|
||||
s.mutex.Lock()
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
enabled, sessions := engine.GetVNCServerStatus()
|
||||
pbSessions := make([]*proto.VNCSessionInfo, 0, len(sessions))
|
||||
for _, sess := range sessions {
|
||||
pbSessions = append(pbSessions, &proto.VNCSessionInfo{
|
||||
RemoteAddress: sess.RemoteAddress,
|
||||
Mode: sess.Mode,
|
||||
Username: sess.Username,
|
||||
UserID: sess.UserID,
|
||||
})
|
||||
}
|
||||
return &proto.VNCServerState{
|
||||
Enabled: enabled,
|
||||
Sessions: pbSessions,
|
||||
}
|
||||
}
|
||||
|
||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||
func (s *Server) GetPeerSSHHostKey(
|
||||
ctx context.Context,
|
||||
@@ -1531,6 +1564,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
Mtu: int64(cfg.MTU),
|
||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||
ServerVNCAllowed: cfg.ServerVNCAllowed != nil && *cfg.ServerVNCAllowed,
|
||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||
|
||||
@@ -58,6 +58,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
rosenpassEnabled := true
|
||||
rosenpassPermissive := true
|
||||
serverSSHAllowed := true
|
||||
serverVNCAllowed := true
|
||||
interfaceName := "utun100"
|
||||
wireguardPort := int64(51820)
|
||||
preSharedKey := "test-psk"
|
||||
@@ -83,6 +84,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
RosenpassPermissive: &rosenpassPermissive,
|
||||
ServerSSHAllowed: &serverSSHAllowed,
|
||||
ServerVNCAllowed: &serverVNCAllowed,
|
||||
InterfaceName: &interfaceName,
|
||||
WireguardPort: &wireguardPort,
|
||||
OptionalPreSharedKey: &preSharedKey,
|
||||
@@ -127,6 +129,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
|
||||
require.NotNil(t, cfg.ServerSSHAllowed)
|
||||
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
|
||||
require.NotNil(t, cfg.ServerVNCAllowed)
|
||||
require.Equal(t, serverVNCAllowed, *cfg.ServerVNCAllowed)
|
||||
require.Equal(t, interfaceName, cfg.WgIface)
|
||||
require.Equal(t, int(wireguardPort), cfg.WgPort)
|
||||
require.Equal(t, preSharedKey, cfg.PreSharedKey)
|
||||
@@ -179,6 +183,7 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
||||
"RosenpassEnabled": true,
|
||||
"RosenpassPermissive": true,
|
||||
"ServerSSHAllowed": true,
|
||||
"ServerVNCAllowed": true,
|
||||
"InterfaceName": true,
|
||||
"WireguardPort": true,
|
||||
"OptionalPreSharedKey": true,
|
||||
@@ -240,6 +245,7 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
||||
"enable-rosenpass": "RosenpassEnabled",
|
||||
"rosenpass-permissive": "RosenpassPermissive",
|
||||
"allow-server-ssh": "ServerSSHAllowed",
|
||||
"allow-server-vnc": "ServerVNCAllowed",
|
||||
"interface-name": "InterfaceName",
|
||||
"wireguard-port": "WireguardPort",
|
||||
"preshared-key": "OptionalPreSharedKey",
|
||||
|
||||
@@ -15,13 +15,16 @@ const (
|
||||
DefaultUserIDClaim = "sub"
|
||||
// Wildcard is a special user ID that matches all users
|
||||
Wildcard = "*"
|
||||
// sessionPubKeyLen is the size of an X25519 static public key in bytes.
|
||||
sessionPubKeyLen = 32
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyUserID = errors.New("JWT user ID is empty")
|
||||
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
|
||||
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
|
||||
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
|
||||
ErrEmptyUserID = errors.New("JWT user ID is empty")
|
||||
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
|
||||
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
|
||||
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
|
||||
ErrSessionKeyNotKnown = errors.New("session pubkey not registered")
|
||||
)
|
||||
|
||||
// Authorizer handles SSH fine-grained access control authorization
|
||||
@@ -35,6 +38,12 @@ type Authorizer struct {
|
||||
// machineUsers maps OS login usernames to lists of authorized user indexes
|
||||
machineUsers map[string][]uint32
|
||||
|
||||
// sessionPubKeys maps an X25519 static public key (as map-safe
|
||||
// array) to the hashed user identity that key authenticates as.
|
||||
// Populated from management's temporary-access flow; used by VNC to
|
||||
// authenticate via the Noise_IK handshake.
|
||||
sessionPubKeys map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash
|
||||
|
||||
// mu protects the list of users
|
||||
mu sync.RWMutex
|
||||
}
|
||||
@@ -50,13 +59,25 @@ type Config struct {
|
||||
// MachineUsers maps OS login usernames to indexes in AuthorizedUsers
|
||||
// If a user wants to login as a specific OS user, their index must be in the corresponding list
|
||||
MachineUsers map[string][]uint32
|
||||
|
||||
// SessionPubKeys binds ephemeral X25519 static public keys to hashed
|
||||
// user identities. Populated for VNC; ignored on the SSH side.
|
||||
SessionPubKeys []SessionPubKey
|
||||
}
|
||||
|
||||
// SessionPubKey is a single ephemeral-key entry: the 32-byte X25519
|
||||
// static public key plus the hashed user identity it authenticates as.
|
||||
type SessionPubKey struct {
|
||||
PubKey []byte
|
||||
UserIDHash sshuserhash.UserIDHash
|
||||
}
|
||||
|
||||
// NewAuthorizer creates a new SSH authorizer with empty configuration
|
||||
func NewAuthorizer() *Authorizer {
|
||||
a := &Authorizer{
|
||||
userIDClaim: DefaultUserIDClaim,
|
||||
machineUsers: make(map[string][]uint32),
|
||||
userIDClaim: DefaultUserIDClaim,
|
||||
machineUsers: make(map[string][]uint32),
|
||||
sessionPubKeys: make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash),
|
||||
}
|
||||
|
||||
return a
|
||||
@@ -72,6 +93,7 @@ func (a *Authorizer) Update(config *Config) {
|
||||
a.userIDClaim = DefaultUserIDClaim
|
||||
a.authorizedUsers = []sshuserhash.UserIDHash{}
|
||||
a.machineUsers = make(map[string][]uint32)
|
||||
a.sessionPubKeys = make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash)
|
||||
log.Info("SSH authorization cleared")
|
||||
return
|
||||
}
|
||||
@@ -94,8 +116,29 @@ func (a *Authorizer) Update(config *Config) {
|
||||
}
|
||||
a.machineUsers = machineUsers
|
||||
|
||||
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings",
|
||||
len(config.AuthorizedUsers), len(machineUsers))
|
||||
sessionPubKeys := make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash, len(config.SessionPubKeys))
|
||||
conflicted := make(map[[sessionPubKeyLen]byte]struct{})
|
||||
for _, e := range config.SessionPubKeys {
|
||||
if len(e.PubKey) != sessionPubKeyLen {
|
||||
continue
|
||||
}
|
||||
var key [sessionPubKeyLen]byte
|
||||
copy(key[:], e.PubKey)
|
||||
if _, bad := conflicted[key]; bad {
|
||||
continue
|
||||
}
|
||||
if existing, ok := sessionPubKeys[key]; ok && existing != e.UserIDHash {
|
||||
log.Warnf("SSH auth: session pubkey bound to conflicting user hashes; dropping binding")
|
||||
delete(sessionPubKeys, key)
|
||||
conflicted[key] = struct{}{}
|
||||
continue
|
||||
}
|
||||
sessionPubKeys[key] = e.UserIDHash
|
||||
}
|
||||
a.sessionPubKeys = sessionPubKeys
|
||||
|
||||
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings, %d session pubkeys",
|
||||
len(config.AuthorizedUsers), len(machineUsers), len(sessionPubKeys))
|
||||
}
|
||||
|
||||
// Authorize validates if a user is authorized to login as the specified OS user.
|
||||
@@ -155,6 +198,38 @@ func (a *Authorizer) GetUserIDClaim() string {
|
||||
return a.userIDClaim
|
||||
}
|
||||
|
||||
// LookupSessionKey resolves a Noise-verified static public key to the
|
||||
// hashed user identity registered with it. Fails closed when the key is
|
||||
// unknown.
|
||||
func (a *Authorizer) LookupSessionKey(pubKey []byte) (sshuserhash.UserIDHash, error) {
|
||||
var zero sshuserhash.UserIDHash
|
||||
if len(pubKey) != sessionPubKeyLen {
|
||||
return zero, fmt.Errorf("session pubkey wrong length: %d", len(pubKey))
|
||||
}
|
||||
var key [sessionPubKeyLen]byte
|
||||
copy(key[:], pubKey)
|
||||
a.mu.RLock()
|
||||
hash, ok := a.sessionPubKeys[key]
|
||||
a.mu.RUnlock()
|
||||
if !ok {
|
||||
return zero, ErrSessionKeyNotKnown
|
||||
}
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
// AuthorizeOSUserBySessionKey resolves the OS-user mapping for a session
|
||||
// key. Mirrors Authorize but skips the JWT-hash step since the key has
|
||||
// already been verified and the user identity hash is in hand.
|
||||
func (a *Authorizer) AuthorizeOSUserBySessionKey(userIDHash sshuserhash.UserIDHash, osUsername string) (string, error) {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
userIndex, found := a.findUserIndex(userIDHash)
|
||||
if !found {
|
||||
return "", fmt.Errorf("session user (hash: %s) not in authorized list for OS user %q: %w", userIDHash, osUsername, ErrUserNotAuthorized)
|
||||
}
|
||||
return a.checkMachineUserMapping("session", osUsername, userIndex)
|
||||
}
|
||||
|
||||
// findUserIndex finds the index of a hashed user ID in the authorized users list
|
||||
// Returns the index and true if found, 0 and false if not found
|
||||
func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -610,3 +611,61 @@ func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
|
||||
}
|
||||
|
||||
func TestAuthorizer_LookupSessionKey_Valid(t *testing.T) {
|
||||
pub := bytesRepeat(0x11, sessionPubKeyLen)
|
||||
userHash, err := sshauth.HashUserID("alice")
|
||||
require.NoError(t, err)
|
||||
|
||||
a := NewAuthorizer()
|
||||
a.Update(&Config{
|
||||
AuthorizedUsers: []sshauth.UserIDHash{userHash},
|
||||
MachineUsers: map[string][]uint32{Wildcard: {0}},
|
||||
SessionPubKeys: []SessionPubKey{{PubKey: pub, UserIDHash: userHash}},
|
||||
})
|
||||
|
||||
got, err := a.LookupSessionKey(pub)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userHash, got)
|
||||
|
||||
if _, err := a.AuthorizeOSUserBySessionKey(got, "alice"); err != nil {
|
||||
t.Fatalf("AuthorizeOSUserBySessionKey: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizer_LookupSessionKey_UnknownPub(t *testing.T) {
|
||||
a := NewAuthorizer()
|
||||
a.Update(&Config{})
|
||||
_, err := a.LookupSessionKey(bytesRepeat(0x22, sessionPubKeyLen))
|
||||
require.ErrorIs(t, err, ErrSessionKeyNotKnown)
|
||||
}
|
||||
|
||||
func TestAuthorizer_LookupSessionKey_WrongLength(t *testing.T) {
|
||||
a := NewAuthorizer()
|
||||
_, err := a.LookupSessionKey([]byte("short"))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAuthorizer_LookupSessionKey_UpdateClears(t *testing.T) {
|
||||
pub := bytesRepeat(0x33, sessionPubKeyLen)
|
||||
userHash, err := sshauth.HashUserID("alice")
|
||||
require.NoError(t, err)
|
||||
|
||||
a := NewAuthorizer()
|
||||
a.Update(&Config{SessionPubKeys: []SessionPubKey{{PubKey: pub, UserIDHash: userHash}}})
|
||||
if _, err := a.LookupSessionKey(pub); err != nil {
|
||||
t.Fatalf("setup lookup: %v", err)
|
||||
}
|
||||
a.Update(&Config{})
|
||||
if _, err := a.LookupSessionKey(pub); !errors.Is(err, ErrSessionKeyNotKnown) {
|
||||
t.Fatalf("expected ErrSessionKeyNotKnown, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func bytesRepeat(b byte, n int) []byte {
|
||||
out := make([]byte, n)
|
||||
for i := range out {
|
||||
out[i] = b
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -131,6 +131,18 @@ type SSHServerStateOutput struct {
|
||||
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
|
||||
}
|
||||
|
||||
type VNCSessionOutput struct {
|
||||
RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"`
|
||||
Mode string `json:"mode" yaml:"mode"`
|
||||
Username string `json:"username,omitempty" yaml:"username,omitempty"`
|
||||
UserID string `json:"userID,omitempty" yaml:"userID,omitempty"`
|
||||
}
|
||||
|
||||
type VNCServerStateOutput struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Sessions []VNCSessionOutput `json:"sessions" yaml:"sessions"`
|
||||
}
|
||||
|
||||
type OutputOverview struct {
|
||||
Peers PeersStateOutput `json:"peers" yaml:"peers"`
|
||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||
@@ -153,6 +165,7 @@ type OutputOverview struct {
|
||||
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
|
||||
ProfileName string `json:"profileName" yaml:"profileName"`
|
||||
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
|
||||
VNCServerState VNCServerStateOutput `json:"vncServer" yaml:"vncServer"`
|
||||
}
|
||||
|
||||
// ConvertToStatusOutputOverview converts protobuf status to the output overview.
|
||||
@@ -173,6 +186,7 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
|
||||
|
||||
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
||||
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
|
||||
vncServerOverview := mapVNCServer(pbFullStatus.GetVncServerState())
|
||||
peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter)
|
||||
|
||||
overview := OutputOverview{
|
||||
@@ -197,6 +211,7 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
|
||||
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
|
||||
ProfileName: opts.ProfileName,
|
||||
SSHServerState: sshServerOverview,
|
||||
VNCServerState: vncServerOverview,
|
||||
}
|
||||
|
||||
if opts.Anonymize {
|
||||
@@ -271,6 +286,25 @@ func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput {
|
||||
}
|
||||
}
|
||||
|
||||
func mapVNCServer(state *proto.VNCServerState) VNCServerStateOutput {
|
||||
if state == nil {
|
||||
return VNCServerStateOutput{Sessions: []VNCSessionOutput{}}
|
||||
}
|
||||
sessions := make([]VNCSessionOutput, 0, len(state.GetSessions()))
|
||||
for _, sess := range state.GetSessions() {
|
||||
sessions = append(sessions, VNCSessionOutput{
|
||||
RemoteAddress: sess.GetRemoteAddress(),
|
||||
Mode: sess.GetMode(),
|
||||
Username: sess.GetUsername(),
|
||||
UserID: sess.GetUserID(),
|
||||
})
|
||||
}
|
||||
return VNCServerStateOutput{
|
||||
Enabled: state.GetEnabled(),
|
||||
Sessions: sessions,
|
||||
}
|
||||
}
|
||||
|
||||
func mapPeers(
|
||||
peers []*proto.PeerState,
|
||||
statusFilter string,
|
||||
@@ -533,6 +567,34 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
}
|
||||
}
|
||||
|
||||
vncServerStatus := "Disabled"
|
||||
if o.VNCServerState.Enabled {
|
||||
vncSessionCount := len(o.VNCServerState.Sessions)
|
||||
if vncSessionCount > 0 {
|
||||
sessionWord := "session"
|
||||
if vncSessionCount > 1 {
|
||||
sessionWord = "sessions"
|
||||
}
|
||||
vncServerStatus = fmt.Sprintf("Enabled (%d active %s)", vncSessionCount, sessionWord)
|
||||
} else {
|
||||
vncServerStatus = "Enabled"
|
||||
}
|
||||
|
||||
if showSSHSessions && vncSessionCount > 0 {
|
||||
for _, sess := range o.VNCServerState.Sessions {
|
||||
var line string
|
||||
if sess.UserID != "" {
|
||||
line = fmt.Sprintf("[%s@%s -> %s] mode=%s",
|
||||
sess.UserID, sess.RemoteAddress, sess.Username, sess.Mode)
|
||||
} else {
|
||||
line = fmt.Sprintf("[%s] mode=%s user=%s",
|
||||
sess.RemoteAddress, sess.Mode, sess.Username)
|
||||
}
|
||||
vncServerStatus += "\n " + line
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
|
||||
|
||||
var forwardingRulesString string
|
||||
@@ -563,6 +625,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
"Quantum resistance: %s\n"+
|
||||
"Lazy connection: %s\n"+
|
||||
"SSH Server: %s\n"+
|
||||
"VNC Server: %s\n"+
|
||||
"Networks: %s\n"+
|
||||
"%s"+
|
||||
"Peers count: %s\n",
|
||||
@@ -581,6 +644,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
rosenpassEnabledStatus,
|
||||
lazyConnectionEnabledStatus,
|
||||
sshServerStatus,
|
||||
vncServerStatus,
|
||||
networks,
|
||||
forwardingRulesString,
|
||||
peersCountString,
|
||||
@@ -960,6 +1024,19 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
overview.Relays.Details[i] = detail
|
||||
}
|
||||
|
||||
anonymizeNSServerGroups(a, overview)
|
||||
|
||||
for i, route := range overview.Networks {
|
||||
overview.Networks[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||
|
||||
anonymizeEvents(a, overview)
|
||||
anonymizeServerSessions(a, overview)
|
||||
}
|
||||
|
||||
func anonymizeNSServerGroups(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
for i, nsGroup := range overview.NSServerGroups {
|
||||
for j, domain := range nsGroup.Domains {
|
||||
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
|
||||
@@ -971,13 +1048,9 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i, route := range overview.Networks {
|
||||
overview.Networks[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||
|
||||
func anonymizeEvents(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
for i, event := range overview.Events {
|
||||
overview.Events[i].Message = a.AnonymizeString(event.Message)
|
||||
overview.Events[i].UserMessage = a.AnonymizeString(event.UserMessage)
|
||||
@@ -986,13 +1059,23 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
event.Metadata[k] = a.AnonymizeString(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeRemoteAddress(a *anonymize.Anonymizer, addr string) string {
|
||||
if host, port, err := net.SplitHostPort(addr); err == nil {
|
||||
return fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
|
||||
}
|
||||
return a.AnonymizeIPString(addr)
|
||||
}
|
||||
|
||||
func anonymizeServerSessions(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
for i, session := range overview.SSHServerState.Sessions {
|
||||
if host, port, err := net.SplitHostPort(session.RemoteAddress); err == nil {
|
||||
overview.SSHServerState.Sessions[i].RemoteAddress = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
|
||||
} else {
|
||||
overview.SSHServerState.Sessions[i].RemoteAddress = a.AnonymizeIPString(session.RemoteAddress)
|
||||
}
|
||||
overview.SSHServerState.Sessions[i].RemoteAddress = anonymizeRemoteAddress(a, session.RemoteAddress)
|
||||
overview.SSHServerState.Sessions[i].Command = a.AnonymizeString(session.Command)
|
||||
}
|
||||
for i, sess := range overview.VNCServerState.Sessions {
|
||||
overview.VNCServerState.Sessions[i].RemoteAddress = anonymizeRemoteAddress(a, sess.RemoteAddress)
|
||||
overview.VNCServerState.Sessions[i].Username = a.AnonymizeString(sess.Username)
|
||||
overview.VNCServerState.Sessions[i].UserID = a.AnonymizeString(sess.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,6 +240,10 @@ var overview = OutputOverview{
|
||||
Enabled: false,
|
||||
Sessions: []SSHSessionOutput{},
|
||||
},
|
||||
VNCServerState: VNCServerStateOutput{
|
||||
Enabled: false,
|
||||
Sessions: []VNCSessionOutput{},
|
||||
},
|
||||
}
|
||||
|
||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||
@@ -404,6 +408,10 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"sshServer":{
|
||||
"enabled":false,
|
||||
"sessions":[]
|
||||
},
|
||||
"vncServer":{
|
||||
"enabled":false,
|
||||
"sessions":[]
|
||||
}
|
||||
}`
|
||||
// @formatter:on
|
||||
@@ -513,6 +521,9 @@ profileName: ""
|
||||
sshServer:
|
||||
enabled: false
|
||||
sessions: []
|
||||
vncServer:
|
||||
enabled: false
|
||||
sessions: []
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedYAML, yaml)
|
||||
@@ -582,6 +593,7 @@ Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Lazy connection: false
|
||||
SSH Server: Disabled
|
||||
VNC Server: Disabled
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
||||
@@ -607,6 +619,7 @@ Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Lazy connection: false
|
||||
SSH Server: Disabled
|
||||
VNC Server: Disabled
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`
|
||||
|
||||
@@ -62,6 +62,7 @@ type Info struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed bool
|
||||
ServerVNCAllowed bool
|
||||
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
@@ -83,6 +84,7 @@ type Info struct {
|
||||
func (i *Info) SetFlags(
|
||||
rosenpassEnabled, rosenpassPermissive bool,
|
||||
serverSSHAllowed *bool,
|
||||
serverVNCAllowed *bool,
|
||||
disableClientRoutes, disableServerRoutes,
|
||||
disableDNS, disableFirewall, blockLANAccess, blockInbound, disableIPv6, lazyConnectionEnabled bool,
|
||||
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
|
||||
@@ -93,6 +95,9 @@ func (i *Info) SetFlags(
|
||||
if serverSSHAllowed != nil {
|
||||
i.ServerSSHAllowed = *serverSSHAllowed
|
||||
}
|
||||
if serverVNCAllowed != nil {
|
||||
i.ServerVNCAllowed = *serverVNCAllowed
|
||||
}
|
||||
|
||||
i.DisableClientRoutes = disableClientRoutes
|
||||
i.DisableServerRoutes = disableServerRoutes
|
||||
|
||||
@@ -249,6 +249,7 @@ type serviceClient struct {
|
||||
mQuit *systray.MenuItem
|
||||
mNetworks *systray.MenuItem
|
||||
mAllowSSH *systray.MenuItem
|
||||
mAllowVNC *systray.MenuItem
|
||||
mAutoConnect *systray.MenuItem
|
||||
mEnableRosenpass *systray.MenuItem
|
||||
mLazyConnEnabled *systray.MenuItem
|
||||
@@ -1045,6 +1046,7 @@ func (s *serviceClient) onTrayReady() {
|
||||
|
||||
s.mSettings = systray.AddMenuItem("Settings", disabledMenuDescr)
|
||||
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false)
|
||||
s.mAllowVNC = s.mSettings.AddSubMenuItemCheckbox("Allow VNC", allowVNCMenuDescr, false)
|
||||
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false)
|
||||
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false)
|
||||
s.mLazyConnEnabled = s.mSettings.AddSubMenuItemCheckbox("Enable Lazy Connections", lazyConnMenuDescr, false)
|
||||
@@ -1452,6 +1454,7 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
|
||||
|
||||
config.DisableAutoConnect = cfg.DisableAutoConnect
|
||||
config.ServerSSHAllowed = &cfg.ServerSSHAllowed
|
||||
config.ServerVNCAllowed = &cfg.ServerVNCAllowed
|
||||
config.RosenpassEnabled = cfg.RosenpassEnabled
|
||||
config.RosenpassPermissive = cfg.RosenpassPermissive
|
||||
config.DisableNotifications = &cfg.DisableNotifications
|
||||
@@ -1547,6 +1550,12 @@ func (s *serviceClient) loadSettings() {
|
||||
s.mAllowSSH.Uncheck()
|
||||
}
|
||||
|
||||
if cfg.ServerVNCAllowed {
|
||||
s.mAllowVNC.Check()
|
||||
} else {
|
||||
s.mAllowVNC.Uncheck()
|
||||
}
|
||||
|
||||
if cfg.DisableAutoConnect {
|
||||
s.mAutoConnect.Uncheck()
|
||||
} else {
|
||||
@@ -1586,6 +1595,7 @@ func (s *serviceClient) loadSettings() {
|
||||
func (s *serviceClient) updateConfig() error {
|
||||
disableAutoStart := !s.mAutoConnect.Checked()
|
||||
sshAllowed := s.mAllowSSH.Checked()
|
||||
vncAllowed := s.mAllowVNC.Checked()
|
||||
rosenpassEnabled := s.mEnableRosenpass.Checked()
|
||||
lazyConnectionEnabled := s.mLazyConnEnabled.Checked()
|
||||
blockInbound := s.mBlockInbound.Checked()
|
||||
@@ -1614,6 +1624,7 @@ func (s *serviceClient) updateConfig() error {
|
||||
Username: currUser.Username,
|
||||
DisableAutoConnect: &disableAutoStart,
|
||||
ServerSSHAllowed: &sshAllowed,
|
||||
ServerVNCAllowed: &vncAllowed,
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
LazyConnectionEnabled: &lazyConnectionEnabled,
|
||||
BlockInbound: &blockInbound,
|
||||
|
||||
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
const (
|
||||
allowSSHMenuDescr = "Allow SSH connections"
|
||||
allowVNCMenuDescr = "Allow embedded VNC server"
|
||||
autoConnectMenuDescr = "Connect automatically when the service starts"
|
||||
quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass"
|
||||
lazyConnMenuDescr = "[Experimental] Enable lazy connections"
|
||||
|
||||
@@ -39,6 +39,8 @@ func (h *eventHandler) listen(ctx context.Context) {
|
||||
h.handleDisconnectClick()
|
||||
case <-h.client.mAllowSSH.ClickedCh:
|
||||
h.handleAllowSSHClick()
|
||||
case <-h.client.mAllowVNC.ClickedCh:
|
||||
h.handleAllowVNCClick()
|
||||
case <-h.client.mAutoConnect.ClickedCh:
|
||||
h.handleAutoConnectClick()
|
||||
case <-h.client.mEnableRosenpass.ClickedCh:
|
||||
@@ -134,6 +136,15 @@ func (h *eventHandler) handleAllowSSHClick() {
|
||||
|
||||
}
|
||||
|
||||
func (h *eventHandler) handleAllowVNCClick() {
|
||||
h.toggleCheckbox(h.client.mAllowVNC)
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
h.toggleCheckbox(h.client.mAllowVNC) // revert checkbox state on error
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
h.client.notifier.Send("Error", "Failed to update VNC settings")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *eventHandler) handleAutoConnectClick() {
|
||||
h.toggleCheckbox(h.client.mAutoConnect)
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
|
||||
327
client/vnc/server/agent_darwin.go
Normal file
327
client/vnc/server/agent_darwin.go
Normal file
@@ -0,0 +1,327 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// darwinAgentManager spawns a per-user VNC agent on demand and keeps it
|
||||
// alive across multiple client connections within the same console-user
|
||||
// session. A new agent is spawned the first time a client connects, or
|
||||
// whenever the console user changes underneath us.
|
||||
//
|
||||
// Lifecycle is lazy by design: a daemon that never receives a VNC
|
||||
// connection never spawns anything. The trade-off versus an eager spawn
|
||||
// (the Windows model) is that the first VNC client pays the launchctl
|
||||
// asuser + listen-readiness wait, ~hundreds of milliseconds in practice.
|
||||
// That cost only repeats on user switch.
|
||||
type darwinAgentManager struct {
|
||||
mu sync.Mutex
|
||||
authToken string
|
||||
port uint16
|
||||
uid uint32
|
||||
running bool
|
||||
}
|
||||
|
||||
func newDarwinAgentManager(ctx context.Context) *darwinAgentManager {
|
||||
m := &darwinAgentManager{port: agentPort}
|
||||
go m.watchConsoleUser(ctx)
|
||||
return m
|
||||
}
|
||||
|
||||
// watchConsoleUser kills the cached agent whenever the console user
|
||||
// changes (logout, fast user switch, login window). Without it the daemon
|
||||
// keeps proxying to an agent whose TCC grant and WindowServer access
|
||||
// belong to a user who is no longer at the screen, so the new user only
|
||||
// ever sees the locked-screen wallpaper. Killing the agent breaks the
|
||||
// loopback TCP that the daemon proxies into, the client disconnects, and
|
||||
// the next reconnect runs ensure() against the new console uid.
|
||||
func (m *darwinAgentManager) watchConsoleUser(ctx context.Context) {
|
||||
t := time.NewTicker(2 * time.Second)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
uid, err := consoleUserID()
|
||||
m.mu.Lock()
|
||||
if !m.running {
|
||||
m.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
if err != nil || uid != m.uid {
|
||||
prev := m.uid
|
||||
m.killLocked()
|
||||
m.mu.Unlock()
|
||||
if err != nil {
|
||||
log.Infof("console user gone (was uid=%d): %v; agent stopped", prev, err)
|
||||
} else {
|
||||
log.Infof("console user changed %d -> %d; agent stopped, will respawn on next connect", prev, uid)
|
||||
}
|
||||
continue
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ensure returns a token good for proxyToAgent. It spawns or respawns the
|
||||
// per-user agent process as needed and waits until it is listening on the
|
||||
// loopback port. Each ensure call is serialized so concurrent VNC clients
|
||||
// share the same agent.
|
||||
func (m *darwinAgentManager) ensure(ctx context.Context) (string, error) {
|
||||
consoleUID, err := consoleUserID()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("no console user: %w", err)
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.running && m.uid == consoleUID && vncAgentRunning() {
|
||||
return m.authToken, nil
|
||||
}
|
||||
m.killLocked()
|
||||
// Reap any stray external vnc-agent so the new token is the only one
|
||||
// the freshly spawned agent will accept on the loopback port.
|
||||
killAllVNCAgents()
|
||||
|
||||
token, err := generateAuthToken()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate agent auth token: %w", err)
|
||||
}
|
||||
if err := spawnAgentForUser(consoleUID, m.port, token); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := waitForAgent(ctx, m.port, 5*time.Second); err != nil {
|
||||
killAllVNCAgents()
|
||||
return "", fmt.Errorf("agent did not start listening: %w", err)
|
||||
}
|
||||
m.authToken = token
|
||||
m.uid = consoleUID
|
||||
m.running = true
|
||||
log.Infof("spawned VNC agent for console uid=%d on port %d", consoleUID, m.port)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// stop terminates the spawned agent, if any. Intended for daemon shutdown.
|
||||
func (m *darwinAgentManager) stop() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.killLocked()
|
||||
}
|
||||
|
||||
func (m *darwinAgentManager) killLocked() {
|
||||
if !m.running {
|
||||
return
|
||||
}
|
||||
killAllVNCAgents()
|
||||
m.running = false
|
||||
m.authToken = ""
|
||||
m.uid = 0
|
||||
}
|
||||
|
||||
// errNoConsoleUser is the sentinel callers use to recognise the
|
||||
// "login window showing, no user signed in" state and surface it as a
|
||||
// distinct condition to the VNC client.
|
||||
var errNoConsoleUser = errors.New("no user logged into console")
|
||||
|
||||
// consoleUserID returns the uid of the user currently sitting at the
|
||||
// console (the one whose Aqua session is active). Returns
|
||||
// errNoConsoleUser when nobody is logged in: at the login window
|
||||
// /dev/console is owned by root.
|
||||
func consoleUserID() (uint32, error) {
|
||||
info, err := os.Stat("/dev/console")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("stat /dev/console: %w", err)
|
||||
}
|
||||
st, ok := info.Sys().(*syscall.Stat_t)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("/dev/console stat has unexpected type")
|
||||
}
|
||||
if st.Uid == 0 {
|
||||
return 0, errNoConsoleUser
|
||||
}
|
||||
return st.Uid, nil
|
||||
}
|
||||
|
||||
// spawnAgentForUser uses launchctl asuser to start a netbird vnc-agent
|
||||
// process inside the target user's launchd bootstrap namespace. That is
|
||||
// the only spawn mode on macOS that gives the child access to the user's
|
||||
// WindowServer. The agent's stderr is relogged into the daemon log so
|
||||
// startup failures are not silently lost when the readiness check times
|
||||
// out.
|
||||
func spawnAgentForUser(uid uint32, port uint16, token string) error {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve own executable: %w", err)
|
||||
}
|
||||
cmd := exec.Command(
|
||||
"/bin/launchctl", "asuser", strconv.FormatUint(uint64(uid), 10),
|
||||
exe, vncAgentSubcommand, "--port", strconv.FormatUint(uint64(port), 10),
|
||||
)
|
||||
cmd.Env = append(os.Environ(), agentTokenEnvVar+"="+token)
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("agent stderr pipe: %w", err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("launchctl asuser: %w", err)
|
||||
}
|
||||
go func() {
|
||||
defer stderr.Close()
|
||||
relogAgentStream(stderr)
|
||||
}()
|
||||
go func() { _ = cmd.Wait() }()
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitForAgent dials the loopback port until the agent answers. Used to
|
||||
// gate proxy attempts until the spawned process has finished its Start.
|
||||
func waitForAgent(ctx context.Context, port uint16, wait time.Duration) error {
|
||||
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||
deadline := time.Now().Add(wait)
|
||||
for time.Now().Before(deadline) {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
c, err := net.DialTimeout("tcp", addr, 200*time.Millisecond)
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
return nil
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
return fmt.Errorf("timeout dialing %s", addr)
|
||||
}
|
||||
|
||||
// vncAgentRunning reports whether any vnc-agent process exists on the
|
||||
// system. The daemon owns the only port-15900 listener model, so any
|
||||
// match is "the" agent.
|
||||
func vncAgentRunning() bool {
|
||||
pids, err := vncAgentPIDs()
|
||||
if err != nil {
|
||||
log.Debugf("scan for vnc-agent: %v", err)
|
||||
return false
|
||||
}
|
||||
return len(pids) > 0
|
||||
}
|
||||
|
||||
// killAllVNCAgents sends SIGTERM to every process whose argv contains
|
||||
// "vnc-agent", waits briefly for them to exit, and escalates to SIGKILL
|
||||
// for any that remain. We enumerate kern.proc.all rather than
|
||||
// kern.proc.uid because launchctl asuser preserves the caller's uid
|
||||
// (root) on the spawned child, so a uid-scoped filter would never match.
|
||||
func killAllVNCAgents() {
|
||||
pids, err := vncAgentPIDs()
|
||||
if err != nil {
|
||||
log.Debugf("scan for vnc-agent: %v", err)
|
||||
return
|
||||
}
|
||||
for _, pid := range pids {
|
||||
_ = syscall.Kill(pid, syscall.SIGTERM)
|
||||
}
|
||||
if len(pids) == 0 {
|
||||
return
|
||||
}
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
remaining, _ := vncAgentPIDs()
|
||||
if len(remaining) == 0 {
|
||||
return
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
leftover, _ := vncAgentPIDs()
|
||||
for _, pid := range leftover {
|
||||
_ = syscall.Kill(pid, syscall.SIGKILL)
|
||||
}
|
||||
}
|
||||
|
||||
// vncAgentPIDs returns the pids of vnc-agent subprocesses spawned from
|
||||
// this binary. Matches exactly on argv[0] == our own executable path
|
||||
// AND argv[1] == "vnc-agent" so unrelated processes that happen to have
|
||||
// the same name elsewhere in argv are not targeted. Skips pid 0 and 1
|
||||
// defensively.
|
||||
func vncAgentPIDs() ([]int, error) {
|
||||
procs, err := unix.SysctlKinfoProcSlice("kern.proc.all")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sysctl kern.proc.all: %w", err)
|
||||
}
|
||||
ownExe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve own executable: %w", err)
|
||||
}
|
||||
var out []int
|
||||
for i := range procs {
|
||||
pid := int(procs[i].Proc.P_pid)
|
||||
if pid <= 1 {
|
||||
continue
|
||||
}
|
||||
argv, err := procArgv(pid)
|
||||
if err != nil || !argvIsVNCAgent(argv, ownExe) {
|
||||
continue
|
||||
}
|
||||
out = append(out, pid)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// procArgv reads the kernel's stored argv for pid via the kern.procargs2
|
||||
// sysctl. Format: 4-byte argc, then argv[0..argc) each NUL-terminated,
|
||||
// then envp, then padding. We only need argv so we stop after argc.
|
||||
func procArgv(pid int) ([]string, error) {
|
||||
raw, err := unix.SysctlRaw("kern.procargs2", pid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(raw) < 4 {
|
||||
return nil, fmt.Errorf("procargs2 truncated")
|
||||
}
|
||||
argc := int(raw[0]) | int(raw[1])<<8 | int(raw[2])<<16 | int(raw[3])<<24
|
||||
body := raw[4:]
|
||||
// Skip the executable path (NUL-terminated) and any zero padding that
|
||||
// follows before argv[0].
|
||||
end := bytes.IndexByte(body, 0)
|
||||
if end < 0 {
|
||||
return nil, fmt.Errorf("procargs2 path unterminated")
|
||||
}
|
||||
body = body[end+1:]
|
||||
for len(body) > 0 && body[0] == 0 {
|
||||
body = body[1:]
|
||||
}
|
||||
args := make([]string, 0, argc)
|
||||
for i := 0; i < argc; i++ {
|
||||
end := bytes.IndexByte(body, 0)
|
||||
if end < 0 {
|
||||
break
|
||||
}
|
||||
args = append(args, string(body[:end]))
|
||||
body = body[end+1:]
|
||||
}
|
||||
return args, nil
|
||||
}
|
||||
|
||||
// argvIsVNCAgent reports whether argv belongs to a vnc-agent subprocess
|
||||
// spawned from our binary. Requires argv[0] to match ownExe exactly and
|
||||
// argv[1] to be the vnc-agent subcommand. Matches the spawn shape in
|
||||
// spawnAgentForUser and rejects anything else.
|
||||
func argvIsVNCAgent(argv []string, ownExe string) bool {
|
||||
if len(argv) < 2 || ownExe == "" {
|
||||
return false
|
||||
}
|
||||
return argv[0] == ownExe && argv[1] == vncAgentSubcommand
|
||||
}
|
||||
178
client/vnc/server/agent_ipc.go
Normal file
178
client/vnc/server/agent_ipc.go
Normal file
@@ -0,0 +1,178 @@
|
||||
//go:build darwin || windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
crand "crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// agentPort is the TCP loopback port on which a per-session VNC agent
|
||||
// listens. The daemon dials this port and presents agentToken before
|
||||
// proxying VNC bytes. The choice of TCP (rather than a Unix socket or
|
||||
// named pipe) is intentional: it lets the same proxy/handshake code
|
||||
// run on every platform; the token does the access control.
|
||||
agentPort uint16 = 15900
|
||||
|
||||
// agentTokenLen is the size of the random per-spawn token in bytes.
|
||||
agentTokenLen = 32
|
||||
|
||||
// agentTokenEnvVar names the environment variable the daemon uses to
|
||||
// hand the per-spawn token to the agent child. Out-of-band channels
|
||||
// like this keep the secret out of the command line, where listings
|
||||
// such as `ps` or Windows tasklist would expose it.
|
||||
agentTokenEnvVar = "NB_VNC_AGENT_TOKEN" // #nosec G101 -- env var name, not a credential
|
||||
|
||||
// vncAgentSubcommand is the CLI subcommand the daemon invokes to start
|
||||
// the per-session agent process. Must match cmd.vncAgentCmd.Use in
|
||||
// client/cmd/vnc_agent.go.
|
||||
vncAgentSubcommand = "vnc-agent"
|
||||
)
|
||||
|
||||
// generateAuthToken returns a fresh hex-encoded random token for one
|
||||
// daemon→agent session. The daemon hands this to the spawned agent
|
||||
// out-of-band (env var on Windows) and verifies it on every connection
|
||||
// the agent accepts.
|
||||
func generateAuthToken() (string, error) {
|
||||
b := make([]byte, agentTokenLen)
|
||||
if _, err := crand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("read random: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// proxyToAgent dials the per-session agent on TCP loopback, writes the
|
||||
// raw token bytes, and then copies bytes in both directions until either
|
||||
// side closes. The token has to land on the wire before any VNC byte so
|
||||
// the agent's listening Server can apply verifyAgentToken before letting
|
||||
// real RFB traffic through.
|
||||
func proxyToAgent(ctx context.Context, client net.Conn, port uint16, authToken string) {
|
||||
defer client.Close()
|
||||
|
||||
addr := fmt.Sprintf("127.0.0.1:%d", port)
|
||||
agentConn, err := dialAgentWithRetry(ctx, addr)
|
||||
if err != nil {
|
||||
log.Warnf("proxy cannot reach agent at %s: %v", addr, err)
|
||||
return
|
||||
}
|
||||
defer agentConn.Close()
|
||||
|
||||
tokenBytes, err := hex.DecodeString(authToken)
|
||||
if err != nil || len(tokenBytes) != agentTokenLen {
|
||||
log.Warnf("invalid auth token (len=%d): %v", len(tokenBytes), err)
|
||||
return
|
||||
}
|
||||
if _, err := agentConn.Write(tokenBytes); err != nil {
|
||||
log.Warnf("send auth token to agent: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("proxy connected to agent, starting bidirectional copy")
|
||||
done := make(chan struct{}, 2)
|
||||
cp := func(label string, dst, src net.Conn) {
|
||||
n, err := io.Copy(dst, src)
|
||||
log.Debugf("proxy %s: %d bytes, err=%v", label, n, err)
|
||||
done <- struct{}{}
|
||||
}
|
||||
go cp("client→agent", agentConn, client)
|
||||
go cp("agent→client", client, agentConn)
|
||||
<-done
|
||||
}
|
||||
|
||||
// relogAgentStream reads log lines from the agent's stderr and re-emits
|
||||
// them through the daemon's logrus, so the merged log keeps a single
|
||||
// format. JSON lines (the agent's normal output) are parsed and dispatched
|
||||
// by level; plain-text lines (cobra errors, panic traces) are forwarded
|
||||
// verbatim so early-startup failures stay visible.
|
||||
func relogAgentStream(r io.Reader) {
|
||||
entry := log.WithField("component", "vnc-agent")
|
||||
scanner := bufio.NewScanner(r)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
if line[0] != '{' {
|
||||
entry.Warn(string(line))
|
||||
continue
|
||||
}
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(line, &m); err != nil {
|
||||
entry.Warn(string(line))
|
||||
continue
|
||||
}
|
||||
msg, _ := m["msg"].(string)
|
||||
if msg == "" {
|
||||
continue
|
||||
}
|
||||
fields := make(log.Fields)
|
||||
for k, v := range m {
|
||||
switch k {
|
||||
case "msg", "level", "time", "func":
|
||||
continue
|
||||
case "caller":
|
||||
fields["source"] = v
|
||||
default:
|
||||
fields[k] = v
|
||||
}
|
||||
}
|
||||
e := entry.WithFields(fields)
|
||||
switch m["level"] {
|
||||
case "error":
|
||||
e.Error(msg)
|
||||
case "warning":
|
||||
e.Warn(msg)
|
||||
case "debug":
|
||||
e.Debug(msg)
|
||||
case "trace":
|
||||
e.Trace(msg)
|
||||
default:
|
||||
e.Info(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dialAgentWithRetry retries the loopback connect for up to ~10 s so the
|
||||
// daemon does not race the agent's first listen. Returns the live conn or
|
||||
// the final error. Aborts early when ctx is cancelled so a Stop() during
|
||||
// service-mode startup doesn't leave a goroutine sleeping for 10 s.
|
||||
func dialAgentWithRetry(ctx context.Context, addr string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
var lastErr error
|
||||
for range 50 {
|
||||
if err := ctx.Err(); err != nil {
|
||||
if lastErr == nil {
|
||||
lastErr = err
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
dialCtx, cancel := context.WithTimeout(ctx, time.Second)
|
||||
c, err := d.DialContext(dialCtx, "tcp", addr)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return c, nil
|
||||
}
|
||||
lastErr = err
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if errors.Is(lastErr, context.Canceled) || errors.Is(lastErr, context.DeadlineExceeded) {
|
||||
lastErr = ctx.Err()
|
||||
}
|
||||
return nil, lastErr
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
692
client/vnc/server/agent_windows.go
Normal file
692
client/vnc/server/agent_windows.go
Normal file
@@ -0,0 +1,692 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
stillActive = 259
|
||||
|
||||
tokenPrimary = 1
|
||||
securityImpersonation = 2
|
||||
tokenSessionID = 12
|
||||
|
||||
createUnicodeEnvironment = 0x00000400
|
||||
createNoWindow = 0x08000000
|
||||
createSuspended = 0x00000004
|
||||
createBreakawayFromJob = 0x01000000
|
||||
)
|
||||
|
||||
var (
|
||||
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
advapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
||||
userenv = windows.NewLazySystemDLL("userenv.dll")
|
||||
|
||||
procWTSGetActiveConsoleSessionId = kernel32.NewProc("WTSGetActiveConsoleSessionId")
|
||||
procCreateJobObjectW = kernel32.NewProc("CreateJobObjectW")
|
||||
procSetInformationJobObject = kernel32.NewProc("SetInformationJobObject")
|
||||
procAssignProcessToJobObject = kernel32.NewProc("AssignProcessToJobObject")
|
||||
procSetTokenInformation = advapi32.NewProc("SetTokenInformation")
|
||||
procCreateEnvironmentBlock = userenv.NewProc("CreateEnvironmentBlock")
|
||||
procDestroyEnvironmentBlock = userenv.NewProc("DestroyEnvironmentBlock")
|
||||
|
||||
wtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll")
|
||||
procWTSEnumerateSessionsW = wtsapi32.NewProc("WTSEnumerateSessionsW")
|
||||
procWTSFreeMemory = wtsapi32.NewProc("WTSFreeMemory")
|
||||
procWTSQuerySessionInformation = wtsapi32.NewProc("WTSQuerySessionInformationW")
|
||||
|
||||
iphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
|
||||
procGetExtendedTcpTable = iphlpapi.NewProc("GetExtendedTcpTable")
|
||||
)
|
||||
|
||||
// GetCurrentSessionID returns the session ID of the current process.
|
||||
func GetCurrentSessionID() uint32 {
|
||||
var token windows.Token
|
||||
if err := windows.OpenProcessToken(windows.CurrentProcess(),
|
||||
windows.TOKEN_QUERY, &token); err != nil {
|
||||
return 0
|
||||
}
|
||||
defer token.Close()
|
||||
var id uint32
|
||||
var ret uint32
|
||||
_ = windows.GetTokenInformation(token, windows.TokenSessionId,
|
||||
(*byte)(unsafe.Pointer(&id)), 4, &ret)
|
||||
return id
|
||||
}
|
||||
|
||||
func getConsoleSessionID() uint32 {
|
||||
r, _, _ := procWTSGetActiveConsoleSessionId.Call()
|
||||
return uint32(r)
|
||||
}
|
||||
|
||||
const (
|
||||
wtsActive = 0
|
||||
wtsConnected = 1
|
||||
wtsDisconnected = 4
|
||||
)
|
||||
|
||||
// getActiveSessionID returns the session ID of the best session to attach to.
|
||||
// On a Windows Server with no console display attached, session 1 still
|
||||
// reports WTSActive (login screen "owns" the console), so a naive
|
||||
// first-active-wins pick lands on a session with no actual rendering.
|
||||
// Preference order:
|
||||
// 1. Active session with a user logged in (RDP user in session ≥2)
|
||||
// 2. Active session without a user (console at login screen)
|
||||
// 3. Console session ID
|
||||
func getActiveSessionID() uint32 {
|
||||
var sessionInfo uintptr
|
||||
var count uint32
|
||||
|
||||
r, _, _ := procWTSEnumerateSessionsW.Call(
|
||||
0, // WTS_CURRENT_SERVER_HANDLE
|
||||
0, // reserved
|
||||
1, // version
|
||||
uintptr(unsafe.Pointer(&sessionInfo)),
|
||||
uintptr(unsafe.Pointer(&count)),
|
||||
)
|
||||
if r == 0 || count == 0 {
|
||||
return getConsoleSessionID()
|
||||
}
|
||||
defer func() { _, _, _ = procWTSFreeMemory.Call(sessionInfo) }()
|
||||
|
||||
type wtsSession struct {
|
||||
SessionID uint32
|
||||
Station *uint16
|
||||
State uint32
|
||||
}
|
||||
sessions := unsafe.Slice((*wtsSession)(unsafe.Pointer(sessionInfo)), count)
|
||||
|
||||
var withUser uint32
|
||||
var withUserFound bool
|
||||
var anyActive uint32
|
||||
var anyActiveFound bool
|
||||
for _, s := range sessions {
|
||||
if s.SessionID == 0 {
|
||||
continue
|
||||
}
|
||||
if s.State != wtsActive {
|
||||
continue
|
||||
}
|
||||
if !anyActiveFound {
|
||||
anyActive = s.SessionID
|
||||
anyActiveFound = true
|
||||
}
|
||||
if !withUserFound && wtsSessionHasUser(s.SessionID) {
|
||||
withUser = s.SessionID
|
||||
withUserFound = true
|
||||
}
|
||||
}
|
||||
if withUserFound {
|
||||
return withUser
|
||||
}
|
||||
if anyActiveFound {
|
||||
return anyActive
|
||||
}
|
||||
return getConsoleSessionID()
|
||||
}
|
||||
|
||||
// reapOrphanOnPort finds any process listening on 127.0.0.1:port and, if
|
||||
// it's a netbird vnc-agent left over from a previous service instance,
|
||||
// terminates it. Verified by image-name match so we never kill an
|
||||
// unrelated process that happens to use the same port.
|
||||
func reapOrphanOnPort(port uint16) {
|
||||
pid := tcpListenerPID(port)
|
||||
if pid == 0 || pid == uint32(windows.GetCurrentProcessId()) {
|
||||
return
|
||||
}
|
||||
h, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION|windows.PROCESS_TERMINATE|windows.SYNCHRONIZE, false, pid)
|
||||
if err != nil {
|
||||
log.Warnf("reap on port %d: open PID=%d: %v", port, pid, err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = windows.CloseHandle(h) }()
|
||||
if !isOurAgentProcess(h) {
|
||||
log.Warnf("reap on port %d: PID=%d is not a netbird vnc-agent, leaving it alone", port, pid)
|
||||
return
|
||||
}
|
||||
if err := windows.TerminateProcess(h, 0); err != nil {
|
||||
log.Warnf("reap on port %d: terminate PID=%d: %v", port, pid, err)
|
||||
return
|
||||
}
|
||||
log.Infof("reaped orphan vnc-agent PID=%d holding port %d", pid, port)
|
||||
}
|
||||
|
||||
// isOurAgentProcess returns true if the given process handle points at a
|
||||
// netbird.exe binary at the same path as the current process. We compare
|
||||
// full paths (case-insensitive on Windows) so co-installed netbird binaries
|
||||
// from a different install dir or unrelated apps named netbird.exe don't
|
||||
// get killed.
|
||||
func isOurAgentProcess(h windows.Handle) bool {
|
||||
var size uint32 = windows.MAX_PATH
|
||||
buf := make([]uint16, size)
|
||||
if err := windows.QueryFullProcessImageName(h, 0, &buf[0], &size); err != nil {
|
||||
return false
|
||||
}
|
||||
target := strings.ToLower(windows.UTF16ToString(buf[:size]))
|
||||
selfExe, err := os.Executable()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return target == strings.ToLower(selfExe)
|
||||
}
|
||||
|
||||
// tcpListenerPID returns the PID of the process listening on 127.0.0.1:port,
|
||||
// or 0 if none. Uses GetExtendedTcpTable with TCP_TABLE_OWNER_PID_LISTENER.
|
||||
func tcpListenerPID(port uint16) uint32 {
|
||||
const tcpTableOwnerPidListener = 3
|
||||
const afInet = 2
|
||||
|
||||
// MIB_TCPROW_OWNER_PID layout: state(4) + localAddr(4) + localPort(4) +
|
||||
// remoteAddr(4) + remotePort(4) + owningPid(4) = 24 bytes.
|
||||
const rowSize = 24
|
||||
|
||||
var size uint32
|
||||
_, _, _ = procGetExtendedTcpTable.Call(0, uintptr(unsafe.Pointer(&size)), 0, afInet, tcpTableOwnerPidListener, 0)
|
||||
if size == 0 {
|
||||
return 0
|
||||
}
|
||||
buf := make([]byte, size)
|
||||
r, _, _ := procGetExtendedTcpTable.Call(
|
||||
uintptr(unsafe.Pointer(&buf[0])),
|
||||
uintptr(unsafe.Pointer(&size)),
|
||||
0, afInet, tcpTableOwnerPidListener, 0,
|
||||
)
|
||||
if r != 0 {
|
||||
return 0
|
||||
}
|
||||
count := binary.LittleEndian.Uint32(buf[:4])
|
||||
for i := uint32(0); i < count; i++ {
|
||||
off := 4 + int(i)*rowSize
|
||||
if off+rowSize > len(buf) {
|
||||
break
|
||||
}
|
||||
// localPort is stored big-endian in the high 16 bits of a 32-bit field.
|
||||
localPort := uint16(buf[off+8])<<8 | uint16(buf[off+9])
|
||||
if localPort != port {
|
||||
continue
|
||||
}
|
||||
localAddr := binary.LittleEndian.Uint32(buf[off+4 : off+8])
|
||||
// 0x0100007f == 127.0.0.1 in network byte order on little-endian.
|
||||
// We accept 0.0.0.0 too in case the orphan bound to all interfaces.
|
||||
if localAddr != 0x0100007f && localAddr != 0 {
|
||||
continue
|
||||
}
|
||||
return binary.LittleEndian.Uint32(buf[off+20 : off+24])
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// wtsSessionHasUser returns true if the session has a non-empty user name,
|
||||
// i.e. someone is logged in (vs. the login/Welcome screen). The console
|
||||
// session at the lock screen has WTSUserName == "".
|
||||
const wtsUserName = 5
|
||||
|
||||
func wtsSessionHasUser(sessionID uint32) bool {
|
||||
var buf uintptr
|
||||
var bytesReturned uint32
|
||||
r, _, _ := procWTSQuerySessionInformation.Call(
|
||||
0, // WTS_CURRENT_SERVER_HANDLE
|
||||
uintptr(sessionID),
|
||||
uintptr(wtsUserName),
|
||||
uintptr(unsafe.Pointer(&buf)),
|
||||
uintptr(unsafe.Pointer(&bytesReturned)),
|
||||
)
|
||||
if r == 0 || buf == 0 {
|
||||
return false
|
||||
}
|
||||
defer func() { _, _, _ = procWTSFreeMemory.Call(buf) }()
|
||||
// First UTF-16 code unit non-zero ⇒ non-empty username.
|
||||
return *(*uint16)(unsafe.Pointer(buf)) != 0
|
||||
}
|
||||
|
||||
// getSystemTokenForSession duplicates the current SYSTEM token and sets its
|
||||
// session ID so the spawned process runs in the target session. Using a SYSTEM
|
||||
// token gives access to both Default and Winlogon desktops plus UIPI bypass.
|
||||
func getSystemTokenForSession(sessionID uint32) (windows.Token, error) {
|
||||
var cur windows.Token
|
||||
if err := windows.OpenProcessToken(windows.CurrentProcess(),
|
||||
windows.MAXIMUM_ALLOWED, &cur); err != nil {
|
||||
return 0, fmt.Errorf("OpenProcessToken: %w", err)
|
||||
}
|
||||
defer cur.Close()
|
||||
|
||||
var dup windows.Token
|
||||
if err := windows.DuplicateTokenEx(cur, windows.MAXIMUM_ALLOWED, nil,
|
||||
securityImpersonation, tokenPrimary, &dup); err != nil {
|
||||
return 0, fmt.Errorf("DuplicateTokenEx: %w", err)
|
||||
}
|
||||
|
||||
sid := sessionID
|
||||
r, _, err := procSetTokenInformation.Call(
|
||||
uintptr(dup),
|
||||
uintptr(tokenSessionID),
|
||||
uintptr(unsafe.Pointer(&sid)),
|
||||
unsafe.Sizeof(sid),
|
||||
)
|
||||
if r == 0 {
|
||||
dup.Close()
|
||||
return 0, fmt.Errorf("SetTokenInformation(SessionId=%d): %w", sessionID, err)
|
||||
}
|
||||
return dup, nil
|
||||
}
|
||||
|
||||
// injectEnvVar appends a KEY=VALUE entry to a Unicode environment block.
|
||||
// The block is a sequence of null-terminated UTF-16 strings, terminated by
|
||||
// an extra null. Returns the new []uint16 backing slice; the caller must
|
||||
// hold the returned slice alive until CreateProcessAsUser completes.
|
||||
func injectEnvVar(envBlock uintptr, key, value string) []uint16 {
|
||||
entry := key + "=" + value
|
||||
|
||||
// Walk the existing block to find its total length.
|
||||
ptr := (*uint16)(unsafe.Pointer(envBlock))
|
||||
var totalChars int
|
||||
for {
|
||||
ch := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars)*2))
|
||||
if ch == 0 {
|
||||
// Check for double-null terminator.
|
||||
next := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars+1)*2))
|
||||
totalChars++
|
||||
if next == 0 {
|
||||
// End of block (don't count the final null yet, we'll rebuild).
|
||||
break
|
||||
}
|
||||
} else {
|
||||
totalChars++
|
||||
}
|
||||
}
|
||||
|
||||
entryUTF16, _ := windows.UTF16FromString(entry)
|
||||
// New block: existing entries + new entry (null-terminated) + final null.
|
||||
newLen := totalChars + len(entryUTF16) + 1
|
||||
newBlock := make([]uint16, newLen)
|
||||
// Copy existing entries (up to but not including the final null).
|
||||
for i := range totalChars {
|
||||
newBlock[i] = *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(i)*2))
|
||||
}
|
||||
copy(newBlock[totalChars:], entryUTF16)
|
||||
newBlock[newLen-1] = 0 // final null terminator
|
||||
|
||||
return newBlock
|
||||
}
|
||||
|
||||
func spawnAgentInSession(sessionID uint32, port uint16, authToken string, jobHandle windows.Handle) (windows.Handle, error) {
|
||||
token, err := getSystemTokenForSession(sessionID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get SYSTEM token for session %d: %w", sessionID, err)
|
||||
}
|
||||
defer token.Close()
|
||||
|
||||
var envBlock uintptr
|
||||
r, _, e := procCreateEnvironmentBlock.Call(
|
||||
uintptr(unsafe.Pointer(&envBlock)),
|
||||
uintptr(token),
|
||||
0,
|
||||
)
|
||||
if r == 0 {
|
||||
// Without an environment block we cannot inject NB_VNC_AGENT_TOKEN;
|
||||
// the agent would start unauthenticated. Abort instead of launching.
|
||||
return 0, fmt.Errorf("CreateEnvironmentBlock: %w", e)
|
||||
}
|
||||
defer func() { _, _, _ = procDestroyEnvironmentBlock.Call(envBlock) }()
|
||||
|
||||
// Inject the auth token into the environment block so it doesn't appear
|
||||
// in the process command line (visible via tasklist/wmic). injectedBlock
|
||||
// must stay alive until CreateProcessAsUser returns.
|
||||
injectedBlock := injectEnvVar(envBlock, agentTokenEnvVar, authToken)
|
||||
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get executable path: %w", err)
|
||||
}
|
||||
|
||||
cmdLine := fmt.Sprintf(`"%s" %s --port %d`, exePath, vncAgentSubcommand, port)
|
||||
cmdLineW, err := windows.UTF16PtrFromString(cmdLine)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("UTF16 cmdline: %w", err)
|
||||
}
|
||||
|
||||
// Create an inheritable pipe for the agent's stderr so we can relog
|
||||
// its output in the service process.
|
||||
var sa windows.SecurityAttributes
|
||||
sa.Length = uint32(unsafe.Sizeof(sa))
|
||||
sa.InheritHandle = 1
|
||||
|
||||
var stderrRead, stderrWrite windows.Handle
|
||||
if err := windows.CreatePipe(&stderrRead, &stderrWrite, &sa, 0); err != nil {
|
||||
return 0, fmt.Errorf("create stderr pipe: %w", err)
|
||||
}
|
||||
// The read end must NOT be inherited by the child.
|
||||
_ = windows.SetHandleInformation(stderrRead, windows.HANDLE_FLAG_INHERIT, 0)
|
||||
|
||||
desktop, _ := windows.UTF16PtrFromString(`WinSta0\Default`)
|
||||
si := windows.StartupInfo{
|
||||
Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})),
|
||||
Desktop: desktop,
|
||||
Flags: windows.STARTF_USESHOWWINDOW | windows.STARTF_USESTDHANDLES,
|
||||
ShowWindow: 0,
|
||||
StdErr: stderrWrite,
|
||||
StdOutput: stderrWrite,
|
||||
}
|
||||
var pi windows.ProcessInformation
|
||||
|
||||
var envPtr *uint16
|
||||
if len(injectedBlock) > 0 {
|
||||
envPtr = &injectedBlock[0]
|
||||
} else if envBlock != 0 {
|
||||
envPtr = (*uint16)(unsafe.Pointer(envBlock))
|
||||
}
|
||||
|
||||
// CREATE_SUSPENDED so we can assign the process to our Job Object
|
||||
// before it executes. Without this the agent could spawn its own child
|
||||
// processes and have them inherit the SCM service-job (not ours), or
|
||||
// briefly listen on the agent port before we tear it down on rollback.
|
||||
// CREATE_BREAKAWAY_FROM_JOB lets the child leave the SCM-managed
|
||||
// service job; harmless if that job allows breakaway, and is required
|
||||
// before AssignProcessToJobObject can succeed in the no-nested-jobs case.
|
||||
err = windows.CreateProcessAsUser(
|
||||
token, nil, cmdLineW,
|
||||
nil, nil, true, // inheritHandles=true for the pipe
|
||||
createUnicodeEnvironment|createNoWindow|createSuspended|createBreakawayFromJob,
|
||||
envPtr, nil, &si, &pi,
|
||||
)
|
||||
runtime.KeepAlive(injectedBlock)
|
||||
// Close the write end in the parent so reads will get EOF when the child exits.
|
||||
_ = windows.CloseHandle(stderrWrite)
|
||||
if err != nil {
|
||||
_ = windows.CloseHandle(stderrRead)
|
||||
return 0, fmt.Errorf("CreateProcessAsUser: %w", err)
|
||||
}
|
||||
|
||||
if jobHandle != 0 {
|
||||
r, _, e := procAssignProcessToJobObject.Call(uintptr(jobHandle), uintptr(pi.Process))
|
||||
if r == 0 {
|
||||
log.Warnf("assign agent to job object: %v (orphan possible on service crash)", e)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := windows.ResumeThread(pi.Thread); err != nil {
|
||||
log.Warnf("resume agent main thread: %v", err)
|
||||
}
|
||||
_ = windows.CloseHandle(pi.Thread)
|
||||
|
||||
// Relog agent output in the service with a [vnc-agent] prefix.
|
||||
go relogAgentOutput(stderrRead)
|
||||
|
||||
log.Infof("spawned agent PID=%d in session %d on port %d", pi.ProcessId, sessionID, port)
|
||||
return pi.Process, nil
|
||||
}
|
||||
|
||||
// sessionManager monitors the active console session and ensures a VNC agent
|
||||
// process is running in it. When the session changes (e.g., user switch, RDP
|
||||
// connect/disconnect), it kills the old agent and spawns a new one.
|
||||
type sessionManager struct {
|
||||
port uint16
|
||||
mu sync.Mutex
|
||||
agentProc windows.Handle
|
||||
everSpawned bool
|
||||
agentStartedAt time.Time
|
||||
spawnFailures int
|
||||
nextSpawnAt time.Time
|
||||
sessionID uint32
|
||||
authToken string
|
||||
done chan struct{}
|
||||
// jobHandle owns the agent processes via a Windows Job Object with
|
||||
// JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE. When the service exits or crashes,
|
||||
// the OS closes the handle and terminates every assigned agent: no
|
||||
// orphaned listeners holding the agent port across restarts.
|
||||
jobHandle windows.Handle
|
||||
}
|
||||
|
||||
func newSessionManager(port uint16) *sessionManager {
|
||||
m := &sessionManager{port: port, sessionID: ^uint32(0), done: make(chan struct{})}
|
||||
if h, err := createKillOnCloseJob(); err != nil {
|
||||
log.Warnf("create job object for vnc-agent (orphan agents possible after crash): %v", err)
|
||||
} else {
|
||||
m.jobHandle = h
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// createKillOnCloseJob returns a Job Object configured so that closing its
|
||||
// handle (process exit or explicit Close) terminates every process assigned
|
||||
// to it. Used to keep orphaned vnc-agent processes from outliving the service.
|
||||
func createKillOnCloseJob() (windows.Handle, error) {
|
||||
r, _, e := procCreateJobObjectW.Call(0, 0)
|
||||
if r == 0 {
|
||||
return 0, fmt.Errorf("CreateJobObject: %w", e)
|
||||
}
|
||||
job := windows.Handle(r)
|
||||
|
||||
// JOBOBJECT_EXTENDED_LIMIT_INFORMATION on amd64 = 144 bytes.
|
||||
//
|
||||
// JOBOBJECT_BASIC_LIMIT_INFORMATION (64 bytes with alignment padding)
|
||||
// PerProcessUserTimeLimit LARGE_INTEGER off 0
|
||||
// PerJobUserTimeLimit LARGE_INTEGER off 8
|
||||
// LimitFlags DWORD off 16
|
||||
// [4 byte pad to align SIZE_T]
|
||||
// MinimumWorkingSetSize SIZE_T off 24
|
||||
// MaximumWorkingSetSize SIZE_T off 32
|
||||
// ActiveProcessLimit DWORD off 40
|
||||
// [4 byte pad to align ULONG_PTR]
|
||||
// Affinity ULONG_PTR off 48
|
||||
// PriorityClass DWORD off 56
|
||||
// SchedulingClass DWORD off 60
|
||||
// IO_COUNTERS (48) + 4 * SIZE_T (32) = 144 total.
|
||||
//
|
||||
// We only set LimitFlags; the rest stays zero.
|
||||
const sizeofExtended = 144
|
||||
const offsetLimitFlags = 16
|
||||
const jobObjectExtendedLimitInformation = 9
|
||||
const jobObjectLimitKillOnJobClose = 0x00002000
|
||||
|
||||
var info [sizeofExtended]byte
|
||||
binary.LittleEndian.PutUint32(info[offsetLimitFlags:offsetLimitFlags+4], jobObjectLimitKillOnJobClose)
|
||||
|
||||
r, _, e = procSetInformationJobObject.Call(
|
||||
uintptr(job),
|
||||
uintptr(jobObjectExtendedLimitInformation),
|
||||
uintptr(unsafe.Pointer(&info[0])),
|
||||
uintptr(sizeofExtended),
|
||||
)
|
||||
if r == 0 {
|
||||
_ = windows.CloseHandle(job)
|
||||
return 0, fmt.Errorf("SetInformationJobObject(KILL_ON_JOB_CLOSE): %w", e)
|
||||
}
|
||||
return job, nil
|
||||
}
|
||||
|
||||
// AuthToken returns the current agent authentication token.
|
||||
func (m *sessionManager) AuthToken() string {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.authToken
|
||||
}
|
||||
|
||||
// Stop signals the session manager to exit its polling loop and closes the
|
||||
// Job Object handle, which Windows uses as the trigger to terminate every
|
||||
// agent process this manager spawned.
|
||||
func (m *sessionManager) Stop() {
|
||||
select {
|
||||
case <-m.done:
|
||||
default:
|
||||
close(m.done)
|
||||
}
|
||||
m.mu.Lock()
|
||||
if m.jobHandle != 0 {
|
||||
_ = windows.CloseHandle(m.jobHandle)
|
||||
m.jobHandle = 0
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *sessionManager) run() {
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
if !m.tick() {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-m.done:
|
||||
m.mu.Lock()
|
||||
m.killAgent()
|
||||
m.mu.Unlock()
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tick performs one session/agent-state update. Returns false if the manager
|
||||
// should permanently stop (e.g. missing SYSTEM privileges).
|
||||
func (m *sessionManager) tick() bool {
|
||||
sid := getActiveSessionID()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.handleSessionChange(sid)
|
||||
m.reapExitedAgent()
|
||||
return m.maybeSpawnAgent(sid)
|
||||
}
|
||||
|
||||
func (m *sessionManager) handleSessionChange(sid uint32) {
|
||||
if sid == m.sessionID {
|
||||
return
|
||||
}
|
||||
log.Infof("active session changed: %d -> %d", m.sessionID, sid)
|
||||
m.killAgent()
|
||||
m.sessionID = sid
|
||||
}
|
||||
|
||||
func (m *sessionManager) reapExitedAgent() {
|
||||
if m.agentProc == 0 {
|
||||
return
|
||||
}
|
||||
var code uint32
|
||||
if err := windows.GetExitCodeProcess(m.agentProc, &code); err != nil {
|
||||
log.Debugf("GetExitCodeProcess: %v", err)
|
||||
return
|
||||
}
|
||||
if code == stillActive {
|
||||
return
|
||||
}
|
||||
m.scheduleNextSpawn(code, time.Since(m.agentStartedAt))
|
||||
if err := windows.CloseHandle(m.agentProc); err != nil {
|
||||
log.Debugf("close agent handle: %v", err)
|
||||
}
|
||||
m.agentProc = 0
|
||||
}
|
||||
|
||||
// scheduleNextSpawn applies an exponential backoff on fast crashes (<5s) and
|
||||
// resets immediately otherwise.
|
||||
func (m *sessionManager) scheduleNextSpawn(exitCode uint32, lifetime time.Duration) {
|
||||
if lifetime < 5*time.Second {
|
||||
m.spawnFailures++
|
||||
backoff := time.Duration(1<<min(m.spawnFailures, 5)) * time.Second
|
||||
if backoff > 30*time.Second {
|
||||
backoff = 30 * time.Second
|
||||
}
|
||||
m.nextSpawnAt = time.Now().Add(backoff)
|
||||
log.Warnf("agent exited (code=%d) after %v, retrying in %v (failures=%d)", exitCode, lifetime.Round(time.Millisecond), backoff, m.spawnFailures)
|
||||
return
|
||||
}
|
||||
m.spawnFailures = 0
|
||||
m.nextSpawnAt = time.Time{}
|
||||
log.Infof("agent exited (code=%d) after %v, respawning", exitCode, lifetime.Round(time.Second))
|
||||
}
|
||||
|
||||
// maybeSpawnAgent spawns a new agent if there's no current one and the backoff
|
||||
// window has elapsed. Returns false to permanently stop the manager when the
|
||||
// service lacks the privileges needed to spawn cross-session.
|
||||
func (m *sessionManager) maybeSpawnAgent(sid uint32) bool {
|
||||
if m.agentProc != 0 || sid == 0xFFFFFFFF || !time.Now().After(m.nextSpawnAt) {
|
||||
return true
|
||||
}
|
||||
// Reap any orphan still holding the agent port from a previous
|
||||
// service instance, only on our very first spawn. Once we own
|
||||
// an agent, we manage its lifecycle ourselves and never need to
|
||||
// kill an unknown listener; if a kill+respawn races on port
|
||||
// release, the spawn-failure backoff handles it without forcing
|
||||
// a synchronous wait or duplicate kill.
|
||||
if !m.everSpawned {
|
||||
reapOrphanOnPort(m.port)
|
||||
}
|
||||
token, err := generateAuthToken()
|
||||
if err != nil {
|
||||
log.Warnf("generate agent auth token: %v", err)
|
||||
return true
|
||||
}
|
||||
m.authToken = token
|
||||
h, err := spawnAgentInSession(sid, m.port, m.authToken, m.jobHandle)
|
||||
if err != nil {
|
||||
m.authToken = ""
|
||||
if errors.Is(err, windows.ERROR_PRIVILEGE_NOT_HELD) {
|
||||
// SE_TCB_NAME (token-impersonation across sessions) is only
|
||||
// granted to SYSTEM. Without it spawnAgent will fail every 2
|
||||
// seconds forever: log once and give up.
|
||||
log.Warnf("VNC service mode disabled: agent spawn requires SYSTEM privileges (got: %v)", err)
|
||||
return false
|
||||
}
|
||||
log.Warnf("spawn agent in session %d: %v", sid, err)
|
||||
return true
|
||||
}
|
||||
m.agentProc = h
|
||||
m.agentStartedAt = time.Now()
|
||||
m.everSpawned = true
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *sessionManager) killAgent() {
|
||||
if m.agentProc == 0 {
|
||||
return
|
||||
}
|
||||
_ = windows.TerminateProcess(m.agentProc, 0)
|
||||
_ = windows.CloseHandle(m.agentProc)
|
||||
m.agentProc = 0
|
||||
log.Info("killed old agent")
|
||||
}
|
||||
|
||||
// relogAgentOutput reads log lines from the agent's stderr pipe and
|
||||
// relogs them with the service's formatter. The *os.File owns the
|
||||
// underlying handle, so closing it suffices.
|
||||
func relogAgentOutput(pipe windows.Handle) {
|
||||
f := os.NewFile(uintptr(pipe), "vnc-agent-stderr")
|
||||
defer func() { _ = f.Close() }()
|
||||
relogAgentStream(f)
|
||||
}
|
||||
|
||||
// logCleanupCall invokes a Windows syscall used solely as a cleanup primitive
|
||||
// (CloseClipboard, ReleaseDC, etc.) and logs failures at trace level. The
|
||||
// indirection lets us satisfy errcheck without scattering ignored returns at
|
||||
// each call site, while still capturing diagnostic info when the OS reports
|
||||
// a failure.
|
||||
func logCleanupCall(name string, proc *windows.LazyProc) {
|
||||
r, _, err := proc.Call()
|
||||
if r == 0 && err != nil && err != windows.NTE_OP_OK {
|
||||
log.Tracef("%s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// logCleanupCallArgs is logCleanupCall with one argument; common pattern for
|
||||
// release-by-handle syscalls.
|
||||
func logCleanupCallArgs(name string, proc *windows.LazyProc, args ...uintptr) {
|
||||
r, _, err := proc.Call(args...)
|
||||
if r == 0 && err != nil && err != windows.NTE_OP_OK {
|
||||
log.Tracef("%s: %v", name, err)
|
||||
}
|
||||
}
|
||||
643
client/vnc/server/capture_darwin.go
Normal file
643
client/vnc/server/capture_darwin.go
Normal file
@@ -0,0 +1,643 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"image"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var darwinCaptureOnce sync.Once
|
||||
|
||||
var (
|
||||
cgMainDisplayID func() uint32
|
||||
cgDisplayPixelsWide func(uint32) uintptr
|
||||
cgDisplayPixelsHigh func(uint32) uintptr
|
||||
cgDisplayCreateImage func(uint32) uintptr
|
||||
cgImageGetWidth func(uintptr) uintptr
|
||||
cgImageGetHeight func(uintptr) uintptr
|
||||
cgImageGetBytesPerRow func(uintptr) uintptr
|
||||
cgImageGetBitsPerPixel func(uintptr) uintptr
|
||||
cgImageGetDataProvider func(uintptr) uintptr
|
||||
cgDataProviderCopyData func(uintptr) uintptr
|
||||
cgImageRelease func(uintptr)
|
||||
cfDataGetLength func(uintptr) int64
|
||||
cfDataGetBytePtr func(uintptr) uintptr
|
||||
cfRelease func(uintptr)
|
||||
cgRequestScreenCaptureAccess func() bool
|
||||
cgEventCreate func(uintptr) uintptr
|
||||
cgEventGetLocation func(uintptr) cgPoint
|
||||
darwinCaptureReady bool
|
||||
)
|
||||
|
||||
// cgPoint mirrors CoreGraphics CGPoint: two doubles, 16 bytes, returned
|
||||
// in registers on Darwin amd64/arm64. Used to receive cursor coordinates
|
||||
// from CGEventGetLocation via purego.
|
||||
type cgPoint struct {
|
||||
X, Y float64
|
||||
}
|
||||
|
||||
func initDarwinCapture() {
|
||||
darwinCaptureOnce.Do(func() {
|
||||
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load CoreGraphics: %v", err)
|
||||
return
|
||||
}
|
||||
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load CoreFoundation: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
purego.RegisterLibFunc(&cgMainDisplayID, cg, "CGMainDisplayID")
|
||||
purego.RegisterLibFunc(&cgDisplayPixelsWide, cg, "CGDisplayPixelsWide")
|
||||
purego.RegisterLibFunc(&cgDisplayPixelsHigh, cg, "CGDisplayPixelsHigh")
|
||||
purego.RegisterLibFunc(&cgDisplayCreateImage, cg, "CGDisplayCreateImage")
|
||||
purego.RegisterLibFunc(&cgImageGetWidth, cg, "CGImageGetWidth")
|
||||
purego.RegisterLibFunc(&cgImageGetHeight, cg, "CGImageGetHeight")
|
||||
purego.RegisterLibFunc(&cgImageGetBytesPerRow, cg, "CGImageGetBytesPerRow")
|
||||
purego.RegisterLibFunc(&cgImageGetBitsPerPixel, cg, "CGImageGetBitsPerPixel")
|
||||
purego.RegisterLibFunc(&cgImageGetDataProvider, cg, "CGImageGetDataProvider")
|
||||
purego.RegisterLibFunc(&cgDataProviderCopyData, cg, "CGDataProviderCopyData")
|
||||
purego.RegisterLibFunc(&cgImageRelease, cg, "CGImageRelease")
|
||||
purego.RegisterLibFunc(&cfDataGetLength, cf, "CFDataGetLength")
|
||||
purego.RegisterLibFunc(&cfDataGetBytePtr, cf, "CFDataGetBytePtr")
|
||||
purego.RegisterLibFunc(&cfRelease, cf, "CFRelease")
|
||||
|
||||
// CGRequestScreenCaptureAccess (macOS 11+) prompts on first call and
|
||||
// is a cheap no-op once granted. The Preflight companion is unreliable
|
||||
// on Sequoia (returns false even when access is granted), so we drive
|
||||
// the permission flow from actual capture failures instead.
|
||||
if sym, err := purego.Dlsym(cg, "CGRequestScreenCaptureAccess"); err == nil {
|
||||
purego.RegisterFunc(&cgRequestScreenCaptureAccess, sym)
|
||||
}
|
||||
// CGEventCreate / CGEventGetLocation feed the cursor position used
|
||||
// by remote-cursor compositing. Optional; absence reports as a
|
||||
// position-source error and disables that feature on this host.
|
||||
if sym, err := purego.Dlsym(cg, "CGEventCreate"); err == nil {
|
||||
purego.RegisterFunc(&cgEventCreate, sym)
|
||||
}
|
||||
if sym, err := purego.Dlsym(cg, "CGEventGetLocation"); err == nil {
|
||||
purego.RegisterFunc(&cgEventGetLocation, sym)
|
||||
}
|
||||
|
||||
darwinCaptureReady = true
|
||||
})
|
||||
}
|
||||
|
||||
// CGCapturer captures the macOS main display using Core Graphics.
|
||||
type CGCapturer struct {
|
||||
displayID uint32
|
||||
w, h int
|
||||
// downscale is 1 for pixel-perfect, 2 for Retina 2:1 box-filter downscale.
|
||||
downscale int
|
||||
hashSeed maphash.Seed
|
||||
lastHash uint64
|
||||
hasHash bool
|
||||
// cursor lazily binds the private CGSCreateCurrentCursorImage symbol
|
||||
// so we can emit the Cursor pseudo-encoding without a per-frame cost
|
||||
// on builds that never query it.
|
||||
cursorOnce sync.Once
|
||||
cursor *cgCursor
|
||||
}
|
||||
|
||||
// PrimeScreenCapturePermission triggers the macOS Screen Recording
|
||||
// permission prompt without creating a full capturer. The platform wiring
|
||||
// calls this at VNC-server enable time so the user sees the prompt the
|
||||
// moment they turn the feature on. CGRequestScreenCaptureAccess is a
|
||||
// no-op when the grant already exists, so calling it on every enable is
|
||||
// cheap and safe.
|
||||
func PrimeScreenCapturePermission() {
|
||||
initDarwinCapture()
|
||||
if !darwinCaptureReady {
|
||||
return
|
||||
}
|
||||
if cgRequestScreenCaptureAccess != nil {
|
||||
cgRequestScreenCaptureAccess()
|
||||
}
|
||||
}
|
||||
|
||||
// notifyScreenRecordingMissing nudges the user once per agent process to
|
||||
// approve Screen Recording. The capturer init retries on backoff when the
|
||||
// grant is missing; without the sync.Once we would reopen System Settings
|
||||
// every tick and flood the daemon log with the same warning.
|
||||
var screenRecordingNotifyOnce sync.Once
|
||||
|
||||
func notifyScreenRecordingMissing() {
|
||||
screenRecordingNotifyOnce.Do(func() {
|
||||
if cgRequestScreenCaptureAccess != nil {
|
||||
cgRequestScreenCaptureAccess()
|
||||
}
|
||||
openPrivacyPane("Privacy_ScreenCapture")
|
||||
log.Warn("Screen Recording permission not granted. " +
|
||||
"Opened System Settings > Privacy & Security > Screen Recording; enable netbird and restart.")
|
||||
})
|
||||
}
|
||||
|
||||
// NewCGCapturer creates a screen capturer for the main display.
|
||||
func NewCGCapturer() (*CGCapturer, error) {
|
||||
initDarwinCapture()
|
||||
if !darwinCaptureReady {
|
||||
return nil, fmt.Errorf("CoreGraphics not available")
|
||||
}
|
||||
|
||||
displayID := cgMainDisplayID()
|
||||
c := &CGCapturer{displayID: displayID, downscale: 1, hashSeed: maphash.MakeSeed()}
|
||||
|
||||
img, err := c.Capture()
|
||||
if err != nil {
|
||||
notifyScreenRecordingMissing()
|
||||
return nil, fmt.Errorf("probe capture: %w", err)
|
||||
}
|
||||
nativeW := img.Rect.Dx()
|
||||
nativeH := img.Rect.Dy()
|
||||
c.hasHash = false
|
||||
if nativeW == 0 || nativeH == 0 {
|
||||
return nil, errors.New("display dimensions are zero")
|
||||
}
|
||||
|
||||
logicalW := int(cgDisplayPixelsWide(displayID))
|
||||
logicalH := int(cgDisplayPixelsHigh(displayID))
|
||||
|
||||
// Enable 2:1 downscale on Retina unless explicitly disabled. Cuts pixel
|
||||
// count 4x, shrinking convert, diff, and wire data proportionally.
|
||||
if !retinaDownscaleDisabled() && nativeW >= 2*logicalW && nativeH >= 2*logicalH && nativeW%2 == 0 && nativeH%2 == 0 {
|
||||
c.downscale = 2
|
||||
}
|
||||
c.w = nativeW / c.downscale
|
||||
c.h = nativeH / c.downscale
|
||||
|
||||
log.Infof("macOS capturer ready: %dx%d (native %dx%d, logical %dx%d, downscale=%d, display=%d)",
|
||||
c.w, c.h, nativeW, nativeH, logicalW, logicalH, c.downscale, displayID)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func retinaDownscaleDisabled() bool {
|
||||
v := os.Getenv(EnvVNCDisableDownscale)
|
||||
if v == "" {
|
||||
return false
|
||||
}
|
||||
disabled, err := strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
log.Warnf("parse %s: %v", EnvVNCDisableDownscale, err)
|
||||
return false
|
||||
}
|
||||
return disabled
|
||||
}
|
||||
|
||||
// Width returns the screen width.
|
||||
func (c *CGCapturer) Width() int { return c.w }
|
||||
|
||||
// Height returns the screen height.
|
||||
func (c *CGCapturer) Height() int { return c.h }
|
||||
|
||||
// Capture returns the current screen as an RGBA image.
|
||||
// CaptureInto writes a fresh frame directly into dst, skipping the
|
||||
// per-frame image.RGBA allocation that Capture() does. Returns
|
||||
// errFrameUnchanged when the screen hash matches the prior call.
|
||||
func (c *CGCapturer) CaptureInto(dst *image.RGBA) error {
|
||||
cgImage := cgDisplayCreateImage(c.displayID)
|
||||
if cgImage == 0 {
|
||||
return fmt.Errorf("CGDisplayCreateImage returned nil (screen recording permission?)")
|
||||
}
|
||||
defer cgImageRelease(cgImage)
|
||||
w := int(cgImageGetWidth(cgImage))
|
||||
h := int(cgImageGetHeight(cgImage))
|
||||
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
|
||||
bpp := int(cgImageGetBitsPerPixel(cgImage))
|
||||
provider := cgImageGetDataProvider(cgImage)
|
||||
if provider == 0 {
|
||||
return fmt.Errorf("CGImageGetDataProvider returned nil")
|
||||
}
|
||||
cfData := cgDataProviderCopyData(provider)
|
||||
if cfData == 0 {
|
||||
return fmt.Errorf("CGDataProviderCopyData returned nil")
|
||||
}
|
||||
defer cfRelease(cfData)
|
||||
dataLen := int(cfDataGetLength(cfData))
|
||||
dataPtr := cfDataGetBytePtr(cfData)
|
||||
if dataPtr == 0 || dataLen == 0 {
|
||||
return fmt.Errorf("empty image data")
|
||||
}
|
||||
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
|
||||
hash := maphash.Bytes(c.hashSeed, src)
|
||||
if c.hasHash && hash == c.lastHash {
|
||||
return errFrameUnchanged
|
||||
}
|
||||
c.lastHash = hash
|
||||
c.hasHash = true
|
||||
|
||||
ds := c.downscale
|
||||
if ds < 1 {
|
||||
ds = 1
|
||||
}
|
||||
outW := w / ds
|
||||
outH := h / ds
|
||||
if dst.Rect.Dx() != outW || dst.Rect.Dy() != outH {
|
||||
return fmt.Errorf("dst size mismatch: dst=%dx%d capturer=%dx%d",
|
||||
dst.Rect.Dx(), dst.Rect.Dy(), outW, outH)
|
||||
}
|
||||
bytesPerPixel := bpp / 8
|
||||
if bytesPerPixel == 4 && ds == 1 {
|
||||
convertBGRAToRGBA(dst.Pix, dst.Stride, src, bytesPerRow, w, h)
|
||||
return nil
|
||||
}
|
||||
if bytesPerPixel == 4 && ds == 2 {
|
||||
convertBGRAToRGBADownscale2(dst.Pix, dst.Stride, src, bytesPerRow, outW, outH)
|
||||
return nil
|
||||
}
|
||||
for row := 0; row < outH; row++ {
|
||||
srcOff := row * ds * bytesPerRow
|
||||
dstOff := row * dst.Stride
|
||||
for col := 0; col < outW; col++ {
|
||||
si := srcOff + col*ds*bytesPerPixel
|
||||
di := dstOff + col*4
|
||||
dst.Pix[di+0] = src[si+2]
|
||||
dst.Pix[di+1] = src[si+1]
|
||||
dst.Pix[di+2] = src[si+0]
|
||||
dst.Pix[di+3] = 0xff
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CGCapturer) Capture() (*image.RGBA, error) {
|
||||
cgImage := cgDisplayCreateImage(c.displayID)
|
||||
if cgImage == 0 {
|
||||
return nil, fmt.Errorf("CGDisplayCreateImage returned nil (screen recording permission?)")
|
||||
}
|
||||
defer cgImageRelease(cgImage)
|
||||
|
||||
w := int(cgImageGetWidth(cgImage))
|
||||
h := int(cgImageGetHeight(cgImage))
|
||||
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
|
||||
bpp := int(cgImageGetBitsPerPixel(cgImage))
|
||||
|
||||
provider := cgImageGetDataProvider(cgImage)
|
||||
if provider == 0 {
|
||||
return nil, fmt.Errorf("CGImageGetDataProvider returned nil")
|
||||
}
|
||||
|
||||
cfData := cgDataProviderCopyData(provider)
|
||||
if cfData == 0 {
|
||||
return nil, fmt.Errorf("CGDataProviderCopyData returned nil")
|
||||
}
|
||||
defer cfRelease(cfData)
|
||||
|
||||
dataLen := int(cfDataGetLength(cfData))
|
||||
dataPtr := cfDataGetBytePtr(cfData)
|
||||
if dataPtr == 0 || dataLen == 0 {
|
||||
return nil, fmt.Errorf("empty image data")
|
||||
}
|
||||
|
||||
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
|
||||
|
||||
hash := maphash.Bytes(c.hashSeed, src)
|
||||
if c.hasHash && hash == c.lastHash {
|
||||
return nil, errFrameUnchanged
|
||||
}
|
||||
c.lastHash = hash
|
||||
c.hasHash = true
|
||||
|
||||
ds := c.downscale
|
||||
if ds < 1 {
|
||||
ds = 1
|
||||
}
|
||||
outW := w / ds
|
||||
outH := h / ds
|
||||
img := image.NewRGBA(image.Rect(0, 0, outW, outH))
|
||||
|
||||
bytesPerPixel := bpp / 8
|
||||
switch {
|
||||
case bytesPerPixel == 4 && ds == 1:
|
||||
convertBGRAToRGBA(img.Pix, img.Stride, src, bytesPerRow, w, h)
|
||||
case bytesPerPixel == 4 && ds == 2:
|
||||
convertBGRAToRGBADownscale2(img.Pix, img.Stride, src, bytesPerRow, outW, outH)
|
||||
default:
|
||||
convertBGRAToRGBAGeneric(img.Pix, img.Stride, src, bytesPerRow, bgraDownscaleParams{outW: outW, outH: outH, bytesPerPixel: bytesPerPixel, ds: ds})
|
||||
}
|
||||
|
||||
return img, nil
|
||||
}
|
||||
|
||||
type bgraDownscaleParams struct {
|
||||
outW, outH, bytesPerPixel, ds int
|
||||
}
|
||||
|
||||
// convertBGRAToRGBAGeneric is the slow per-pixel fallback for non-4-bytes
|
||||
// or non-1/2 downscale formats. Always available regardless of the source
|
||||
// format quirks the fast paths optimize for.
|
||||
func convertBGRAToRGBAGeneric(dst []byte, dstStride int, src []byte, srcStride int, p bgraDownscaleParams) {
|
||||
for row := 0; row < p.outH; row++ {
|
||||
srcOff := row * p.ds * srcStride
|
||||
dstOff := row * dstStride
|
||||
for col := 0; col < p.outW; col++ {
|
||||
si := srcOff + col*p.ds*p.bytesPerPixel
|
||||
di := dstOff + col*4
|
||||
dst[di+0] = src[si+2]
|
||||
dst[di+1] = src[si+1]
|
||||
dst[di+2] = src[si+0]
|
||||
dst[di+3] = 0xff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// convertBGRAToRGBADownscale2 averages every 2x2 BGRA block into one RGBA
|
||||
// output pixel, parallelised across GOMAXPROCS cores. outW and outH are the
|
||||
// destination dimensions (source is 2*outW by 2*outH).
|
||||
func convertBGRAToRGBADownscale2(dst []byte, dstStride int, src []byte, srcStride, outW, outH int) {
|
||||
workers := runtime.GOMAXPROCS(0)
|
||||
if workers > outH {
|
||||
workers = outH
|
||||
}
|
||||
if workers < 1 || outH < 32 {
|
||||
workers = 1
|
||||
}
|
||||
|
||||
convertRows := func(y0, y1 int) {
|
||||
for row := y0; row < y1; row++ {
|
||||
srcRow0 := 2 * row * srcStride
|
||||
srcRow1 := srcRow0 + srcStride
|
||||
dstOff := row * dstStride
|
||||
for col := 0; col < outW; col++ {
|
||||
s0 := srcRow0 + col*8
|
||||
s1 := srcRow1 + col*8
|
||||
b := (uint32(src[s0]) + uint32(src[s0+4]) + uint32(src[s1]) + uint32(src[s1+4])) >> 2
|
||||
g := (uint32(src[s0+1]) + uint32(src[s0+5]) + uint32(src[s1+1]) + uint32(src[s1+5])) >> 2
|
||||
r := (uint32(src[s0+2]) + uint32(src[s0+6]) + uint32(src[s1+2]) + uint32(src[s1+6])) >> 2
|
||||
di := dstOff + col*4
|
||||
dst[di+0] = byte(r)
|
||||
dst[di+1] = byte(g)
|
||||
dst[di+2] = byte(b)
|
||||
dst[di+3] = 0xff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if workers == 1 {
|
||||
convertRows(0, outH)
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
chunk := (outH + workers - 1) / workers
|
||||
for i := 0; i < workers; i++ {
|
||||
y0 := i * chunk
|
||||
y1 := y0 + chunk
|
||||
if y1 > outH {
|
||||
y1 = outH
|
||||
}
|
||||
if y0 >= y1 {
|
||||
break
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(y0, y1 int) {
|
||||
defer wg.Done()
|
||||
convertRows(y0, y1)
|
||||
}(y0, y1)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// convertBGRAToRGBA swaps R/B channels using uint32 word operations, and
|
||||
// parallelises across GOMAXPROCS cores for large images.
|
||||
func convertBGRAToRGBA(dst []byte, dstStride int, src []byte, srcStride, w, h int) {
|
||||
workers := runtime.GOMAXPROCS(0)
|
||||
if workers > h {
|
||||
workers = h
|
||||
}
|
||||
if workers < 1 || h < 64 {
|
||||
workers = 1
|
||||
}
|
||||
|
||||
convertRows := func(y0, y1 int) {
|
||||
rowBytes := w * 4
|
||||
for row := y0; row < y1; row++ {
|
||||
dstRow := dst[row*dstStride : row*dstStride+rowBytes]
|
||||
srcRow := src[row*srcStride : row*srcStride+rowBytes]
|
||||
dstU := unsafe.Slice((*uint32)(unsafe.Pointer(&dstRow[0])), w)
|
||||
srcU := unsafe.Slice((*uint32)(unsafe.Pointer(&srcRow[0])), w)
|
||||
for i, p := range srcU {
|
||||
dstU[i] = (p & 0xff00ff00) | ((p & 0x000000ff) << 16) | ((p & 0x00ff0000) >> 16) | 0xff000000
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if workers == 1 {
|
||||
convertRows(0, h)
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
chunk := (h + workers - 1) / workers
|
||||
for i := 0; i < workers; i++ {
|
||||
y0 := i * chunk
|
||||
y1 := y0 + chunk
|
||||
if y1 > h {
|
||||
y1 = h
|
||||
}
|
||||
if y0 >= y1 {
|
||||
break
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(y0, y1 int) {
|
||||
defer wg.Done()
|
||||
convertRows(y0, y1)
|
||||
}(y0, y1)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// MacPoller wraps CGCapturer with a staleness-cached on-demand Capture:
|
||||
// sessions drive captures themselves from their encoder goroutine, so we
|
||||
// don't need a background ticker. The last result is cached for a short
|
||||
// window so concurrent sessions coalesce into one capture.
|
||||
//
|
||||
// The capturer is allocated lazily on first use and released when all
|
||||
// clients disconnect. Init is retried with backoff because the user may
|
||||
// grant Screen Recording permission while the server is already running.
|
||||
type MacPoller struct {
|
||||
mu sync.Mutex
|
||||
|
||||
capturer *CGCapturer
|
||||
w, h int
|
||||
|
||||
lastFrame *image.RGBA
|
||||
lastAt time.Time
|
||||
|
||||
clients atomic.Int32
|
||||
initFails int
|
||||
initBackoffUntil time.Time
|
||||
closed bool
|
||||
}
|
||||
|
||||
// macInitRetryBackoffFor returns the delay we wait between init attempts
|
||||
// after consecutive failures. Screen Recording permission is a one-shot
|
||||
// user grant, so after several failures we back off aggressively.
|
||||
func macInitRetryBackoffFor(fails int) time.Duration {
|
||||
switch {
|
||||
case fails > 15:
|
||||
return 30 * time.Second
|
||||
case fails > 5:
|
||||
return 10 * time.Second
|
||||
default:
|
||||
return 2 * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
// NewMacPoller creates a lazy on-demand capturer for the macOS display.
|
||||
func NewMacPoller() *MacPoller {
|
||||
return &MacPoller{}
|
||||
}
|
||||
|
||||
// Wake is a no-op retained for API compatibility. With on-demand capture
|
||||
// there is no background retry loop to kick: init happens on the next
|
||||
// Capture/ClientConnect call.
|
||||
func (p *MacPoller) Wake() {
|
||||
// intentional no-op
|
||||
}
|
||||
|
||||
// ClientConnect increments the active client count and eagerly initialises
|
||||
// the capturer so the first FBUpdateRequest doesn't pay the init cost.
|
||||
func (p *MacPoller) ClientConnect() {
|
||||
if p.clients.Add(1) == 1 {
|
||||
p.mu.Lock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// ClientDisconnect decrements the active client count. On the last
|
||||
// disconnect the capturer is released.
|
||||
func (p *MacPoller) ClientDisconnect() {
|
||||
if p.clients.Add(-1) == 0 {
|
||||
p.mu.Lock()
|
||||
p.capturer = nil
|
||||
p.lastFrame = nil
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Close releases all resources.
|
||||
func (p *MacPoller) Close() {
|
||||
p.mu.Lock()
|
||||
p.closed = true
|
||||
p.capturer = nil
|
||||
p.lastFrame = nil
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
// Width returns the screen width. Triggers lazy init if needed.
|
||||
func (p *MacPoller) Width() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
return p.w
|
||||
}
|
||||
|
||||
// Height returns the screen height. Triggers lazy init if needed.
|
||||
func (p *MacPoller) Height() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
return p.h
|
||||
}
|
||||
|
||||
// CaptureInto fills dst directly via the underlying capturer, bypassing
|
||||
// the freshness cache.
|
||||
func (p *MacPoller) CaptureInto(dst *image.RGBA) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
err := p.capturer.CaptureInto(dst)
|
||||
if errors.Is(err, errFrameUnchanged) {
|
||||
// Caller (session) treats this as "no change"; the dst buffer
|
||||
// keeps its prior contents from the previous capture cycle so
|
||||
// the diff stays meaningful.
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
p.capturer = nil
|
||||
return fmt.Errorf("macos capture: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Capture returns a fresh frame, serving from the short-lived cache if a
|
||||
// previous caller captured within freshWindow. Handles the
|
||||
// errFrameUnchanged return from CGCapturer by reusing the cached frame.
|
||||
func (p *MacPoller) Capture() (*image.RGBA, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.lastFrame != nil && time.Since(p.lastAt) < freshWindow {
|
||||
return p.lastFrame, nil
|
||||
}
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
img, err := p.capturer.Capture()
|
||||
if errors.Is(err, errFrameUnchanged) {
|
||||
if p.lastFrame != nil {
|
||||
p.lastAt = time.Now()
|
||||
return p.lastFrame, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if err != nil {
|
||||
// Drop the capturer so the next call retries init; the display stream
|
||||
// can die if the session changes or permissions are revoked.
|
||||
p.capturer = nil
|
||||
return nil, fmt.Errorf("macos capture: %w", err)
|
||||
}
|
||||
p.lastFrame = img
|
||||
p.lastAt = time.Now()
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// ensureCapturerLocked initialises the underlying CGCapturer if needed.
|
||||
// Caller must hold p.mu.
|
||||
func (p *MacPoller) ensureCapturerLocked() error {
|
||||
if p.closed {
|
||||
return fmt.Errorf("poller closed")
|
||||
}
|
||||
if p.capturer != nil {
|
||||
return nil
|
||||
}
|
||||
if time.Now().Before(p.initBackoffUntil) {
|
||||
return fmt.Errorf("macOS capturer unavailable (retry scheduled)")
|
||||
}
|
||||
c, err := NewCGCapturer()
|
||||
if err != nil {
|
||||
p.initFails++
|
||||
p.initBackoffUntil = time.Now().Add(macInitRetryBackoffFor(p.initFails))
|
||||
if p.initFails == 1 || p.initFails%10 == 0 {
|
||||
log.Warnf("macOS capturer: %v (attempt %d)", err, p.initFails)
|
||||
} else {
|
||||
log.Debugf("macOS capturer: %v (attempt %d)", err, p.initFails)
|
||||
}
|
||||
return err
|
||||
}
|
||||
p.initFails = 0
|
||||
p.capturer = c
|
||||
p.w, p.h = c.Width(), c.Height()
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ ScreenCapturer = (*MacPoller)(nil)
|
||||
99
client/vnc/server/capture_dxgi_windows.go
Normal file
99
client/vnc/server/capture_dxgi_windows.go
Normal file
@@ -0,0 +1,99 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
|
||||
"github.com/kirides/go-d3d/d3d11"
|
||||
"github.com/kirides/go-d3d/outputduplication"
|
||||
)
|
||||
|
||||
// dxgiCapturer captures the desktop using DXGI Desktop Duplication.
|
||||
// Provides GPU-accelerated capture with native dirty rect tracking.
|
||||
// Only works from the interactive user session, not Session 0.
|
||||
//
|
||||
// Uses a double-buffer: DXGI writes into img, then we copy to the current
|
||||
// output buffer and hand it out. Alternating between two output buffers
|
||||
// avoids allocating a new image.RGBA per frame (~8MB at 1080p, 30fps).
|
||||
type dxgiCapturer struct {
|
||||
dup *outputduplication.OutputDuplicator
|
||||
device *d3d11.ID3D11Device
|
||||
ctx *d3d11.ID3D11DeviceContext
|
||||
img *image.RGBA
|
||||
out [2]*image.RGBA
|
||||
outIdx int
|
||||
width int
|
||||
height int
|
||||
}
|
||||
|
||||
func newDXGICapturer() (*dxgiCapturer, error) {
|
||||
device, deviceCtx, err := d3d11.NewD3D11Device()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create D3D11 device: %w", err)
|
||||
}
|
||||
|
||||
dup, err := outputduplication.NewIDXGIOutputDuplication(device, deviceCtx, 0)
|
||||
if err != nil {
|
||||
device.Release()
|
||||
deviceCtx.Release()
|
||||
return nil, fmt.Errorf("create output duplication: %w", err)
|
||||
}
|
||||
|
||||
w, h := screenSize()
|
||||
if w == 0 || h == 0 {
|
||||
dup.Release()
|
||||
device.Release()
|
||||
deviceCtx.Release()
|
||||
return nil, fmt.Errorf("screen dimensions are zero")
|
||||
}
|
||||
|
||||
rect := image.Rect(0, 0, w, h)
|
||||
c := &dxgiCapturer{
|
||||
dup: dup,
|
||||
device: device,
|
||||
ctx: deviceCtx,
|
||||
img: image.NewRGBA(rect),
|
||||
out: [2]*image.RGBA{image.NewRGBA(rect), image.NewRGBA(rect)},
|
||||
width: w,
|
||||
height: h,
|
||||
}
|
||||
|
||||
// Grab the initial frame with a longer timeout to ensure we have
|
||||
// a valid image before returning.
|
||||
_ = dup.GetImage(c.img, 2000)
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *dxgiCapturer) capture() (*image.RGBA, error) {
|
||||
err := c.dup.GetImage(c.img, 100)
|
||||
if err != nil && !errors.Is(err, outputduplication.ErrNoImageYet) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Copy into the next output buffer. The DesktopCapturer hands out the
|
||||
// returned pointer to VNC sessions that read pixels concurrently, so we
|
||||
// alternate between two pre-allocated buffers instead of allocating per frame.
|
||||
out := c.out[c.outIdx]
|
||||
c.outIdx ^= 1
|
||||
copy(out.Pix, c.img.Pix)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *dxgiCapturer) close() {
|
||||
if c.dup != nil {
|
||||
c.dup.Release()
|
||||
c.dup = nil
|
||||
}
|
||||
if c.ctx != nil {
|
||||
c.ctx.Release()
|
||||
c.ctx = nil
|
||||
}
|
||||
if c.device != nil {
|
||||
c.device.Release()
|
||||
c.device = nil
|
||||
}
|
||||
}
|
||||
148
client/vnc/server/capture_fb_freebsd.go
Normal file
148
client/vnc/server/capture_fb_freebsd.go
Normal file
@@ -0,0 +1,148 @@
|
||||
//go:build freebsd
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// FreeBSD vt(4) framebuffer ioctl numbers from sys/fbio.h.
|
||||
//
|
||||
// #define FBIOGTYPE _IOR('F', 0, struct fbtype)
|
||||
//
|
||||
// _IOR(g, n, t) on FreeBSD: dir=2 (read) <<30 | (sizeof(t) & 0x1fff)<<16
|
||||
// | (g<<8) | n. sizeof(struct fbtype)=24 → 0x40184600.
|
||||
const fbioGType = 0x40184600
|
||||
|
||||
func defaultFBPath() string { return "/dev/ttyv0" }
|
||||
|
||||
// fbType mirrors FreeBSD's struct fbtype.
|
||||
type fbType struct {
|
||||
FbType int32
|
||||
FbHeight int32
|
||||
FbWidth int32
|
||||
FbDepth int32
|
||||
FbCMSize int32
|
||||
FbSize int32
|
||||
}
|
||||
|
||||
// FBCapturer reads pixels from FreeBSD's vt(4) framebuffer device. The
|
||||
// vt(4) console exposes the active framebuffer via ttyv0 with FBIOGTYPE
|
||||
// for geometry and mmap for backing memory. Pixel layout is assumed to
|
||||
// be 32bpp BGRA (the common case for KMS-backed vt); fbtype doesn't
|
||||
// expose channel offsets, so we don't try to handle exotic layouts here.
|
||||
type FBCapturer struct {
|
||||
mu sync.Mutex
|
||||
path string
|
||||
fd int
|
||||
mmap []byte
|
||||
w, h int
|
||||
bpp int
|
||||
stride int
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// NewFBCapturer opens the given vt(4) device and queries its geometry.
|
||||
func NewFBCapturer(path string) (*FBCapturer, error) {
|
||||
if path == "" {
|
||||
path = defaultFBPath()
|
||||
}
|
||||
fd, err := unix.Open(path, unix.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open %s: %w", path, err)
|
||||
}
|
||||
|
||||
var fbt fbType
|
||||
if _, _, e := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), fbioGType, uintptr(unsafe.Pointer(&fbt))); e != 0 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("FBIOGTYPE: %v", e)
|
||||
}
|
||||
if fbt.FbDepth != 16 && fbt.FbDepth != 24 && fbt.FbDepth != 32 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("unsupported framebuffer depth: %d", fbt.FbDepth)
|
||||
}
|
||||
if fbt.FbWidth <= 0 || fbt.FbHeight <= 0 || fbt.FbSize <= 0 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("invalid framebuffer geometry: %dx%d size=%d", fbt.FbWidth, fbt.FbHeight, fbt.FbSize)
|
||||
}
|
||||
|
||||
mm, err := unix.Mmap(fd, 0, int(fbt.FbSize), unix.PROT_READ, unix.MAP_SHARED)
|
||||
if err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("mmap %s: %w (vt may not support mmap on this driver, e.g. virtio_gpu)", path, err)
|
||||
}
|
||||
|
||||
bpp := int(fbt.FbDepth)
|
||||
stride := int(fbt.FbWidth) * (bpp / 8)
|
||||
c := &FBCapturer{
|
||||
path: path,
|
||||
fd: fd, // valid fd >= 0; we use -1 as the closed sentinel
|
||||
mmap: mm,
|
||||
w: int(fbt.FbWidth),
|
||||
h: int(fbt.FbHeight),
|
||||
bpp: bpp,
|
||||
stride: stride,
|
||||
}
|
||||
log.Infof("framebuffer capturer ready: %s %dx%d bpp=%d (freebsd vt)", path, c.w, c.h, c.bpp)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Width returns the framebuffer width.
|
||||
func (c *FBCapturer) Width() int { return c.w }
|
||||
|
||||
// Height returns the framebuffer height.
|
||||
func (c *FBCapturer) Height() int { return c.h }
|
||||
|
||||
// Capture allocates a fresh image and fills it with the current
|
||||
// framebuffer contents.
|
||||
func (c *FBCapturer) Capture() (*image.RGBA, error) {
|
||||
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
|
||||
if err := c.CaptureInto(img); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// CaptureInto reads the framebuffer directly into dst.Pix. Assumes BGRA
|
||||
// for 32bpp; the FreeBSD fbtype struct doesn't expose channel offsets.
|
||||
func (c *FBCapturer) CaptureInto(dst *image.RGBA) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if dst.Rect.Dx() != c.w || dst.Rect.Dy() != c.h {
|
||||
return fmt.Errorf("dst size mismatch: dst=%dx%d fb=%dx%d",
|
||||
dst.Rect.Dx(), dst.Rect.Dy(), c.w, c.h)
|
||||
}
|
||||
switch c.bpp {
|
||||
case 32:
|
||||
// vt(4) on KMS framebuffers is BGRA: byte 0=B, 1=G, 2=R.
|
||||
swizzleBGRAtoRGBA(dst.Pix, c.mmap[:c.h*c.stride])
|
||||
case 24:
|
||||
swizzleFB24(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h)
|
||||
case 16:
|
||||
swizzleFB16RGB565(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close releases the framebuffer mmap and file descriptor. Serialized with
|
||||
// CaptureInto via c.mu so an in-flight capture can't read freed memory.
|
||||
func (c *FBCapturer) Close() {
|
||||
c.closeOnce.Do(func() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.mmap != nil {
|
||||
_ = unix.Munmap(c.mmap)
|
||||
c.mmap = nil
|
||||
}
|
||||
if c.fd >= 0 {
|
||||
_ = unix.Close(c.fd)
|
||||
c.fd = -1
|
||||
}
|
||||
})
|
||||
}
|
||||
229
client/vnc/server/capture_fb_linux.go
Normal file
229
client/vnc/server/capture_fb_linux.go
Normal file
@@ -0,0 +1,229 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"image"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// Linux framebuffer ioctls (linux/fb.h).
|
||||
const (
|
||||
fbioGetVScreenInfo = 0x4600
|
||||
fbioGetFScreenInfo = 0x4602
|
||||
)
|
||||
|
||||
func defaultFBPath() string { return "/dev/fb0" }
|
||||
|
||||
// fbVarScreenInfo mirrors the kernel's fb_var_screeninfo. Only the
|
||||
// fields we use are mapped; the rest are absorbed into _padN.
|
||||
type fbVarScreenInfo struct {
|
||||
Xres, Yres uint32
|
||||
XresVirtual, YresVirtual uint32
|
||||
XOffset, YOffset uint32
|
||||
BitsPerPixel uint32
|
||||
Grayscale uint32
|
||||
RedOffset, RedLen, RedMSBR uint32
|
||||
GreenOffset, GreenLen, GreenMSBR uint32
|
||||
BlueOffset, BlueLen, BlueMSBR uint32
|
||||
TranspOffset, TranspLen, TranspM uint32
|
||||
NonStd uint32
|
||||
Activate uint32
|
||||
Height, Width uint32
|
||||
AccelFlags uint32
|
||||
PixClock uint32
|
||||
LeftMargin, RightMargin uint32
|
||||
UpperMargin, LowerMargin uint32
|
||||
HsyncLen, VsyncLen uint32
|
||||
Sync uint32
|
||||
Vmode uint32
|
||||
Rotate uint32
|
||||
Colorspace uint32
|
||||
_pad [4]uint32
|
||||
}
|
||||
|
||||
// fbFixScreenInfo mirrors fb_fix_screeninfo. We only need LineLength.
|
||||
type fbFixScreenInfo struct {
|
||||
IDStr [16]byte
|
||||
SmemStart uint64
|
||||
SmemLen uint32
|
||||
Type uint32
|
||||
TypeAux uint32
|
||||
Visual uint32
|
||||
XPanStep uint16
|
||||
YPanStep uint16
|
||||
YWrapStep uint16
|
||||
_pad0 uint16
|
||||
LineLength uint32
|
||||
MmioStart uint64
|
||||
MmioLen uint32
|
||||
Accel uint32
|
||||
Capabilities uint16
|
||||
_reserved [2]uint16
|
||||
}
|
||||
|
||||
// FBCapturer reads pixels straight from the Linux framebuffer device.
|
||||
// Used as a fallback when X11 isn't available, e.g. on a headless box at
|
||||
// the kernel console or the display manager's pre-login screen on machines
|
||||
// without an Xorg server. The framebuffer must be mmap()-able under our
|
||||
// process privileges (typically the netbird service runs as root).
|
||||
type FBCapturer struct {
|
||||
mu sync.Mutex
|
||||
path string
|
||||
fd int
|
||||
mmap []byte
|
||||
w, h int
|
||||
bpp int
|
||||
stride int
|
||||
rOff uint32
|
||||
gOff uint32
|
||||
bOff uint32
|
||||
rLen uint32
|
||||
gLen uint32
|
||||
bLen uint32
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// NewFBCapturer opens the given framebuffer device (/dev/fbN) and
|
||||
// queries its current geometry + pixel format.
|
||||
func NewFBCapturer(path string) (*FBCapturer, error) {
|
||||
if path == "" {
|
||||
path = "/dev/fb0"
|
||||
}
|
||||
fd, err := unix.Open(path, unix.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open %s: %w", path, err)
|
||||
}
|
||||
|
||||
var vinfo fbVarScreenInfo
|
||||
if _, _, e := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), fbioGetVScreenInfo, uintptr(unsafe.Pointer(&vinfo))); e != 0 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("FBIOGET_VSCREENINFO: %v", e)
|
||||
}
|
||||
var finfo fbFixScreenInfo
|
||||
if _, _, e := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), fbioGetFScreenInfo, uintptr(unsafe.Pointer(&finfo))); e != 0 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("FBIOGET_FSCREENINFO: %v", e)
|
||||
}
|
||||
|
||||
bpp := int(vinfo.BitsPerPixel)
|
||||
if bpp != 16 && bpp != 24 && bpp != 32 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("unsupported framebuffer bpp: %d", bpp)
|
||||
}
|
||||
|
||||
size := int(finfo.LineLength) * int(vinfo.Yres)
|
||||
if size <= 0 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("invalid framebuffer dimensions: stride=%d h=%d", finfo.LineLength, vinfo.Yres)
|
||||
}
|
||||
|
||||
mm, err := unix.Mmap(fd, 0, size, unix.PROT_READ, unix.MAP_SHARED)
|
||||
if err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("mmap %s: %w", path, err)
|
||||
}
|
||||
|
||||
c := &FBCapturer{
|
||||
path: path,
|
||||
fd: fd,
|
||||
mmap: mm,
|
||||
w: int(vinfo.Xres),
|
||||
h: int(vinfo.Yres),
|
||||
bpp: bpp,
|
||||
stride: int(finfo.LineLength),
|
||||
rOff: vinfo.RedOffset,
|
||||
gOff: vinfo.GreenOffset,
|
||||
bOff: vinfo.BlueOffset,
|
||||
rLen: vinfo.RedLen,
|
||||
gLen: vinfo.GreenLen,
|
||||
bLen: vinfo.BlueLen,
|
||||
}
|
||||
log.Infof("framebuffer capturer ready: %s %dx%d bpp=%d r=%d/%d g=%d/%d b=%d/%d",
|
||||
path, c.w, c.h, c.bpp, c.rOff, c.rLen, c.gOff, c.gLen, c.bOff, c.bLen)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Width returns the framebuffer width in pixels.
|
||||
func (c *FBCapturer) Width() int { return c.w }
|
||||
|
||||
// Height returns the framebuffer height in pixels.
|
||||
func (c *FBCapturer) Height() int { return c.h }
|
||||
|
||||
// Capture allocates a fresh image and fills it with the current
|
||||
// framebuffer contents.
|
||||
func (c *FBCapturer) Capture() (*image.RGBA, error) {
|
||||
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
|
||||
if err := c.CaptureInto(img); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// CaptureInto reads the framebuffer directly into dst.Pix.
|
||||
func (c *FBCapturer) CaptureInto(dst *image.RGBA) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if dst.Rect.Dx() != c.w || dst.Rect.Dy() != c.h {
|
||||
return fmt.Errorf("dst size mismatch: dst=%dx%d fb=%dx%d",
|
||||
dst.Rect.Dx(), dst.Rect.Dy(), c.w, c.h)
|
||||
}
|
||||
|
||||
switch c.bpp {
|
||||
case 32:
|
||||
swizzleFB32(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h, channelShifts{R: c.rOff, G: c.gOff, B: c.bOff})
|
||||
case 24:
|
||||
swizzleFB24(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h)
|
||||
case 16:
|
||||
swizzleFB16RGB565(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close releases the framebuffer mmap and file descriptor. Serialized with
|
||||
// CaptureInto via c.mu so an in-flight capture can't read freed memory.
|
||||
func (c *FBCapturer) Close() {
|
||||
c.closeOnce.Do(func() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.mmap != nil {
|
||||
_ = unix.Munmap(c.mmap)
|
||||
c.mmap = nil
|
||||
}
|
||||
if c.fd >= 0 {
|
||||
_ = unix.Close(c.fd)
|
||||
c.fd = -1
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// channelShifts groups the bit offsets for the R/G/B channels in a packed
|
||||
// uint32 framebuffer pixel. Bundling avoids drowning per-row callers in a
|
||||
// 9-parameter signature.
|
||||
type channelShifts struct {
|
||||
R, G, B uint32
|
||||
}
|
||||
|
||||
// swizzleFB32 handles 32-bit framebuffers with arbitrary R/G/B channel
|
||||
// offsets. Pulls one pixel per uint32, then masks each channel into the
|
||||
// destination RGBA byte order.
|
||||
func swizzleFB32(dst []byte, dstStride int, src []byte, srcStride, w, h int, shifts channelShifts) {
|
||||
for y := 0; y < h; y++ {
|
||||
srcRow := src[y*srcStride : y*srcStride+w*4]
|
||||
dstRow := dst[y*dstStride:]
|
||||
for x := 0; x < w; x++ {
|
||||
pix := binary.LittleEndian.Uint32(srcRow[x*4 : x*4+4])
|
||||
dstRow[x*4+0] = byte(pix >> shifts.R)
|
||||
dstRow[x*4+1] = byte(pix >> shifts.G)
|
||||
dstRow[x*4+2] = byte(pix >> shifts.B)
|
||||
dstRow[x*4+3] = 0xff
|
||||
}
|
||||
}
|
||||
}
|
||||
149
client/vnc/server/capture_fb_unix.go
Normal file
149
client/vnc/server/capture_fb_unix.go
Normal file
@@ -0,0 +1,149 @@
|
||||
//go:build unix && !darwin && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"image"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// FBPoller wraps FBCapturer with the same lifecycle (ClientConnect /
|
||||
// ClientDisconnect, lazy init) as X11Poller, so it slots into the same
|
||||
// session plumbing without code changes upstream. The concrete
|
||||
// FBCapturer is platform-specific (capture_fb_linux.go / _freebsd.go);
|
||||
// this file owns the cross-platform glue.
|
||||
type FBPoller struct {
|
||||
mu sync.Mutex
|
||||
path string
|
||||
capturer *FBCapturer
|
||||
w, h int
|
||||
clients int32
|
||||
}
|
||||
|
||||
// NewFBPoller returns a poller that opens path on first use. Empty path
|
||||
// defaults to /dev/fb0 on Linux and /dev/ttyv0 on FreeBSD.
|
||||
func NewFBPoller(path string) *FBPoller {
|
||||
if path == "" {
|
||||
path = defaultFBPath()
|
||||
}
|
||||
return &FBPoller{path: path}
|
||||
}
|
||||
|
||||
// ClientConnect eagerly initialises the capturer on first connect.
|
||||
func (p *FBPoller) ClientConnect() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.clients++
|
||||
if p.clients == 1 {
|
||||
_ = p.ensureCapturerLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// ClientDisconnect closes the capturer when the last client leaves.
|
||||
func (p *FBPoller) ClientDisconnect() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.clients--
|
||||
if p.clients <= 0 && p.capturer != nil {
|
||||
p.capturer.Close()
|
||||
p.capturer = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Width returns the framebuffer width, doing lazy init if needed.
|
||||
func (p *FBPoller) Width() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
return p.w
|
||||
}
|
||||
|
||||
// Height returns the framebuffer height, doing lazy init if needed.
|
||||
func (p *FBPoller) Height() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
return p.h
|
||||
}
|
||||
|
||||
// Capture takes a fresh frame.
|
||||
func (p *FBPoller) Capture() (*image.RGBA, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p.capturer.Capture()
|
||||
}
|
||||
|
||||
// CaptureInto fills dst directly.
|
||||
func (p *FBPoller) CaptureInto(dst *image.RGBA) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
return p.capturer.CaptureInto(dst)
|
||||
}
|
||||
|
||||
// Close releases all framebuffer resources.
|
||||
func (p *FBPoller) Close() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.capturer != nil {
|
||||
p.capturer.Close()
|
||||
p.capturer = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *FBPoller) ensureCapturerLocked() error {
|
||||
if p.capturer != nil {
|
||||
return nil
|
||||
}
|
||||
c, err := NewFBCapturer(p.path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.capturer = c
|
||||
p.w, p.h = c.Width(), c.Height()
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ ScreenCapturer = (*FBPoller)(nil)
|
||||
var _ captureIntoer = (*FBPoller)(nil)
|
||||
|
||||
// swizzleFB24 handles 24-bit packed framebuffers (B,G,R triplets).
|
||||
// Shared between Linux and FreeBSD framebuffer paths.
|
||||
func swizzleFB24(dst []byte, dstStride int, src []byte, srcStride, w, h int) {
|
||||
for y := 0; y < h; y++ {
|
||||
srcRow := src[y*srcStride : y*srcStride+w*3]
|
||||
dstRow := dst[y*dstStride:]
|
||||
for x := 0; x < w; x++ {
|
||||
b := srcRow[x*3+0]
|
||||
g := srcRow[x*3+1]
|
||||
r := srcRow[x*3+2]
|
||||
dstRow[x*4+0] = r
|
||||
dstRow[x*4+1] = g
|
||||
dstRow[x*4+2] = b
|
||||
dstRow[x*4+3] = 0xff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// swizzleFB16RGB565 handles 16bpp RGB 565 framebuffers.
|
||||
func swizzleFB16RGB565(dst []byte, dstStride int, src []byte, srcStride, w, h int) {
|
||||
for y := 0; y < h; y++ {
|
||||
srcRow := src[y*srcStride : y*srcStride+w*2]
|
||||
dstRow := dst[y*dstStride:]
|
||||
for x := 0; x < w; x++ {
|
||||
pix := uint16(srcRow[x*2]) | uint16(srcRow[x*2+1])<<8
|
||||
r := byte((pix >> 11) & 0x1f)
|
||||
g := byte((pix >> 5) & 0x3f)
|
||||
b := byte(pix & 0x1f)
|
||||
dstRow[x*4+0] = (r << 3) | (r >> 2)
|
||||
dstRow[x*4+1] = (g << 2) | (g >> 4)
|
||||
dstRow[x*4+2] = (b << 3) | (b >> 2)
|
||||
dstRow[x*4+3] = 0xff
|
||||
}
|
||||
}
|
||||
}
|
||||
556
client/vnc/server/capture_windows.go
Normal file
556
client/vnc/server/capture_windows.go
Normal file
@@ -0,0 +1,556 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
gdi32 = windows.NewLazySystemDLL("gdi32.dll")
|
||||
user32 = windows.NewLazySystemDLL("user32.dll")
|
||||
|
||||
procGetDC = user32.NewProc("GetDC")
|
||||
procReleaseDC = user32.NewProc("ReleaseDC")
|
||||
procCreateCompatDC = gdi32.NewProc("CreateCompatibleDC")
|
||||
procCreateDIBSection = gdi32.NewProc("CreateDIBSection")
|
||||
procSelectObject = gdi32.NewProc("SelectObject")
|
||||
procDeleteObject = gdi32.NewProc("DeleteObject")
|
||||
procDeleteDC = gdi32.NewProc("DeleteDC")
|
||||
procBitBlt = gdi32.NewProc("BitBlt")
|
||||
procGetSystemMetrics = user32.NewProc("GetSystemMetrics")
|
||||
|
||||
// Desktop switching for service/Session 0 capture.
|
||||
procOpenInputDesktop = user32.NewProc("OpenInputDesktop")
|
||||
procSetThreadDesktop = user32.NewProc("SetThreadDesktop")
|
||||
procCloseDesktop = user32.NewProc("CloseDesktop")
|
||||
procOpenWindowStation = user32.NewProc("OpenWindowStationW")
|
||||
procSetProcessWindowStation = user32.NewProc("SetProcessWindowStation")
|
||||
procCloseWindowStation = user32.NewProc("CloseWindowStation")
|
||||
procGetUserObjectInformationW = user32.NewProc("GetUserObjectInformationW")
|
||||
)
|
||||
|
||||
const uoiName = 2
|
||||
|
||||
const (
|
||||
smCxScreen = 0
|
||||
smCyScreen = 1
|
||||
srccopy = 0x00CC0020
|
||||
captureBlt = 0x40000000
|
||||
dibRgbColors = 0
|
||||
)
|
||||
|
||||
type bitmapInfoHeader struct {
|
||||
Size uint32
|
||||
Width int32
|
||||
Height int32
|
||||
Planes uint16
|
||||
BitCount uint16
|
||||
Compression uint32
|
||||
SizeImage uint32
|
||||
XPelsPerMeter int32
|
||||
YPelsPerMeter int32
|
||||
ClrUsed uint32
|
||||
ClrImportant uint32
|
||||
}
|
||||
|
||||
type bitmapInfo struct {
|
||||
Header bitmapInfoHeader
|
||||
}
|
||||
|
||||
// setupInteractiveWindowStation associates the current process with WinSta0,
|
||||
// the interactive window station. This is required for a SYSTEM service in
|
||||
// Session 0 to call OpenInputDesktop for screen capture and input injection.
|
||||
func setupInteractiveWindowStation() error {
|
||||
name, err := windows.UTF16PtrFromString("WinSta0")
|
||||
if err != nil {
|
||||
return fmt.Errorf("UTF16 WinSta0: %w", err)
|
||||
}
|
||||
hWinSta, _, err := procOpenWindowStation.Call(
|
||||
uintptr(unsafe.Pointer(name)),
|
||||
0,
|
||||
uintptr(windows.MAXIMUM_ALLOWED),
|
||||
)
|
||||
if hWinSta == 0 {
|
||||
return fmt.Errorf("OpenWindowStation(WinSta0): %w", err)
|
||||
}
|
||||
r, _, err := procSetProcessWindowStation.Call(hWinSta)
|
||||
if r == 0 {
|
||||
_, _, _ = procCloseWindowStation.Call(hWinSta)
|
||||
return fmt.Errorf("SetProcessWindowStation: %w", err)
|
||||
}
|
||||
log.Info("process window station set to WinSta0 (interactive)")
|
||||
return nil
|
||||
}
|
||||
|
||||
func screenSize() (int, int) {
|
||||
w, _, _ := procGetSystemMetrics.Call(uintptr(smCxScreen))
|
||||
h, _, _ := procGetSystemMetrics.Call(uintptr(smCyScreen))
|
||||
return int(w), int(h)
|
||||
}
|
||||
|
||||
func getDesktopName(hDesk uintptr) string {
|
||||
var buf [256]uint16
|
||||
var needed uint32
|
||||
_, _, _ = procGetUserObjectInformationW.Call(hDesk, uoiName,
|
||||
uintptr(unsafe.Pointer(&buf[0])), 512,
|
||||
uintptr(unsafe.Pointer(&needed)))
|
||||
return windows.UTF16ToString(buf[:])
|
||||
}
|
||||
|
||||
// switchToInputDesktop opens the desktop currently receiving user input
|
||||
// and sets it as the calling OS thread's desktop. Must be called from a
|
||||
// goroutine locked to its OS thread via runtime.LockOSThread().
|
||||
func switchToInputDesktop() (bool, string) {
|
||||
hDesk, _, _ := procOpenInputDesktop.Call(0, 0, uintptr(windows.MAXIMUM_ALLOWED))
|
||||
if hDesk == 0 {
|
||||
return false, ""
|
||||
}
|
||||
name := getDesktopName(hDesk)
|
||||
ret, _, _ := procSetThreadDesktop.Call(hDesk)
|
||||
_, _, _ = procCloseDesktop.Call(hDesk)
|
||||
return ret != 0, name
|
||||
}
|
||||
|
||||
// gdiCapturer captures the desktop screen using GDI BitBlt.
|
||||
// GDI objects (DC, DIBSection) are allocated once and reused across frames.
|
||||
type gdiCapturer struct {
|
||||
mu sync.Mutex
|
||||
width int
|
||||
height int
|
||||
|
||||
// Pre-allocated GDI resources, reused across captures.
|
||||
memDC uintptr
|
||||
bmp uintptr
|
||||
bits uintptr
|
||||
}
|
||||
|
||||
func newGDICapturer() (*gdiCapturer, error) {
|
||||
w, h := screenSize()
|
||||
if w == 0 || h == 0 {
|
||||
return nil, fmt.Errorf("screen dimensions are zero")
|
||||
}
|
||||
c := &gdiCapturer{width: w, height: h}
|
||||
if err := c.allocGDI(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// allocGDI pre-allocates the compatible DC and DIB section for reuse.
|
||||
func (c *gdiCapturer) allocGDI() error {
|
||||
screenDC, _, _ := procGetDC.Call(0)
|
||||
if screenDC == 0 {
|
||||
return fmt.Errorf("GetDC returned 0")
|
||||
}
|
||||
defer func() { _, _, _ = procReleaseDC.Call(0, screenDC) }()
|
||||
|
||||
memDC, _, _ := procCreateCompatDC.Call(screenDC)
|
||||
if memDC == 0 {
|
||||
return fmt.Errorf("CreateCompatibleDC returned 0")
|
||||
}
|
||||
|
||||
bi := bitmapInfo{
|
||||
Header: bitmapInfoHeader{
|
||||
Size: uint32(unsafe.Sizeof(bitmapInfoHeader{})),
|
||||
Width: int32(c.width),
|
||||
Height: -int32(c.height), // negative = top-down DIB
|
||||
Planes: 1,
|
||||
BitCount: 32,
|
||||
},
|
||||
}
|
||||
|
||||
var bits uintptr
|
||||
bmp, _, _ := procCreateDIBSection.Call(
|
||||
screenDC,
|
||||
uintptr(unsafe.Pointer(&bi)),
|
||||
dibRgbColors,
|
||||
uintptr(unsafe.Pointer(&bits)),
|
||||
0, 0,
|
||||
)
|
||||
if bmp == 0 || bits == 0 {
|
||||
_, _, _ = procDeleteDC.Call(memDC)
|
||||
return fmt.Errorf("CreateDIBSection returned 0")
|
||||
}
|
||||
|
||||
_, _, _ = procSelectObject.Call(memDC, bmp)
|
||||
|
||||
c.memDC = memDC
|
||||
c.bmp = bmp
|
||||
c.bits = bits
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *gdiCapturer) close() { c.freeGDI() }
|
||||
|
||||
// freeGDI releases pre-allocated GDI resources.
|
||||
func (c *gdiCapturer) freeGDI() {
|
||||
if c.bmp != 0 {
|
||||
_, _, _ = procDeleteObject.Call(c.bmp)
|
||||
c.bmp = 0
|
||||
}
|
||||
if c.memDC != 0 {
|
||||
_, _, _ = procDeleteDC.Call(c.memDC)
|
||||
c.memDC = 0
|
||||
}
|
||||
c.bits = 0
|
||||
}
|
||||
|
||||
func (c *gdiCapturer) capture() (*image.RGBA, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.memDC == 0 {
|
||||
return nil, fmt.Errorf("GDI resources not allocated")
|
||||
}
|
||||
|
||||
screenDC, _, _ := procGetDC.Call(0)
|
||||
if screenDC == 0 {
|
||||
return nil, fmt.Errorf("GetDC returned 0")
|
||||
}
|
||||
defer func() { _, _, _ = procReleaseDC.Call(0, screenDC) }()
|
||||
|
||||
// SRCCOPY|CAPTUREBLT: CAPTUREBLT forces inclusion of layered/topmost
|
||||
// windows in the capture and is required for GDI BitBlt to return live
|
||||
// pixels when the session is rendered through RDP / DWM-composited
|
||||
// surfaces. Without it BitBlt reads the backing-store DIB which is
|
||||
// often empty (all-black) on RDP and headless sessions.
|
||||
ret, _, _ := procBitBlt.Call(c.memDC, 0, 0, uintptr(c.width), uintptr(c.height),
|
||||
screenDC, 0, 0, srccopy|captureBlt)
|
||||
if ret == 0 {
|
||||
return nil, fmt.Errorf("BitBlt returned 0")
|
||||
}
|
||||
|
||||
n := c.width * c.height * 4
|
||||
raw := unsafe.Slice((*byte)(unsafe.Pointer(c.bits)), n)
|
||||
|
||||
// GDI gives BGRA, the RFB encoder expects RGBA (img.Pix layout).
|
||||
// Swap R and B in bulk using uint32 operations (one load + mask + shift
|
||||
// per pixel instead of three separate byte assignments).
|
||||
img := image.NewRGBA(image.Rect(0, 0, c.width, c.height))
|
||||
swizzleBGRAtoRGBA(img.Pix, raw)
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// DesktopCapturer captures the interactive desktop, handling desktop transitions
|
||||
// (login screen, UAC prompts). A dedicated OS-locked goroutine continuously
|
||||
// captures frames on demand via a dedicated OS-locked goroutine (required
|
||||
// because DXGI's D3D11 device context is not thread-safe). Sessions drive
|
||||
// timing by calling Capture(); a short staleness cache coalesces concurrent
|
||||
// requests. Capture pauses automatically when no clients are connected.
|
||||
type DesktopCapturer struct {
|
||||
mu sync.Mutex
|
||||
w, h int
|
||||
|
||||
// lastFrame/lastAt implement a small staleness cache so multiple
|
||||
// near-simultaneous Capture calls share one DXGI round-trip.
|
||||
lastFrame *image.RGBA
|
||||
lastAt time.Time
|
||||
|
||||
// clients tracks the number of active VNC sessions. When zero, the
|
||||
// worker goroutine releases the underlying capturer.
|
||||
clients atomic.Int32
|
||||
|
||||
// reqCh carries capture requests from sessions to the OS-locked worker.
|
||||
reqCh chan captureReq
|
||||
// wake is signaled when a client connects and the worker should resume.
|
||||
wake chan struct{}
|
||||
// done is closed when Close is called, terminating the worker.
|
||||
done chan struct{}
|
||||
|
||||
// cursorState holds the latest cursor sprite sampled by the worker.
|
||||
// The worker calls GetCursorInfo every capture and decodes a new
|
||||
// sprite only when the HCURSOR changes.
|
||||
cursorState cursorState
|
||||
}
|
||||
|
||||
// captureReq is a single capture request awaiting a reply. Reply channel is
|
||||
// buffered to size 1 so the worker never blocks on a sender that's gone.
|
||||
type captureReq struct {
|
||||
reply chan captureReply
|
||||
}
|
||||
|
||||
type captureReply struct {
|
||||
img *image.RGBA
|
||||
err error
|
||||
}
|
||||
|
||||
// NewDesktopCapturer creates an on-demand capturer for the active desktop.
|
||||
func NewDesktopCapturer() *DesktopCapturer {
|
||||
c := &DesktopCapturer{
|
||||
wake: make(chan struct{}, 1),
|
||||
done: make(chan struct{}),
|
||||
reqCh: make(chan captureReq),
|
||||
}
|
||||
go c.worker()
|
||||
return c
|
||||
}
|
||||
|
||||
// ClientConnect increments the active client count, resuming capture if needed.
|
||||
func (c *DesktopCapturer) ClientConnect() {
|
||||
c.clients.Add(1)
|
||||
select {
|
||||
case c.wake <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// ClientDisconnect decrements the active client count.
|
||||
func (c *DesktopCapturer) ClientDisconnect() {
|
||||
c.clients.Add(-1)
|
||||
}
|
||||
|
||||
// Close stops the capture loop and releases resources.
|
||||
func (c *DesktopCapturer) Close() {
|
||||
select {
|
||||
case <-c.done:
|
||||
default:
|
||||
close(c.done)
|
||||
}
|
||||
}
|
||||
|
||||
// Width returns the current screen width, triggering a capture if the
|
||||
// worker hasn't initialised yet. validateCapturer depends on Width/Height
|
||||
// becoming non-zero promptly after ClientConnect so it doesn't reject
|
||||
// brand-new sessions.
|
||||
func (c *DesktopCapturer) Width() int {
|
||||
c.mu.Lock()
|
||||
w := c.w
|
||||
c.mu.Unlock()
|
||||
if w == 0 && c.clients.Load() > 0 {
|
||||
_, _ = c.Capture()
|
||||
c.mu.Lock()
|
||||
w = c.w
|
||||
c.mu.Unlock()
|
||||
}
|
||||
return w
|
||||
}
|
||||
|
||||
// Height returns the current screen height, triggering a capture if the
|
||||
// worker hasn't initialised yet (see Width). Returns 0 while no client is
|
||||
// connected so callers don't deadlock against a parked worker.
|
||||
func (c *DesktopCapturer) Height() int {
|
||||
c.mu.Lock()
|
||||
h := c.h
|
||||
c.mu.Unlock()
|
||||
if h == 0 && c.clients.Load() > 0 {
|
||||
_, _ = c.Capture()
|
||||
c.mu.Lock()
|
||||
h = c.h
|
||||
c.mu.Unlock()
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// Capture returns a freshly captured frame, serving from a short staleness
|
||||
// cache when multiple sessions ask within freshWindow of each other. All
|
||||
// real DXGI/GDI work happens on the OS-locked worker goroutine.
|
||||
func (c *DesktopCapturer) Capture() (*image.RGBA, error) {
|
||||
c.mu.Lock()
|
||||
if c.lastFrame != nil && time.Since(c.lastAt) < freshWindow {
|
||||
img := c.lastFrame
|
||||
c.mu.Unlock()
|
||||
return img, nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
reply := make(chan captureReply, 1)
|
||||
select {
|
||||
case c.reqCh <- captureReq{reply: reply}:
|
||||
case <-c.done:
|
||||
return nil, fmt.Errorf("capturer closed")
|
||||
}
|
||||
select {
|
||||
case r := <-reply:
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.lastFrame = r.img
|
||||
c.lastAt = time.Now()
|
||||
c.mu.Unlock()
|
||||
return r.img, nil
|
||||
case <-c.done:
|
||||
return nil, fmt.Errorf("capturer closed")
|
||||
}
|
||||
}
|
||||
|
||||
// waitForClient blocks until a client connects or the capturer is closed.
|
||||
func (c *DesktopCapturer) waitForClient() bool {
|
||||
if c.clients.Load() > 0 {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-c.wake:
|
||||
return true
|
||||
case <-c.done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// worker owns DXGI/GDI state on its OS-locked thread and services capture
|
||||
// requests from sessions. No background ticker: a capture happens only when
|
||||
// a session asks for one (throttled by Capture()'s staleness cache).
|
||||
func (c *DesktopCapturer) worker() {
|
||||
runtime.LockOSThread()
|
||||
|
||||
// When running as a Windows service (Session 0), we need to attach to the
|
||||
// interactive window station before OpenInputDesktop will succeed.
|
||||
if err := setupInteractiveWindowStation(); err != nil {
|
||||
log.Warnf("attach to interactive window station: %v", err)
|
||||
}
|
||||
|
||||
w := &captureWorker{c: c}
|
||||
defer w.closeCapturer()
|
||||
|
||||
for {
|
||||
if !c.waitForClient() {
|
||||
return
|
||||
}
|
||||
// Drop the capturer when all clients have disconnected so we don't
|
||||
// hold the DXGI duplication or GDI DC on an idle peer.
|
||||
if c.clients.Load() <= 0 {
|
||||
w.closeCapturer()
|
||||
continue
|
||||
}
|
||||
if !w.handleNextRequest() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// frameCapturer is the per-backend interface used by the worker. DXGI and
|
||||
// GDI implementations both satisfy it.
|
||||
type frameCapturer interface {
|
||||
capture() (*image.RGBA, error)
|
||||
close()
|
||||
}
|
||||
|
||||
// captureWorker owns the worker goroutine's mutable state. Extracted into a
|
||||
// struct so the request/desktop/init logic can live on small methods and the
|
||||
// outer worker() stays a thin loop.
|
||||
type captureWorker struct {
|
||||
c *DesktopCapturer
|
||||
cap frameCapturer
|
||||
desktopFails int
|
||||
lastDesktop string
|
||||
nextInitRetry time.Time
|
||||
cursor cursorSampler
|
||||
}
|
||||
|
||||
// handleNextRequest waits for either shutdown or a capture request and runs
|
||||
// the request through prepCapturer/capture. Returns false when the worker
|
||||
// should exit.
|
||||
func (w *captureWorker) handleNextRequest() bool {
|
||||
select {
|
||||
case <-w.c.done:
|
||||
return false
|
||||
case req := <-w.c.reqCh:
|
||||
w.serveRequest(req)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (w *captureWorker) serveRequest(req captureReq) {
|
||||
fc, err := w.prepCapturer()
|
||||
if err != nil {
|
||||
req.reply <- captureReply{err: err}
|
||||
return
|
||||
}
|
||||
img, err := fc.capture()
|
||||
if err != nil {
|
||||
log.Debugf("capture: %v", err)
|
||||
w.closeCapturer()
|
||||
w.nextInitRetry = time.Now().Add(100 * time.Millisecond)
|
||||
req.reply <- captureReply{err: err}
|
||||
return
|
||||
}
|
||||
if snap, err := w.cursor.sample(); err != nil {
|
||||
w.c.cursorState.store(&cursorSnapshot{err: err})
|
||||
} else {
|
||||
w.c.cursorState.store(snap)
|
||||
}
|
||||
req.reply <- captureReply{img: img}
|
||||
}
|
||||
|
||||
// prepCapturer switches to the input desktop, handles desktop-change
|
||||
// teardown, and creates the underlying capturer on demand. Backoff state is
|
||||
// tracked across calls via w.nextInitRetry.
|
||||
func (w *captureWorker) prepCapturer() (frameCapturer, error) {
|
||||
if err := w.refreshDesktop(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if w.cap != nil {
|
||||
return w.cap, nil
|
||||
}
|
||||
if time.Now().Before(w.nextInitRetry) {
|
||||
return nil, fmt.Errorf("capturer init backing off")
|
||||
}
|
||||
fc, err := w.createCapturer()
|
||||
if err != nil {
|
||||
w.nextInitRetry = time.Now().Add(500 * time.Millisecond)
|
||||
return nil, err
|
||||
}
|
||||
w.cap = fc
|
||||
sw, sh := screenSize()
|
||||
w.c.mu.Lock()
|
||||
w.c.w, w.c.h = sw, sh
|
||||
w.c.mu.Unlock()
|
||||
log.Infof("screen capturer ready: %dx%d", sw, sh)
|
||||
return w.cap, nil
|
||||
}
|
||||
|
||||
// refreshDesktop tracks the active input desktop. When it changes (lock
|
||||
// screen, fast-user-switch) the existing capturer is dropped so the next
|
||||
// call rebuilds one against the new desktop.
|
||||
func (w *captureWorker) refreshDesktop() error {
|
||||
ok, desk := switchToInputDesktop()
|
||||
if !ok {
|
||||
w.desktopFails++
|
||||
if w.desktopFails == 1 || w.desktopFails%100 == 0 {
|
||||
log.Warnf("switchToInputDesktop failed (count=%d), no interactive desktop session?", w.desktopFails)
|
||||
}
|
||||
return fmt.Errorf("no interactive desktop")
|
||||
}
|
||||
if w.desktopFails > 0 {
|
||||
log.Infof("switchToInputDesktop recovered after %d failures, desktop=%q", w.desktopFails, desk)
|
||||
w.desktopFails = 0
|
||||
}
|
||||
if desk != w.lastDesktop {
|
||||
log.Infof("desktop changed: %q -> %q", w.lastDesktop, desk)
|
||||
w.lastDesktop = desk
|
||||
w.closeCapturer()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *captureWorker) createCapturer() (frameCapturer, error) {
|
||||
dc, err := newDXGICapturer()
|
||||
if err == nil {
|
||||
log.Info("using DXGI Desktop Duplication for capture")
|
||||
return dc, nil
|
||||
}
|
||||
log.Warnf("DXGI Desktop Duplication unavailable, falling back to slower GDI BitBlt: %v", err)
|
||||
gc, err := newGDICapturer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Info("using GDI BitBlt for capture")
|
||||
return gc, nil
|
||||
}
|
||||
|
||||
func (w *captureWorker) closeCapturer() {
|
||||
if w.cap != nil {
|
||||
w.cap.close()
|
||||
w.cap = nil
|
||||
}
|
||||
}
|
||||
533
client/vnc/server/capture_x11.go
Normal file
533
client/vnc/server/capture_x11.go
Normal file
@@ -0,0 +1,533 @@
|
||||
//go:build unix && !darwin && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/jezek/xgb"
|
||||
"github.com/jezek/xgb/xproto"
|
||||
)
|
||||
|
||||
const (
|
||||
// x11SocketDir is the well-known directory where X servers create
|
||||
// their abstract UNIX-domain sockets, named "X<display>". Used both
|
||||
// for auto-detecting an existing display and for placing/probing
|
||||
// sockets of virtual sessions we spawn.
|
||||
x11SocketDir = "/tmp/.X11-unix"
|
||||
|
||||
// envDisplay is the X11 display selector environment variable.
|
||||
envDisplay = "DISPLAY"
|
||||
// envXAuthority points X clients at the cookie file used to
|
||||
// authenticate against the running X server.
|
||||
envXAuthority = "XAUTHORITY"
|
||||
)
|
||||
|
||||
// X11Capturer captures the screen from an X11 display using the MIT-SHM extension.
|
||||
type X11Capturer struct {
|
||||
mu sync.Mutex
|
||||
conn *xgb.Conn
|
||||
screen *xproto.ScreenInfo
|
||||
w, h int
|
||||
shmID int
|
||||
shmAddr []byte
|
||||
shmSeg uint32
|
||||
useSHM bool
|
||||
// bufs double-buffers output images so the X11Poller's capture loop can
|
||||
// overwrite one while the session is still encoding the other. Before
|
||||
// this, a single reused buffer would race with the reader. Allocation
|
||||
// happens on first use and on geometry change.
|
||||
bufs [2]*image.RGBA
|
||||
cur int
|
||||
// cursor is the XFixes binding used to report the current sprite.
|
||||
// Allocated lazily on the first Cursor call. cursorInitErr latches
|
||||
// a permanent init failure so we stop retrying every frame.
|
||||
cursor *xfixesCursor
|
||||
cursorInitErr error
|
||||
}
|
||||
|
||||
// detectX11Display finds the active X11 display and sets DISPLAY/XAUTHORITY
|
||||
// environment variables if needed. This is required when running as a system
|
||||
// service where these vars aren't set.
|
||||
func detectX11Display() {
|
||||
if os.Getenv(envDisplay) != "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Try /proc first (Linux), then ps fallback (FreeBSD and others).
|
||||
if detectX11FromProc() {
|
||||
return
|
||||
}
|
||||
if detectX11FromSockets() {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// detectX11FromProc scans /proc/*/cmdline for Xorg (Linux).
|
||||
func detectX11FromProc() bool {
|
||||
entries, err := os.ReadDir("/proc")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() {
|
||||
continue
|
||||
}
|
||||
cmdline, err := os.ReadFile("/proc/" + e.Name() + "/cmdline")
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if display, auth := parseXorgArgs(splitCmdline(cmdline)); display != "" {
|
||||
setDisplayEnv(display, auth)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// detectX11FromSockets checks /tmp/.X11-unix/ for X sockets and uses ps
|
||||
// to find the auth file. Works on FreeBSD and other systems without /proc.
|
||||
func detectX11FromSockets() bool {
|
||||
entries, err := os.ReadDir(x11SocketDir)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Pick the lowest numeric display rather than the lexically first
|
||||
// entry, so X10 doesn't win over X2.
|
||||
minDisplay := -1
|
||||
for _, e := range entries {
|
||||
name := e.Name()
|
||||
if len(name) < 2 || name[0] != 'X' {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(name[1:])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if minDisplay < 0 || n < minDisplay {
|
||||
minDisplay = n
|
||||
}
|
||||
}
|
||||
if minDisplay < 0 {
|
||||
return false
|
||||
}
|
||||
display := ":" + strconv.Itoa(minDisplay)
|
||||
os.Setenv(envDisplay, display)
|
||||
auth := findXorgAuthFromPS()
|
||||
if auth != "" {
|
||||
os.Setenv(envXAuthority, auth)
|
||||
log.Infof("auto-detected DISPLAY=%s (from socket) XAUTHORITY=%s (from ps)", display, auth)
|
||||
} else {
|
||||
log.Infof("auto-detected DISPLAY=%s (from socket)", display)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// findXorgAuthFromPS runs ps to find Xorg and extract its -auth argument.
|
||||
func findXorgAuthFromPS() string {
|
||||
out, err := exec.Command("ps", "auxww").Output()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, line := range strings.Split(string(out), "\n") {
|
||||
if !strings.Contains(line, "Xorg") && !strings.Contains(line, "/X ") {
|
||||
continue
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
for i, f := range fields {
|
||||
if f == "-auth" && i+1 < len(fields) {
|
||||
return fields[i+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseXorgArgs(args []string) (display, auth string) {
|
||||
if len(args) == 0 {
|
||||
return "", ""
|
||||
}
|
||||
base := args[0]
|
||||
if !(base == "Xorg" || base == "X" || len(base) > 0 && base[len(base)-1] == 'X' ||
|
||||
strings.Contains(base, "/Xorg") || strings.Contains(base, "/X")) {
|
||||
return "", ""
|
||||
}
|
||||
for i, arg := range args[1:] {
|
||||
if len(arg) > 0 && arg[0] == ':' {
|
||||
display = arg
|
||||
}
|
||||
if arg == "-auth" && i+2 < len(args) {
|
||||
auth = args[i+2]
|
||||
}
|
||||
}
|
||||
return display, auth
|
||||
}
|
||||
|
||||
func setDisplayEnv(display, auth string) {
|
||||
os.Setenv(envDisplay, display)
|
||||
if auth != "" {
|
||||
os.Setenv(envXAuthority, auth)
|
||||
log.Infof("auto-detected DISPLAY=%s XAUTHORITY=%s", display, auth)
|
||||
return
|
||||
}
|
||||
log.Infof("auto-detected DISPLAY=%s", display)
|
||||
}
|
||||
|
||||
func splitCmdline(data []byte) []string {
|
||||
var args []string
|
||||
for _, b := range splitNull(data) {
|
||||
if len(b) > 0 {
|
||||
args = append(args, string(b))
|
||||
}
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func splitNull(data []byte) [][]byte {
|
||||
var parts [][]byte
|
||||
start := 0
|
||||
for i, b := range data {
|
||||
if b == 0 {
|
||||
parts = append(parts, data[start:i])
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
if start < len(data) {
|
||||
parts = append(parts, data[start:])
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// NewX11Capturer connects to the X11 display and sets up shared memory capture.
|
||||
func NewX11Capturer(display string) (*X11Capturer, error) {
|
||||
if display == "" {
|
||||
detectX11Display()
|
||||
display = os.Getenv(envDisplay)
|
||||
}
|
||||
if display == "" {
|
||||
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
|
||||
}
|
||||
|
||||
conn, err := xgb.NewConnDisplay(display)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
|
||||
}
|
||||
|
||||
setup := xproto.Setup(conn)
|
||||
if len(setup.Roots) == 0 {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("no X11 screens")
|
||||
}
|
||||
screen := setup.Roots[0]
|
||||
|
||||
c := &X11Capturer{
|
||||
conn: conn,
|
||||
screen: &screen,
|
||||
w: int(screen.WidthInPixels),
|
||||
h: int(screen.HeightInPixels),
|
||||
}
|
||||
|
||||
if err := c.initSHM(); err != nil {
|
||||
log.Debugf("X11 SHM not available, using slow GetImage: %v", err)
|
||||
}
|
||||
|
||||
log.Infof("X11 capturer ready: %dx%d (display=%s, shm=%v)", c.w, c.h, display, c.useSHM)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// initSHM is implemented in capture_x11_shm_linux.go (requires SysV SHM).
|
||||
// On platforms without SysV SHM (FreeBSD), a stub returns an error and
|
||||
// the capturer falls back to GetImage.
|
||||
|
||||
// Width returns the screen width.
|
||||
func (c *X11Capturer) Width() int { return c.w }
|
||||
|
||||
// Height returns the screen height.
|
||||
func (c *X11Capturer) Height() int { return c.h }
|
||||
|
||||
// Capture returns the current screen as an RGBA image.
|
||||
func (c *X11Capturer) Capture() (*image.RGBA, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.useSHM {
|
||||
return c.captureSHM()
|
||||
}
|
||||
return c.captureGetImage()
|
||||
}
|
||||
|
||||
// CaptureInto fills the caller's destination buffer in one pass. The
|
||||
// source path (SHM or fallback GetImage) writes directly into dst.Pix
|
||||
// instead of going through the X11Capturer's internal double-buffer,
|
||||
// saving one full-frame memcpy per capture.
|
||||
func (c *X11Capturer) CaptureInto(dst *image.RGBA) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if dst.Rect.Dx() != c.w || dst.Rect.Dy() != c.h {
|
||||
return fmt.Errorf("dst size mismatch: dst=%dx%d capturer=%dx%d",
|
||||
dst.Rect.Dx(), dst.Rect.Dy(), c.w, c.h)
|
||||
}
|
||||
if c.useSHM {
|
||||
return c.captureSHMInto(dst)
|
||||
}
|
||||
return c.captureGetImageInto(dst)
|
||||
}
|
||||
|
||||
func (c *X11Capturer) captureGetImageInto(dst *image.RGBA) error {
|
||||
cookie := xproto.GetImage(c.conn, xproto.ImageFormatZPixmap,
|
||||
xproto.Drawable(c.screen.Root),
|
||||
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF)
|
||||
reply, err := cookie.Reply()
|
||||
if err != nil {
|
||||
return fmt.Errorf("GetImage: %w", err)
|
||||
}
|
||||
n := c.w * c.h * 4
|
||||
if len(reply.Data) < n {
|
||||
return fmt.Errorf("GetImage returned %d bytes, expected %d", len(reply.Data), n)
|
||||
}
|
||||
swizzleBGRAtoRGBA(dst.Pix, reply.Data)
|
||||
return nil
|
||||
}
|
||||
|
||||
// captureSHM is implemented in capture_x11_shm_linux.go.
|
||||
|
||||
func (c *X11Capturer) captureGetImage() (*image.RGBA, error) {
|
||||
cookie := xproto.GetImage(c.conn, xproto.ImageFormatZPixmap,
|
||||
xproto.Drawable(c.screen.Root),
|
||||
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF)
|
||||
|
||||
reply, err := cookie.Reply()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetImage: %w", err)
|
||||
}
|
||||
|
||||
data := reply.Data
|
||||
n := c.w * c.h * 4
|
||||
if len(data) < n {
|
||||
return nil, fmt.Errorf("GetImage returned %d bytes, expected %d", len(data), n)
|
||||
}
|
||||
|
||||
img := c.nextBuffer()
|
||||
swizzleBGRAtoRGBA(img.Pix, data)
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// nextBuffer returns the *image.RGBA the next capture should fill, advancing
|
||||
// the double-buffer index. Reallocates on geometry change.
|
||||
func (c *X11Capturer) nextBuffer() *image.RGBA {
|
||||
c.cur ^= 1
|
||||
b := c.bufs[c.cur]
|
||||
if b == nil || b.Rect.Dx() != c.w || b.Rect.Dy() != c.h {
|
||||
b = image.NewRGBA(image.Rect(0, 0, c.w, c.h))
|
||||
c.bufs[c.cur] = b
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Close releases X11 resources.
|
||||
func (c *X11Capturer) Close() {
|
||||
c.closeSHM()
|
||||
c.conn.Close()
|
||||
}
|
||||
|
||||
// closeSHM is implemented in capture_x11_shm_linux.go.
|
||||
|
||||
// X11Poller wraps X11Capturer with a staleness-cached on-demand Capture:
|
||||
// sessions drive captures themselves through the encoder goroutine, so we
|
||||
// don't need a background ticker. The last result is cached for a short
|
||||
// window so concurrent sessions coalesce into one capture.
|
||||
//
|
||||
// The capturer is allocated lazily on first use and released when all
|
||||
// clients disconnect, so an idle peer holds no X connection or SHM segment.
|
||||
type X11Poller struct {
|
||||
mu sync.Mutex
|
||||
|
||||
capturer *X11Capturer
|
||||
w, h int
|
||||
// closed at Close so callers can stop waiting on retry backoff.
|
||||
done chan struct{}
|
||||
|
||||
// lastFrame/lastAt implement a small cache: multiple near-simultaneous
|
||||
// Capture calls (multi-client, or input-coalesced) return the same
|
||||
// frame instead of hammering the X server.
|
||||
lastFrame *image.RGBA
|
||||
lastAt time.Time
|
||||
|
||||
// initBackoffUntil throttles capturer re-init when the X server is
|
||||
// unavailable or flapping.
|
||||
initBackoffUntil time.Time
|
||||
|
||||
clients atomic.Int32
|
||||
display string
|
||||
}
|
||||
|
||||
// initRetryBackoff gates capturer re-init attempts after a failure so we
|
||||
// don't spin on X server errors.
|
||||
const initRetryBackoff = 2 * time.Second
|
||||
|
||||
// NewX11Poller creates a lazy on-demand capturer for the given X display.
|
||||
func NewX11Poller(display string) *X11Poller {
|
||||
return &X11Poller{
|
||||
display: display,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// ClientConnect increments the active client count. The first client triggers
|
||||
// eager capturer initialisation so that the first FBUpdateRequest doesn't
|
||||
// pay the X11 connect + SHM attach latency.
|
||||
func (p *X11Poller) ClientConnect() {
|
||||
if p.clients.Add(1) == 1 {
|
||||
p.mu.Lock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// ClientDisconnect decrements the active client count. On the last
|
||||
// disconnect we close the underlying capturer so idle peers cost nothing.
|
||||
func (p *X11Poller) ClientDisconnect() {
|
||||
if p.clients.Add(-1) == 0 {
|
||||
p.mu.Lock()
|
||||
if p.capturer != nil {
|
||||
p.capturer.Close()
|
||||
p.capturer = nil
|
||||
p.lastFrame = nil
|
||||
}
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Close releases all resources. Subsequent Capture calls will fail.
|
||||
func (p *X11Poller) Close() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
select {
|
||||
case <-p.done:
|
||||
default:
|
||||
close(p.done)
|
||||
}
|
||||
if p.capturer != nil {
|
||||
p.capturer.Close()
|
||||
p.capturer = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Width returns the screen width. Triggers lazy init if needed.
|
||||
func (p *X11Poller) Width() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
return p.w
|
||||
}
|
||||
|
||||
// Height returns the screen height. Triggers lazy init if needed.
|
||||
func (p *X11Poller) Height() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
return p.h
|
||||
}
|
||||
|
||||
// Cursor satisfies cursorSource by forwarding to the lazily-initialised
|
||||
// X11Capturer. Asking for the cursor on an idle poller triggers the same
|
||||
// lazy X11 connection setup as a capture would.
|
||||
func (p *X11Poller) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return nil, 0, 0, 0, err
|
||||
}
|
||||
return p.capturer.Cursor()
|
||||
}
|
||||
|
||||
// CursorPos satisfies cursorPositionSource by forwarding to the X11Capturer.
|
||||
func (p *X11Poller) CursorPos() (int, int, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return p.capturer.CursorPos()
|
||||
}
|
||||
|
||||
// Capture returns a fresh frame, serving from the short-lived cache if a
|
||||
// previous caller captured within freshWindow.
|
||||
func (p *X11Poller) Capture() (*image.RGBA, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.lastFrame != nil && time.Since(p.lastAt) < freshWindow {
|
||||
return p.lastFrame, nil
|
||||
}
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
img, err := p.capturer.Capture()
|
||||
if err != nil {
|
||||
// Drop the capturer so the next call re-inits; the X connection may
|
||||
// have died (e.g. Xorg restart).
|
||||
p.capturer.Close()
|
||||
p.capturer = nil
|
||||
p.initBackoffUntil = time.Now().Add(initRetryBackoff)
|
||||
return nil, fmt.Errorf("x11 capture: %w", err)
|
||||
}
|
||||
p.lastFrame = img
|
||||
p.lastAt = time.Now()
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// CaptureInto fills dst directly via the underlying capturer, bypassing
|
||||
// the freshness cache. The session's prevFrame/curFrame swap means each
|
||||
// session needs its own buffer anyway, so caching wouldn't help.
|
||||
func (p *X11Poller) CaptureInto(dst *image.RGBA) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := p.capturer.CaptureInto(dst); err != nil {
|
||||
p.capturer.Close()
|
||||
p.capturer = nil
|
||||
p.initBackoffUntil = time.Now().Add(initRetryBackoff)
|
||||
return fmt.Errorf("x11 capture: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureCapturerLocked initialises the underlying X11Capturer if not
|
||||
// already open. Caller must hold p.mu.
|
||||
func (p *X11Poller) ensureCapturerLocked() error {
|
||||
if p.capturer != nil {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-p.done:
|
||||
return fmt.Errorf("x11 capturer closed")
|
||||
default:
|
||||
}
|
||||
if time.Now().Before(p.initBackoffUntil) {
|
||||
return fmt.Errorf("x11 capturer unavailable (retry scheduled)")
|
||||
}
|
||||
c, err := NewX11Capturer(p.display)
|
||||
if err != nil {
|
||||
p.initBackoffUntil = time.Now().Add(initRetryBackoff)
|
||||
log.Debugf("X11 capturer: %v", err)
|
||||
return err
|
||||
}
|
||||
p.capturer = c
|
||||
p.w, p.h = c.Width(), c.Height()
|
||||
return nil
|
||||
}
|
||||
96
client/vnc/server/capture_x11_shm_linux.go
Normal file
96
client/vnc/server/capture_x11_shm_linux.go
Normal file
@@ -0,0 +1,96 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
|
||||
"github.com/jezek/xgb/shm"
|
||||
"github.com/jezek/xgb/xproto"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func (c *X11Capturer) initSHM() error {
|
||||
if err := shm.Init(c.conn); err != nil {
|
||||
return fmt.Errorf("init SHM extension: %w", err)
|
||||
}
|
||||
|
||||
size := c.w * c.h * 4
|
||||
id, err := unix.SysvShmGet(unix.IPC_PRIVATE, size, unix.IPC_CREAT|0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("shmget: %w", err)
|
||||
}
|
||||
|
||||
addr, err := unix.SysvShmAttach(id, 0, 0)
|
||||
if err != nil {
|
||||
if _, ctlErr := unix.SysvShmCtl(id, unix.IPC_RMID, nil); ctlErr != nil {
|
||||
log.Debugf("shmctl IPC_RMID on attach failure: %v", ctlErr)
|
||||
}
|
||||
return fmt.Errorf("shmat: %w", err)
|
||||
}
|
||||
|
||||
if _, err := unix.SysvShmCtl(id, unix.IPC_RMID, nil); err != nil {
|
||||
log.Debugf("shmctl IPC_RMID: %v", err)
|
||||
}
|
||||
|
||||
seg, err := shm.NewSegId(c.conn)
|
||||
if err != nil {
|
||||
if detachErr := unix.SysvShmDetach(addr); detachErr != nil {
|
||||
log.Debugf("shmdt on new-seg failure: %v", detachErr)
|
||||
}
|
||||
return fmt.Errorf("new SHM seg: %w", err)
|
||||
}
|
||||
|
||||
if err := shm.AttachChecked(c.conn, seg, uint32(id), false).Check(); err != nil {
|
||||
if detachErr := unix.SysvShmDetach(addr); detachErr != nil {
|
||||
log.Debugf("shmdt on attach-checked failure: %v", detachErr)
|
||||
}
|
||||
return fmt.Errorf("SHM attach to X: %w", err)
|
||||
}
|
||||
|
||||
c.shmID = id
|
||||
c.shmAddr = addr
|
||||
c.shmSeg = uint32(seg)
|
||||
c.useSHM = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
|
||||
if err := c.fillSHM(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
img := c.nextBuffer()
|
||||
swizzleBGRAtoRGBA(img.Pix, c.shmAddr[:c.w*c.h*4])
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// captureSHMInto runs a single SHM GetImage and swizzles directly into the
|
||||
// caller-provided destination, skipping the internal double-buffer.
|
||||
func (c *X11Capturer) captureSHMInto(dst *image.RGBA) error {
|
||||
if err := c.fillSHM(); err != nil {
|
||||
return err
|
||||
}
|
||||
swizzleBGRAtoRGBA(dst.Pix, c.shmAddr[:c.w*c.h*4])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *X11Capturer) fillSHM() error {
|
||||
cookie := shm.GetImage(c.conn, xproto.Drawable(c.screen.Root),
|
||||
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF,
|
||||
xproto.ImageFormatZPixmap, shm.Seg(c.shmSeg), 0)
|
||||
if _, err := cookie.Reply(); err != nil {
|
||||
return fmt.Errorf("SHM GetImage: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *X11Capturer) closeSHM() {
|
||||
if c.useSHM {
|
||||
shm.Detach(c.conn, shm.Seg(c.shmSeg))
|
||||
if err := unix.SysvShmDetach(c.shmAddr); err != nil {
|
||||
log.Debugf("shmdt on close: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
24
client/vnc/server/capture_x11_shm_stub.go
Normal file
24
client/vnc/server/capture_x11_shm_stub.go
Normal file
@@ -0,0 +1,24 @@
|
||||
//go:build freebsd
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
)
|
||||
|
||||
func (c *X11Capturer) initSHM() error {
|
||||
return fmt.Errorf("SysV SHM not available on this platform")
|
||||
}
|
||||
|
||||
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
|
||||
return nil, fmt.Errorf("SHM capture not available on this platform")
|
||||
}
|
||||
|
||||
func (c *X11Capturer) captureSHMInto(_ *image.RGBA) error {
|
||||
return fmt.Errorf("SHM capture not available on this platform")
|
||||
}
|
||||
|
||||
func (c *X11Capturer) closeSHM() {
|
||||
// no SHM to close on this platform
|
||||
}
|
||||
77
client/vnc/server/coalesce_test.go
Normal file
77
client/vnc/server/coalesce_test.go
Normal file
@@ -0,0 +1,77 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCoalesceRects(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in [][4]int
|
||||
want [][4]int
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
in: nil,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "single",
|
||||
in: [][4]int{{0, 0, 64, 64}},
|
||||
want: [][4]int{{0, 0, 64, 64}},
|
||||
},
|
||||
{
|
||||
name: "horizontal_run",
|
||||
in: [][4]int{{0, 0, 64, 64}, {64, 0, 64, 64}, {128, 0, 64, 64}},
|
||||
want: [][4]int{{0, 0, 192, 64}},
|
||||
},
|
||||
{
|
||||
name: "vertical_run",
|
||||
in: [][4]int{{0, 0, 64, 64}, {0, 64, 64, 64}, {0, 128, 64, 64}},
|
||||
want: [][4]int{{0, 0, 64, 192}},
|
||||
},
|
||||
{
|
||||
name: "block_2x2",
|
||||
in: [][4]int{
|
||||
{0, 0, 64, 64}, {64, 0, 64, 64},
|
||||
{0, 64, 64, 64}, {64, 64, 64, 64},
|
||||
},
|
||||
want: [][4]int{{0, 0, 128, 128}},
|
||||
},
|
||||
{
|
||||
name: "no_merge_gap",
|
||||
in: [][4]int{{0, 0, 64, 64}, {192, 0, 64, 64}},
|
||||
want: [][4]int{{0, 0, 64, 64}, {192, 0, 64, 64}},
|
||||
},
|
||||
{
|
||||
name: "two_disjoint_columns",
|
||||
in: [][4]int{
|
||||
{0, 0, 64, 64}, {192, 0, 64, 64},
|
||||
{0, 64, 64, 64}, {192, 64, 64, 64},
|
||||
},
|
||||
want: [][4]int{{0, 0, 64, 128}, {192, 0, 64, 128}},
|
||||
},
|
||||
{
|
||||
name: "misaligned_widths_no_vertical_merge",
|
||||
in: [][4]int{
|
||||
{0, 0, 128, 64},
|
||||
{0, 64, 64, 64},
|
||||
},
|
||||
want: [][4]int{{0, 0, 128, 64}, {0, 64, 64, 64}},
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := coalesceRects(tc.in)
|
||||
if len(got) == 0 && len(tc.want) == 0 {
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Fatalf("got %v want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
191
client/vnc/server/copyrect.go
Normal file
191
client/vnc/server/copyrect.go
Normal file
@@ -0,0 +1,191 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"hash/maphash"
|
||||
"image"
|
||||
)
|
||||
|
||||
// copyRectDetector finds tiles in the current frame that match the content
|
||||
// of some tile-aligned region of the previous frame, so we can emit them as
|
||||
// CopyRect rectangles (16 wire bytes) instead of re-encoding the pixels.
|
||||
//
|
||||
// The detector keeps two structures:
|
||||
// - tileHash, a flat slice of one hash per tile-aligned position, used as
|
||||
// the source of truth for the previous frame's tile content.
|
||||
// - prevTiles, a hash → position lookup used during findTileMatch.
|
||||
//
|
||||
// updateDirty rehashes only the tiles that changed this frame, so the
|
||||
// steady-state cost is proportional to the dirty set, not the framebuffer.
|
||||
// A full rebuild from scratch is only done on the first frame or when the
|
||||
// detector has not yet been initialized for the current resolution.
|
||||
//
|
||||
// Limitations:
|
||||
// - Only tile-aligned source positions are considered. Sub-tile-aligned
|
||||
// moves (e.g. window dragged by 7 pixels) are not detected. This still
|
||||
// covers the common case of vertical/horizontal scrolling, which always
|
||||
// produces tile-aligned matches at the tile granularity.
|
||||
// - 64-bit maphash collisions are assumed not to happen. The probability
|
||||
// for any single frame's hash universe is ~2^-32 * tileCount² which is
|
||||
// vanishingly small at typical resolutions; if we ever observe one we
|
||||
// can fall back to a full memcmp verification.
|
||||
type copyRectDetector struct {
|
||||
seed maphash.Seed
|
||||
tileSize int
|
||||
w, h int
|
||||
cols, rows int
|
||||
// tileHash[ty*cols + tx] is the current hash of the tile at (tx, ty)
|
||||
// in the previous frame. Lookup uses this to detect stale prevTiles
|
||||
// entries: incremental updates may leave hash→pos entries pointing
|
||||
// at a tile whose content has since changed.
|
||||
tileHash []uint64
|
||||
// prevTiles maps a tile hash to a (x, y) origin in the previous frame.
|
||||
prevTiles map[uint64][2]int
|
||||
// hash is reused across hash computations to keep the per-tile lookup
|
||||
// path allocation-free.
|
||||
hash maphash.Hash
|
||||
}
|
||||
|
||||
func newCopyRectDetector(tileSize int) *copyRectDetector {
|
||||
d := ©RectDetector{
|
||||
seed: maphash.MakeSeed(),
|
||||
tileSize: tileSize,
|
||||
prevTiles: make(map[uint64][2]int),
|
||||
}
|
||||
d.hash.SetSeed(d.seed)
|
||||
return d
|
||||
}
|
||||
|
||||
// resize ensures the per-tile tables match the given framebuffer size.
|
||||
// Called from rebuild before each full hash sweep.
|
||||
func (d *copyRectDetector) resize(w, h int) {
|
||||
if d.w == w && d.h == h && d.tileHash != nil {
|
||||
return
|
||||
}
|
||||
d.w, d.h = w, h
|
||||
d.cols = w / d.tileSize
|
||||
d.rows = h / d.tileSize
|
||||
d.tileHash = make([]uint64, d.cols*d.rows)
|
||||
}
|
||||
|
||||
// hashTile computes the 64-bit maphash of one tile-aligned tile of frame.
|
||||
func (d *copyRectDetector) hashTile(frame *image.RGBA, tx, ty int) uint64 {
|
||||
d.hash.Reset()
|
||||
ts := d.tileSize
|
||||
stride := frame.Stride
|
||||
rowBytes := ts * 4
|
||||
base := ty*stride + tx*4
|
||||
for row := 0; row < ts; row++ {
|
||||
off := base + row*stride
|
||||
_, _ = d.hash.Write(frame.Pix[off : off+rowBytes])
|
||||
}
|
||||
return d.hash.Sum64()
|
||||
}
|
||||
|
||||
// rebuild discards everything and rehashes the whole frame. O(w*h). Use
|
||||
// for the first frame or after the detector has been resized. Steady-state
|
||||
// updates should go through updateDirty instead.
|
||||
func (d *copyRectDetector) rebuild(frame *image.RGBA, w, h int) {
|
||||
d.resize(w, h)
|
||||
if d.prevTiles == nil {
|
||||
d.prevTiles = make(map[uint64][2]int)
|
||||
} else {
|
||||
clear(d.prevTiles)
|
||||
}
|
||||
ts := d.tileSize
|
||||
for ty := 0; ty+ts <= h; ty += ts {
|
||||
for tx := 0; tx+ts <= w; tx += ts {
|
||||
sum := d.hashTile(frame, tx, ty)
|
||||
d.tileHash[(ty/ts)*d.cols+(tx/ts)] = sum
|
||||
if _, exists := d.prevTiles[sum]; !exists {
|
||||
d.prevTiles[sum] = [2]int{tx, ty}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateDirty rehashes only the tiles named in dirty (each entry is
|
||||
// [x, y, w, h] with w and h equal to tileSize). O(len(dirty)) work, which
|
||||
// in the common case is a tiny fraction of the whole framebuffer.
|
||||
//
|
||||
// The prevTiles map is replaced on collision rather than first-wins so a
|
||||
// newly-hashed tile claims the slot. Old, stale entries pointing at tiles
|
||||
// that no longer carry that hash are filtered at lookup time via tileHash.
|
||||
func (d *copyRectDetector) updateDirty(frame *image.RGBA, w, h int, dirty [][4]int) {
|
||||
if d.w != w || d.h != h || d.tileHash == nil {
|
||||
d.rebuild(frame, w, h)
|
||||
return
|
||||
}
|
||||
ts := d.tileSize
|
||||
for _, r := range dirty {
|
||||
if r[2] != ts || r[3] != ts {
|
||||
continue
|
||||
}
|
||||
tx, ty := r[0], r[1]
|
||||
if tx+ts > w || ty+ts > h {
|
||||
continue
|
||||
}
|
||||
sum := d.hashTile(frame, tx, ty)
|
||||
d.tileHash[(ty/ts)*d.cols+(tx/ts)] = sum
|
||||
// Latest-wins on collision: ensures the most recent owner of this
|
||||
// hash is the one we'll return on lookup. The previous owner's
|
||||
// entry, if any, gets shadowed; if its content has changed it's
|
||||
// stale anyway and findTileMatch's verification will skip it.
|
||||
d.prevTiles[sum] = [2]int{tx, ty}
|
||||
}
|
||||
}
|
||||
|
||||
// findTileMatch hashes the current-frame tile at (dstX, dstY) and looks up
|
||||
// its hash in the previous-frame map. Returns (srcX, srcY, true) when a
|
||||
// matching tile-aligned tile exists at a different position whose stored
|
||||
// hash still equals the requested hash (so the result is not stale).
|
||||
func (d *copyRectDetector) findTileMatch(cur *image.RGBA, dstX, dstY int) (int, int, bool) {
|
||||
if len(d.prevTiles) == 0 || d.tileHash == nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
ts := d.tileSize
|
||||
if dstX+ts > cur.Rect.Dx() || dstY+ts > cur.Rect.Dy() {
|
||||
return 0, 0, false
|
||||
}
|
||||
sum := d.hashTile(cur, dstX, dstY)
|
||||
pos, ok := d.prevTiles[sum]
|
||||
if !ok {
|
||||
return 0, 0, false
|
||||
}
|
||||
if pos[0] == dstX && pos[1] == dstY {
|
||||
return 0, 0, false
|
||||
}
|
||||
// Reject stale entries: the position the map points at must still
|
||||
// carry the same hash according to our per-tile array.
|
||||
if d.tileHash[(pos[1]/ts)*d.cols+(pos[0]/ts)] != sum {
|
||||
return 0, 0, false
|
||||
}
|
||||
return pos[0], pos[1], true
|
||||
}
|
||||
|
||||
// extractCopyRectTiles examines the diff-produced (per-tile) dirty list and
|
||||
// pulls out any tiles whose current-frame content matches a prev-frame tile
|
||||
// at a different position. Returns the CopyRect candidates and the residual
|
||||
// dirty tiles that still need pixel encoding.
|
||||
type copyRectMove struct {
|
||||
srcX, srcY int
|
||||
dstX, dstY int
|
||||
}
|
||||
|
||||
func (d *copyRectDetector) extractCopyRectTiles(cur *image.RGBA, dirtyTiles [][4]int) (moves []copyRectMove, remaining [][4]int) {
|
||||
ts := d.tileSize
|
||||
remaining = dirtyTiles[:0:cap(dirtyTiles)]
|
||||
for _, r := range dirtyTiles {
|
||||
if r[2] == ts && r[3] == ts {
|
||||
if sx, sy, ok := d.findTileMatch(cur, r[0], r[1]); ok {
|
||||
moves = append(moves, copyRectMove{
|
||||
srcX: sx, srcY: sy, dstX: r[0], dstY: r[1],
|
||||
})
|
||||
continue
|
||||
}
|
||||
}
|
||||
remaining = append(remaining, r)
|
||||
}
|
||||
return moves, remaining
|
||||
}
|
||||
162
client/vnc/server/copyrect_test.go
Normal file
162
client/vnc/server/copyrect_test.go
Normal file
@@ -0,0 +1,162 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"image"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// fillTile paints a tileSize×tileSize block of img at (x,y) with the colour
|
||||
// derived from (r,g,b) so the test can construct distinct-content tiles.
|
||||
func fillTile(img *image.RGBA, x, y, ts int, r, g, b byte) {
|
||||
for row := 0; row < ts; row++ {
|
||||
off := (y+row)*img.Stride + x*4
|
||||
for col := 0; col < ts; col++ {
|
||||
img.Pix[off+col*4+0] = r
|
||||
img.Pix[off+col*4+1] = g
|
||||
img.Pix[off+col*4+2] = b
|
||||
img.Pix[off+col*4+3] = 0xff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copyTile copies a tileSize×tileSize block from src(sx,sy) to dst(dx,dy).
|
||||
func copyTile(dst, src *image.RGBA, sx, sy, dx, dy, ts int) {
|
||||
for row := 0; row < ts; row++ {
|
||||
srcOff := (sy+row)*src.Stride + sx*4
|
||||
dstOff := (dy+row)*dst.Stride + dx*4
|
||||
copy(dst.Pix[dstOff:dstOff+ts*4], src.Pix[srcOff:srcOff+ts*4])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyRectDetector_DetectsVerticalScroll(t *testing.T) {
|
||||
const w, h = 256, 192 // 4×3 tiles at 64px
|
||||
const ts = 64
|
||||
|
||||
prev := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
cur := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
|
||||
// prev: 12 tiles each with a unique colour.
|
||||
for ty := 0; ty < 3; ty++ {
|
||||
for tx := 0; tx < 4; tx++ {
|
||||
fillTile(prev, tx*ts, ty*ts, ts, byte(tx*40), byte(ty*60), 0x80)
|
||||
}
|
||||
}
|
||||
// cur: simulate a single-tile-row scroll upward, every tile copied from
|
||||
// the row below in prev, top row is new content.
|
||||
for ty := 0; ty < 2; ty++ {
|
||||
for tx := 0; tx < 4; tx++ {
|
||||
copyTile(cur, prev, tx*ts, (ty+1)*ts, tx*ts, ty*ts, ts)
|
||||
}
|
||||
}
|
||||
// Bottom row of cur: new colour, not a match.
|
||||
for tx := 0; tx < 4; tx++ {
|
||||
fillTile(cur, tx*ts, 2*ts, ts, 0xff, 0xff, 0xff)
|
||||
}
|
||||
|
||||
d := newCopyRectDetector(ts)
|
||||
d.rebuild(prev, w, h)
|
||||
|
||||
tiles := diffTiles(prev, cur, w, h, ts)
|
||||
moves, remaining := d.extractCopyRectTiles(cur, tiles)
|
||||
|
||||
// Expect 8 CopyRect moves (top two rows) and 4 residual tiles (bottom row).
|
||||
if len(moves) != 8 {
|
||||
t.Fatalf("moves: want 8, got %d", len(moves))
|
||||
}
|
||||
if len(remaining) != 4 {
|
||||
t.Fatalf("remaining: want 4, got %d", len(remaining))
|
||||
}
|
||||
// Spot-check one move: cur (0, 0) should map to prev (0, 64).
|
||||
var found bool
|
||||
for _, m := range moves {
|
||||
if m.dstX == 0 && m.dstY == 0 {
|
||||
if m.srcX != 0 || m.srcY != ts {
|
||||
t.Fatalf("move at (0,0): src=(%d,%d), want (0,%d)", m.srcX, m.srcY, ts)
|
||||
}
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("no move for dst (0,0)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyRectDetector_RejectsSelfMatch(t *testing.T) {
|
||||
const w, h = 128, 128
|
||||
const ts = 64
|
||||
|
||||
prev := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
cur := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
|
||||
// prev: 4 tiles, all unique
|
||||
fillTile(prev, 0, 0, ts, 0x10, 0x20, 0x30)
|
||||
fillTile(prev, ts, 0, ts, 0x40, 0x50, 0x60)
|
||||
fillTile(prev, 0, ts, ts, 0x70, 0x80, 0x90)
|
||||
fillTile(prev, ts, ts, ts, 0xa0, 0xb0, 0xc0)
|
||||
|
||||
// cur: tile (0,0) unchanged, others changed but content same as prev's (0,0).
|
||||
fillTile(cur, 0, 0, ts, 0x10, 0x20, 0x30) // self-match
|
||||
fillTile(cur, ts, 0, ts, 0xff, 0xff, 0xff)
|
||||
fillTile(cur, 0, ts, ts, 0xff, 0xff, 0xff)
|
||||
fillTile(cur, ts, ts, ts, 0xff, 0xff, 0xff)
|
||||
|
||||
d := newCopyRectDetector(ts)
|
||||
d.rebuild(prev, w, h)
|
||||
|
||||
// Tile (0,0) is not in the dirty list (it's unchanged) so it should not
|
||||
// produce a move even though its hash matches prev (0,0).
|
||||
tiles := diffTiles(prev, cur, w, h, ts)
|
||||
moves, _ := d.extractCopyRectTiles(cur, tiles)
|
||||
for _, m := range moves {
|
||||
if m.dstX == 0 && m.dstY == 0 {
|
||||
t.Fatalf("unexpected move at (0,0)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyRectDetector_PassThroughWhenNoMatch(t *testing.T) {
|
||||
const w, h = 64, 64
|
||||
const ts = 64
|
||||
|
||||
prev := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
cur := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
fillTile(prev, 0, 0, ts, 0x11, 0x22, 0x33)
|
||||
fillTile(cur, 0, 0, ts, 0xaa, 0xbb, 0xcc) // wholly different
|
||||
|
||||
d := newCopyRectDetector(ts)
|
||||
d.rebuild(prev, w, h)
|
||||
tiles := diffTiles(prev, cur, w, h, ts)
|
||||
moves, remaining := d.extractCopyRectTiles(cur, tiles)
|
||||
|
||||
if len(moves) != 0 {
|
||||
t.Fatalf("expected 0 moves, got %d", len(moves))
|
||||
}
|
||||
if len(remaining) != 1 {
|
||||
t.Fatalf("expected 1 residual tile, got %d", len(remaining))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeCopyRectBody_Layout(t *testing.T) {
|
||||
got := encodeCopyRectBody(100, 200, 300, 400, 64, 48)
|
||||
if len(got) != 16 {
|
||||
t.Fatalf("CopyRect body length: want 16, got %d", len(got))
|
||||
}
|
||||
// Dest position
|
||||
if got[0] != 0x01 || got[1] != 0x2c || got[2] != 0x01 || got[3] != 0x90 {
|
||||
t.Fatalf("bad dest bytes: % x", got[0:4])
|
||||
}
|
||||
// Width, height
|
||||
if got[4] != 0 || got[5] != 64 || got[6] != 0 || got[7] != 48 {
|
||||
t.Fatalf("bad size bytes: % x", got[4:8])
|
||||
}
|
||||
// Encoding = 1
|
||||
if got[11] != 0x01 {
|
||||
t.Fatalf("bad encoding byte: 0x%02x", got[11])
|
||||
}
|
||||
// Source position
|
||||
if got[12] != 0 || got[13] != 100 || got[14] != 0 || got[15] != 200 {
|
||||
t.Fatalf("bad src bytes: % x", got[12:16])
|
||||
}
|
||||
}
|
||||
194
client/vnc/server/cursor_darwin.go
Normal file
194
client/vnc/server/cursor_darwin.go
Normal file
@@ -0,0 +1,194 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"image"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
darwinCursorOnce sync.Once
|
||||
cgsCreateCursor func() uintptr
|
||||
darwinCursorErr error
|
||||
)
|
||||
|
||||
// initDarwinCursor binds a private symbol that returns the current
|
||||
// system cursor image. The classic CGSCreateCurrentCursorImage moved
|
||||
// from CoreGraphics to SkyLight around macOS 13 and is gone entirely
|
||||
// in Sequoia; we probe both frameworks for any of the historical
|
||||
// names so this keeps working on whichever release the binding still
|
||||
// exists. Without a hit the remote-cursor compositing path becomes a
|
||||
// no-op and we log the candidates we tried.
|
||||
func initDarwinCursor() {
|
||||
darwinCursorOnce.Do(func() {
|
||||
libs := []string{
|
||||
"/System/Library/PrivateFrameworks/SkyLight.framework/SkyLight",
|
||||
"/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics",
|
||||
}
|
||||
names := []string{
|
||||
"CGSCreateCurrentCursorImage",
|
||||
"CGSCopyCurrentCursorImage",
|
||||
"CGSCurrentCursorImage",
|
||||
"CGSHardwareCursorActiveImage",
|
||||
}
|
||||
var tried []string
|
||||
for _, path := range libs {
|
||||
h, err := purego.Dlopen(path, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
tried = append(tried, fmt.Sprintf("dlopen %s: %v", path, err))
|
||||
continue
|
||||
}
|
||||
for _, name := range names {
|
||||
sym, err := purego.Dlsym(h, name)
|
||||
if err != nil {
|
||||
tried = append(tried, fmt.Sprintf("%s!%s missing", path, name))
|
||||
continue
|
||||
}
|
||||
purego.RegisterFunc(&cgsCreateCursor, sym)
|
||||
log.Infof("macOS cursor: bound %s from %s", name, path)
|
||||
return
|
||||
}
|
||||
}
|
||||
darwinCursorErr = fmt.Errorf("no cursor image symbol available; tried: %v", tried)
|
||||
})
|
||||
}
|
||||
|
||||
// cgCursor holds the cached macOS cursor sprite and bumps a serial when
|
||||
// the bytes change. Hotspot is left at (0, 0): the public Cocoa hot-spot
|
||||
// query lives on NSCursor which is process-local and not reachable from
|
||||
// our purego-based bindings; the visual cost is a small misalignment for
|
||||
// non-arrow cursors (I-beam, crosshair, etc.).
|
||||
type cgCursor struct {
|
||||
mu sync.Mutex
|
||||
hashSeed maphash.Seed
|
||||
lastSum uint64
|
||||
cached *image.RGBA
|
||||
serial uint64
|
||||
}
|
||||
|
||||
func newCGCursor() *cgCursor {
|
||||
initDarwinCursor()
|
||||
return &cgCursor{hashSeed: maphash.MakeSeed()}
|
||||
}
|
||||
|
||||
// Cursor returns the current cursor sprite as RGBA. Errors that come from
|
||||
// missing private symbols are sticky; transient empty-image responses are
|
||||
// reported as such so the encoder skips this cycle.
|
||||
func (c *cgCursor) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if darwinCursorErr != nil {
|
||||
return nil, 0, 0, 0, darwinCursorErr
|
||||
}
|
||||
if cgsCreateCursor == nil {
|
||||
return nil, 0, 0, 0, fmt.Errorf("CGSCreateCurrentCursorImage unavailable")
|
||||
}
|
||||
cgImage := cgsCreateCursor()
|
||||
if cgImage == 0 {
|
||||
return nil, 0, 0, 0, fmt.Errorf("no cursor image available")
|
||||
}
|
||||
defer cgImageRelease(cgImage)
|
||||
|
||||
w := int(cgImageGetWidth(cgImage))
|
||||
h := int(cgImageGetHeight(cgImage))
|
||||
if w <= 0 || h <= 0 {
|
||||
return nil, 0, 0, 0, fmt.Errorf("cursor has zero extent")
|
||||
}
|
||||
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
|
||||
bpp := int(cgImageGetBitsPerPixel(cgImage))
|
||||
if bpp != 32 {
|
||||
return nil, 0, 0, 0, fmt.Errorf("unsupported cursor bpp: %d", bpp)
|
||||
}
|
||||
provider := cgImageGetDataProvider(cgImage)
|
||||
if provider == 0 {
|
||||
return nil, 0, 0, 0, fmt.Errorf("cursor data provider missing")
|
||||
}
|
||||
cfData := cgDataProviderCopyData(provider)
|
||||
if cfData == 0 {
|
||||
return nil, 0, 0, 0, fmt.Errorf("cursor data copy failed")
|
||||
}
|
||||
defer cfRelease(cfData)
|
||||
dataLen := int(cfDataGetLength(cfData))
|
||||
dataPtr := cfDataGetBytePtr(cfData)
|
||||
if dataPtr == 0 || dataLen == 0 {
|
||||
return nil, 0, 0, 0, fmt.Errorf("cursor data empty")
|
||||
}
|
||||
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
|
||||
|
||||
sum := maphash.Bytes(c.hashSeed, src)
|
||||
if c.cached != nil && sum == c.lastSum {
|
||||
return c.cached, 0, 0, c.serial, nil
|
||||
}
|
||||
|
||||
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
for y := 0; y < h; y++ {
|
||||
srcOff := y * bytesPerRow
|
||||
dstOff := y * w * 4
|
||||
for x := 0; x < w; x++ {
|
||||
si := srcOff + x*4
|
||||
di := dstOff + x*4
|
||||
img.Pix[di+0] = src[si+2]
|
||||
img.Pix[di+1] = src[si+1]
|
||||
img.Pix[di+2] = src[si+0]
|
||||
img.Pix[di+3] = src[si+3]
|
||||
}
|
||||
}
|
||||
|
||||
c.lastSum = sum
|
||||
c.cached = img
|
||||
c.serial++
|
||||
return img, 0, 0, c.serial, nil
|
||||
}
|
||||
|
||||
// Cursor on CGCapturer satisfies cursorSource. The cgCursor wrapper is
|
||||
// allocated lazily so a build that never asks for the cursor pays no cost.
|
||||
func (c *CGCapturer) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
c.cursorOnce.Do(func() {
|
||||
c.cursor = newCGCursor()
|
||||
})
|
||||
return c.cursor.Cursor()
|
||||
}
|
||||
|
||||
// CursorPos returns the current global mouse location via CGEventCreate /
|
||||
// CGEventGetLocation. Coordinates are screen pixels in the main display.
|
||||
func (c *CGCapturer) CursorPos() (int, int, error) {
|
||||
if cgEventCreate == nil || cgEventGetLocation == nil {
|
||||
return 0, 0, fmt.Errorf("CGEvent location APIs unavailable")
|
||||
}
|
||||
ev := cgEventCreate(0)
|
||||
if ev == 0 {
|
||||
return 0, 0, fmt.Errorf("CGEventCreate returned nil")
|
||||
}
|
||||
defer cfRelease(ev)
|
||||
pt := cgEventGetLocation(ev)
|
||||
return int(pt.X), int(pt.Y), nil
|
||||
}
|
||||
|
||||
// Cursor on MacPoller forwards to the lazy CGCapturer. ensureCapturerLocked
|
||||
// returns an error when Screen Recording permission has not been granted;
|
||||
// in that case there is no usable cursor source either.
|
||||
func (p *MacPoller) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return nil, 0, 0, 0, err
|
||||
}
|
||||
return p.capturer.Cursor()
|
||||
}
|
||||
|
||||
// CursorPos forwards to the lazy CGCapturer.
|
||||
func (p *MacPoller) CursorPos() (int, int, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return p.capturer.CursorPos()
|
||||
}
|
||||
407
client/vnc/server/cursor_windows.go
Normal file
407
client/vnc/server/cursor_windows.go
Normal file
@@ -0,0 +1,407 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
procGetCursorInfo = user32.NewProc("GetCursorInfo")
|
||||
procGetIconInfo = user32.NewProc("GetIconInfo")
|
||||
procGetObjectW = gdi32.NewProc("GetObjectW")
|
||||
procGetDIBits = gdi32.NewProc("GetDIBits")
|
||||
)
|
||||
|
||||
const (
|
||||
cursorShowing = 0x00000001
|
||||
diRgbColors = 0
|
||||
biRgb = 0
|
||||
dibSectionBytes = 40 // sizeof(BITMAPINFOHEADER)
|
||||
)
|
||||
|
||||
// hiddenHandle is a sentinel stored in cursorSampler.lastHandle while
|
||||
// Windows reports the cursor as hidden. It is not a valid HCURSOR value;
|
||||
// real handles never collide with this constant.
|
||||
const hiddenHandle = windows.Handle(^uintptr(0))
|
||||
|
||||
// transparentCursorImage returns a 1x1 fully transparent sprite. The
|
||||
// client renders this as "no cursor"; emitting it explicitly lets us
|
||||
// recover when an app un-hides the cursor a moment later.
|
||||
func transparentCursorImage() *image.RGBA {
|
||||
return image.NewRGBA(image.Rect(0, 0, 1, 1))
|
||||
}
|
||||
|
||||
type winPoint struct {
|
||||
X, Y int32
|
||||
}
|
||||
|
||||
type winCursorInfo struct {
|
||||
Size uint32
|
||||
Flags uint32
|
||||
Cursor windows.Handle
|
||||
PtPos winPoint
|
||||
}
|
||||
|
||||
type winIconInfo struct {
|
||||
FIcon int32
|
||||
XHotspot uint32
|
||||
YHotspot uint32
|
||||
HbmMask windows.Handle
|
||||
HbmColor windows.Handle
|
||||
}
|
||||
|
||||
type winBitmap struct {
|
||||
BmType int32
|
||||
BmWidth int32
|
||||
BmHeight int32
|
||||
BmWidthBytes int32
|
||||
BmPlanes uint16
|
||||
BmBitsPixel uint16
|
||||
BmBits uintptr
|
||||
}
|
||||
|
||||
type winBitmapInfoHeader struct {
|
||||
BiSize uint32
|
||||
BiWidth int32
|
||||
BiHeight int32
|
||||
BiPlanes uint16
|
||||
BiBitCount uint16
|
||||
BiCompression uint32
|
||||
BiSizeImage uint32
|
||||
BiXPelsPerMeter int32
|
||||
BiYPelsPerMeter int32
|
||||
BiClrUsed uint32
|
||||
BiClrImportant uint32
|
||||
}
|
||||
|
||||
// cursorSnapshot is the captured cursor state shared between the worker
|
||||
// (which polls the OS) and the session encoder (which reads it).
|
||||
type cursorSnapshot struct {
|
||||
img *image.RGBA
|
||||
hotX int
|
||||
hotY int
|
||||
posX int
|
||||
posY int
|
||||
hasPos bool
|
||||
serial uint64
|
||||
err error
|
||||
}
|
||||
|
||||
// cursorSampler captures the foreground process's cursor sprite via Win32
|
||||
// APIs. It must be called from a goroutine attached to the same window
|
||||
// station and desktop as the user session (the capture worker does this
|
||||
// via switchToInputDesktop). lastHandle dedupes per-shape work so we only
|
||||
// touch GDI when Windows hands us a new cursor.
|
||||
type cursorSampler struct {
|
||||
lastHandle windows.Handle
|
||||
serial uint64
|
||||
snapshot *cursorSnapshot
|
||||
}
|
||||
|
||||
// sample queries the current cursor and decodes a new sprite when Windows
|
||||
// reports a different HCURSOR than last time. Returns the current snapshot
|
||||
// regardless of whether anything changed; callers diff by serial.
|
||||
func (s *cursorSampler) sample() (*cursorSnapshot, error) {
|
||||
var ci winCursorInfo
|
||||
ci.Size = uint32(unsafe.Sizeof(ci))
|
||||
r, _, err := procGetCursorInfo.Call(uintptr(unsafe.Pointer(&ci)))
|
||||
if r == 0 {
|
||||
return nil, fmt.Errorf("GetCursorInfo: %w", err)
|
||||
}
|
||||
if ci.Flags&cursorShowing == 0 || ci.Cursor == 0 {
|
||||
// Cursor temporarily hidden by an app (text fields toggle it on
|
||||
// focus). Emit a 1x1 transparent sprite so the client renders no
|
||||
// cursor and stay armed for the next handle change rather than
|
||||
// treating this as a hard failure that would latch us off for
|
||||
// the session.
|
||||
if s.lastHandle == hiddenHandle {
|
||||
s.snapshot.posX = int(ci.PtPos.X)
|
||||
s.snapshot.posY = int(ci.PtPos.Y)
|
||||
s.snapshot.hasPos = true
|
||||
return s.snapshot, nil
|
||||
}
|
||||
s.lastHandle = hiddenHandle
|
||||
s.serial++
|
||||
s.snapshot = &cursorSnapshot{
|
||||
img: transparentCursorImage(),
|
||||
posX: int(ci.PtPos.X),
|
||||
posY: int(ci.PtPos.Y),
|
||||
hasPos: true,
|
||||
serial: s.serial,
|
||||
}
|
||||
return s.snapshot, nil
|
||||
}
|
||||
if ci.Cursor == s.lastHandle && s.snapshot != nil {
|
||||
s.snapshot.posX = int(ci.PtPos.X)
|
||||
s.snapshot.posY = int(ci.PtPos.Y)
|
||||
s.snapshot.hasPos = true
|
||||
return s.snapshot, nil
|
||||
}
|
||||
img, hotX, hotY, err := decodeCursor(ci.Cursor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.lastHandle = ci.Cursor
|
||||
s.serial++
|
||||
s.snapshot = &cursorSnapshot{
|
||||
img: img,
|
||||
hotX: hotX,
|
||||
hotY: hotY,
|
||||
posX: int(ci.PtPos.X),
|
||||
posY: int(ci.PtPos.Y),
|
||||
hasPos: true,
|
||||
serial: s.serial,
|
||||
}
|
||||
return s.snapshot, nil
|
||||
}
|
||||
|
||||
// decodeCursor extracts the sprite at hCur as RGBA along with the hotspot.
|
||||
// Color cursors are read from the colour bitmap with the AND mask combined
|
||||
// in for alpha. Monochrome cursors collapse the two halves of the mask
|
||||
// bitmap into a single visible sprite where the AND bit drives alpha.
|
||||
func decodeCursor(hCur windows.Handle) (*image.RGBA, int, int, error) {
|
||||
var info winIconInfo
|
||||
r, _, err := procGetIconInfo.Call(uintptr(hCur), uintptr(unsafe.Pointer(&info)))
|
||||
if r == 0 {
|
||||
return nil, 0, 0, fmt.Errorf("GetIconInfo: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if info.HbmMask != 0 {
|
||||
_, _, _ = procDeleteObject.Call(uintptr(info.HbmMask))
|
||||
}
|
||||
if info.HbmColor != 0 {
|
||||
_, _, _ = procDeleteObject.Call(uintptr(info.HbmColor))
|
||||
}
|
||||
}()
|
||||
hotX, hotY := int(info.XHotspot), int(info.YHotspot)
|
||||
if info.HbmColor != 0 {
|
||||
img, err := decodeColorCursor(info.HbmColor, info.HbmMask)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
return img, hotX, hotY, nil
|
||||
}
|
||||
img, err := decodeMonoCursor(info.HbmMask)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
return img, hotX, hotY, nil
|
||||
}
|
||||
|
||||
// readBitmap returns the BITMAP descriptor for hbm.
|
||||
func readBitmap(hbm windows.Handle) (winBitmap, error) {
|
||||
var bm winBitmap
|
||||
r, _, err := procGetObjectW.Call(uintptr(hbm), unsafe.Sizeof(bm), uintptr(unsafe.Pointer(&bm)))
|
||||
if r == 0 {
|
||||
return winBitmap{}, fmt.Errorf("GetObject: %w", err)
|
||||
}
|
||||
return bm, nil
|
||||
}
|
||||
|
||||
// dibCopy reads hbm as 32bpp top-down BGRA into a freshly allocated slice
|
||||
// matching w*h*4 bytes. The bitmap may be selected into the screen DC so
|
||||
// we use a memory DC to keep the call cheap.
|
||||
func dibCopy(hbm windows.Handle, w, h int32) ([]byte, error) {
|
||||
hdcScreen, _, _ := procGetDC.Call(0)
|
||||
if hdcScreen == 0 {
|
||||
return nil, fmt.Errorf("GetDC: failed")
|
||||
}
|
||||
defer func() { _, _, _ = procReleaseDC.Call(0, hdcScreen) }()
|
||||
hdcMem, _, _ := procCreateCompatDC.Call(hdcScreen)
|
||||
if hdcMem == 0 {
|
||||
return nil, fmt.Errorf("CreateCompatibleDC: failed")
|
||||
}
|
||||
defer func() { _, _, _ = procDeleteDC.Call(hdcMem) }()
|
||||
|
||||
var bih winBitmapInfoHeader
|
||||
bih.BiSize = dibSectionBytes
|
||||
bih.BiWidth = w
|
||||
bih.BiHeight = -h // top-down
|
||||
bih.BiPlanes = 1
|
||||
bih.BiBitCount = 32
|
||||
bih.BiCompression = biRgb
|
||||
|
||||
buf := make([]byte, int(w)*int(h)*4)
|
||||
r, _, err := procGetDIBits.Call(
|
||||
hdcMem,
|
||||
uintptr(hbm),
|
||||
0,
|
||||
uintptr(h),
|
||||
uintptr(unsafe.Pointer(&buf[0])),
|
||||
uintptr(unsafe.Pointer(&bih)),
|
||||
diRgbColors,
|
||||
)
|
||||
if r == 0 {
|
||||
return nil, fmt.Errorf("GetDIBits: %w", err)
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// decodeColorCursor reads a 32bpp colour cursor and folds the AND mask into
|
||||
// the alpha channel when the colour bitmap leaves it zero.
|
||||
func decodeColorCursor(hbmColor, hbmMask windows.Handle) (*image.RGBA, error) {
|
||||
bm, err := readBitmap(hbmColor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w, h := bm.BmWidth, bm.BmHeight
|
||||
color, err := dibCopy(hbmColor, w, h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var mask []byte
|
||||
if hbmMask != 0 {
|
||||
mask, _ = dibCopy(hbmMask, w, h)
|
||||
}
|
||||
hasAlpha := colorHasAlpha(color)
|
||||
img := image.NewRGBA(image.Rect(0, 0, int(w), int(h)))
|
||||
for y := int32(0); y < h; y++ {
|
||||
for x := int32(0); x < w; x++ {
|
||||
si := (y*w + x) * 4
|
||||
b := color[si]
|
||||
g := color[si+1]
|
||||
r := color[si+2]
|
||||
a := pixelAlpha(color[si+3], si, mask, hasAlpha)
|
||||
// Premultiply so the shared compositor can use the same
|
||||
// formula on every platform (X11 XFixes and macOS CG return
|
||||
// premultiplied bytes natively).
|
||||
if a != 255 && a != 0 {
|
||||
r = byte(uint32(r) * uint32(a) / 255)
|
||||
g = byte(uint32(g) * uint32(a) / 255)
|
||||
b = byte(uint32(b) * uint32(a) / 255)
|
||||
} else if a == 0 {
|
||||
r, g, b = 0, 0, 0
|
||||
}
|
||||
img.Pix[si+0] = r
|
||||
img.Pix[si+1] = g
|
||||
img.Pix[si+2] = b
|
||||
img.Pix[si+3] = a
|
||||
}
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// colorHasAlpha reports whether any pixel of a 32bpp BGRA buffer has a
|
||||
// non-zero alpha. Cursors authored without alpha leave the channel at 0
|
||||
// and rely on hbmMask for transparency.
|
||||
func colorHasAlpha(color []byte) bool {
|
||||
for i := 0; i < len(color); i += 4 {
|
||||
if color[i+3] != 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// pixelAlpha returns the effective alpha for a colour-cursor pixel. When
|
||||
// the source bitmap already has alpha we trust it; otherwise the AND mask
|
||||
// decides (1 = transparent, 0 = opaque). The 32bpp DIB stores each AND
|
||||
// bit as a 4-byte entry; the first byte carries the effective value.
|
||||
func pixelAlpha(colorA byte, si int32, mask []byte, hasAlpha bool) byte {
|
||||
if hasAlpha {
|
||||
return colorA
|
||||
}
|
||||
if mask != nil && mask[si] != 0 {
|
||||
return 0
|
||||
}
|
||||
return 255
|
||||
}
|
||||
|
||||
// decodeMonoCursor handles legacy 1bpp cursors where hbmMask is twice as
|
||||
// tall as the visible sprite: rows [0..h) are the AND mask and rows [h..2h)
|
||||
// are the XOR mask. We render the visible half into RGBA, treating
|
||||
// AND-mask=1 as transparent and the XOR bit as a black/white pixel.
|
||||
func decodeMonoCursor(hbmMask windows.Handle) (*image.RGBA, error) {
|
||||
bm, err := readBitmap(hbmMask)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w, fullH := bm.BmWidth, bm.BmHeight
|
||||
if fullH%2 != 0 {
|
||||
return nil, fmt.Errorf("unexpected mono cursor shape: %dx%d", w, fullH)
|
||||
}
|
||||
h := fullH / 2
|
||||
data, err := dibCopy(hbmMask, w, fullH)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
img := image.NewRGBA(image.Rect(0, 0, int(w), int(h)))
|
||||
for y := int32(0); y < h; y++ {
|
||||
for x := int32(0); x < w; x++ {
|
||||
and := data[(y*w+x)*4]
|
||||
xor := data[((y+h)*w+x)*4]
|
||||
di := (y*w + x) * 4
|
||||
if and != 0 {
|
||||
img.Pix[di+3] = 0
|
||||
continue
|
||||
}
|
||||
c := byte(0)
|
||||
if xor != 0 {
|
||||
c = 255
|
||||
}
|
||||
img.Pix[di+0] = c
|
||||
img.Pix[di+1] = c
|
||||
img.Pix[di+2] = c
|
||||
img.Pix[di+3] = 255
|
||||
}
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// cursorState is the latest snapshot shared between the worker and
|
||||
// session readers.
|
||||
type cursorState struct {
|
||||
mu sync.Mutex
|
||||
snapshot *cursorSnapshot
|
||||
}
|
||||
|
||||
func (s *cursorState) store(snap *cursorSnapshot) {
|
||||
s.mu.Lock()
|
||||
s.snapshot = snap
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *cursorState) load() *cursorSnapshot {
|
||||
s.mu.Lock()
|
||||
snap := s.snapshot
|
||||
s.mu.Unlock()
|
||||
return snap
|
||||
}
|
||||
|
||||
// Cursor satisfies cursorSource by returning the latest snapshot the
|
||||
// capture worker decoded. The "no sample yet" and "cursor hidden" cases
|
||||
// return img=nil with no error so callers skip emission this cycle
|
||||
// without latching the source off for the rest of the session.
|
||||
func (c *DesktopCapturer) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
snap := c.cursorState.load()
|
||||
if snap == nil {
|
||||
return nil, 0, 0, 0, nil
|
||||
}
|
||||
if snap.err != nil {
|
||||
return nil, 0, 0, 0, snap.err
|
||||
}
|
||||
return snap.img, snap.hotX, snap.hotY, snap.serial, nil
|
||||
}
|
||||
|
||||
// CursorPos returns the cursor screen position observed by the worker on
|
||||
// its last sample. Errors out if the worker hasn't yet captured a frame
|
||||
// or the most recent sample failed.
|
||||
func (c *DesktopCapturer) CursorPos() (int, int, error) {
|
||||
snap := c.cursorState.load()
|
||||
if snap == nil {
|
||||
return 0, 0, fmt.Errorf("cursor position not sampled yet")
|
||||
}
|
||||
if snap.err != nil {
|
||||
return 0, 0, snap.err
|
||||
}
|
||||
if !snap.hasPos {
|
||||
return 0, 0, fmt.Errorf("cursor position unavailable")
|
||||
}
|
||||
return snap.posX, snap.posY, nil
|
||||
}
|
||||
127
client/vnc/server/cursor_x11.go
Normal file
127
client/vnc/server/cursor_x11.go
Normal file
@@ -0,0 +1,127 @@
|
||||
//go:build unix && !darwin && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"sync"
|
||||
|
||||
"github.com/jezek/xgb"
|
||||
"github.com/jezek/xgb/xfixes"
|
||||
)
|
||||
|
||||
// xfixesCursor reports the current X cursor sprite via the XFixes extension.
|
||||
// CursorSerial changes whenever the server picks a different cursor, so
|
||||
// callers can cache by serial without comparing pixels.
|
||||
type xfixesCursor struct {
|
||||
mu sync.Mutex
|
||||
conn *xgb.Conn
|
||||
// lastPosX/lastPosY hold the cursor screen position observed on the
|
||||
// most recent successful GetCursorImage. cursorPositionSource readers
|
||||
// share this value so we do not pay a second X round-trip per frame.
|
||||
lastPosX, lastPosY int
|
||||
hasPos bool
|
||||
// lastImg, lastHotX, lastHotY, lastSerial cache the most recent good
|
||||
// GetCursorImage result so transient failures (cursor hidden, server
|
||||
// briefly unresponsive) reuse the previous sprite instead of going
|
||||
// dark. Without this the encoder's compositing path drops to no-op as
|
||||
// soon as the cursor becomes momentarily unavailable.
|
||||
lastImg *image.RGBA
|
||||
lastHotX int
|
||||
lastHotY int
|
||||
lastSerial uint64
|
||||
}
|
||||
|
||||
// newXFixesCursor initialises the XFixes extension on conn. Returns an
|
||||
// error if the extension is unavailable; callers can fall back to no
|
||||
// cursor emission instead of asking on every frame.
|
||||
func newXFixesCursor(conn *xgb.Conn) (*xfixesCursor, error) {
|
||||
if err := xfixes.Init(conn); err != nil {
|
||||
return nil, fmt.Errorf("xfixes init: %w", err)
|
||||
}
|
||||
if _, err := xfixes.QueryVersion(conn, 4, 0).Reply(); err != nil {
|
||||
return nil, fmt.Errorf("xfixes query version: %w", err)
|
||||
}
|
||||
return &xfixesCursor{conn: conn}, nil
|
||||
}
|
||||
|
||||
// Cursor returns the current cursor sprite as RGBA along with its hotspot
|
||||
// and serial. Callers should treat an unchanged serial as "no update". On
|
||||
// a transient GetCursorImage failure the last cached sprite is returned
|
||||
// so compositing keeps painting the cursor instead of disappearing.
|
||||
func (c *xfixesCursor) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
reply, err := xfixes.GetCursorImage(c.conn).Reply()
|
||||
if err != nil {
|
||||
if c.lastImg != nil {
|
||||
return c.lastImg, c.lastHotX, c.lastHotY, c.lastSerial, nil
|
||||
}
|
||||
return nil, 0, 0, 0, fmt.Errorf("xfixes GetCursorImage: %w", err)
|
||||
}
|
||||
c.lastPosX, c.lastPosY, c.hasPos = int(reply.X), int(reply.Y), true
|
||||
w, h := int(reply.Width), int(reply.Height)
|
||||
if w <= 0 || h <= 0 {
|
||||
if c.lastImg != nil {
|
||||
return c.lastImg, c.lastHotX, c.lastHotY, c.lastSerial, nil
|
||||
}
|
||||
return nil, 0, 0, 0, fmt.Errorf("cursor has zero extent")
|
||||
}
|
||||
if len(reply.CursorImage) < w*h {
|
||||
if c.lastImg != nil {
|
||||
return c.lastImg, c.lastHotX, c.lastHotY, c.lastSerial, nil
|
||||
}
|
||||
return nil, 0, 0, 0, fmt.Errorf("cursor pixel buffer truncated: %d < %d", len(reply.CursorImage), w*h)
|
||||
}
|
||||
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
// XFixes packs each pixel as a uint32 in ARGB order with premultiplied
|
||||
// alpha. Unpack into the standard RGBA byte layout.
|
||||
for i, p := range reply.CursorImage[:w*h] {
|
||||
o := i * 4
|
||||
img.Pix[o+0] = byte(p >> 16)
|
||||
img.Pix[o+1] = byte(p >> 8)
|
||||
img.Pix[o+2] = byte(p)
|
||||
img.Pix[o+3] = byte(p >> 24)
|
||||
}
|
||||
c.lastImg = img
|
||||
c.lastHotX = int(reply.Xhot)
|
||||
c.lastHotY = int(reply.Yhot)
|
||||
c.lastSerial = uint64(reply.CursorSerial)
|
||||
return img, c.lastHotX, c.lastHotY, c.lastSerial, nil
|
||||
}
|
||||
|
||||
// Cursor on X11Capturer satisfies cursorSource. The XFixes binding is
|
||||
// created lazily on the same X connection used for screen capture; the
|
||||
// first init failure is latched so we stop asking on every frame.
|
||||
func (x *X11Capturer) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
x.mu.Lock()
|
||||
if x.cursor == nil && x.cursorInitErr == nil {
|
||||
x.cursor, x.cursorInitErr = newXFixesCursor(x.conn)
|
||||
}
|
||||
cur := x.cursor
|
||||
initErr := x.cursorInitErr
|
||||
x.mu.Unlock()
|
||||
if initErr != nil {
|
||||
return nil, 0, 0, 0, initErr
|
||||
}
|
||||
return cur.Cursor()
|
||||
}
|
||||
|
||||
// CursorPos on X11Capturer returns the screen position from the most
|
||||
// recent successful Cursor() call. Sessions call Cursor() once per encode
|
||||
// cycle, so this stays current without a second X round-trip.
|
||||
func (x *X11Capturer) CursorPos() (int, int, error) {
|
||||
x.mu.Lock()
|
||||
cur := x.cursor
|
||||
x.mu.Unlock()
|
||||
if cur == nil {
|
||||
return 0, 0, fmt.Errorf("cursor source not initialised")
|
||||
}
|
||||
cur.mu.Lock()
|
||||
defer cur.mu.Unlock()
|
||||
if !cur.hasPos {
|
||||
return 0, 0, fmt.Errorf("cursor position not sampled yet")
|
||||
}
|
||||
return cur.lastPosX, cur.lastPosY, nil
|
||||
}
|
||||
159
client/vnc/server/extclipboard.go
Normal file
159
client/vnc/server/extclipboard.go
Normal file
@@ -0,0 +1,159 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/zlib"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// ExtendedClipboard is an RFB community extension (pseudo-encoding
|
||||
// 0xC0A1E5CE) that replaces legacy CutText with a Caps/Notify/Request/
|
||||
// Provide/Peek handshake. Wins versus legacy CutText:
|
||||
// - UTF-8 text format (legacy is Latin-1).
|
||||
// - Pull-based: a Notify announces "I have new content", the peer fetches
|
||||
// via Request only when it actually needs the data. Saves bandwidth on
|
||||
// high-latency transports versus pushing every change.
|
||||
// - zlib-compressed payloads.
|
||||
// - Caps negotiation so each side knows the other's per-format max size.
|
||||
//
|
||||
// The extension reuses message opcodes 3 (ServerCutText) and 6 (ClientCutText)
|
||||
// and signals "extended" by encoding the length field as a negative int32;
|
||||
// the absolute value is the payload size in bytes. The first 4 bytes of
|
||||
// payload are a flags word: top byte is the action, low 16 bits are the
|
||||
// format mask.
|
||||
const pseudoEncExtendedClipboard = -1063131698 // 0xC0A1E5CE as int32
|
||||
|
||||
const (
|
||||
extClipActionCaps uint32 = 0x01000000
|
||||
extClipActionRequest uint32 = 0x02000000
|
||||
extClipActionPeek uint32 = 0x04000000
|
||||
extClipActionNotify uint32 = 0x08000000
|
||||
extClipActionProvide uint32 = 0x10000000
|
||||
extClipActionMask uint32 = 0x1F000000
|
||||
|
||||
extClipFormatText uint32 = 0x00000001
|
||||
extClipFormatRTF uint32 = 0x00000002
|
||||
extClipFormatHTML uint32 = 0x00000004
|
||||
extClipFormatDIB uint32 = 0x00000008
|
||||
extClipFormatFiles uint32 = 0x00000010
|
||||
extClipFormatMask uint32 = 0x0000FFFF
|
||||
|
||||
// extClipMaxText caps our accepted text payload. Mirrors the legacy
|
||||
// maxCutTextBytes (1 MiB); advertised in Caps and enforced on Provide.
|
||||
extClipMaxText = maxCutTextBytes
|
||||
|
||||
// extClipMaxPayload bounds the raw on-wire payload we will read for an
|
||||
// extended CutText message. Includes flags header, length prefixes, NUL,
|
||||
// and zlib framing overhead on top of the text body.
|
||||
extClipMaxPayload = extClipMaxText + 1024
|
||||
)
|
||||
|
||||
// buildExtClipCaps emits the Caps payload. The flags word advertises every
|
||||
// action we support in the high byte (Caps + Request + Peek + Notify +
|
||||
// Provide) and every format we accept in the low 16 bits. Clients use
|
||||
// these action bits to decide whether to auto-Request on Notify; without
|
||||
// Request in our Caps a conforming client silently drops our Notify
|
||||
// messages. After the flags word we emit one uint32 max size per format
|
||||
// bit set, in ascending bit order.
|
||||
func buildExtClipCaps() []byte {
|
||||
flags := extClipActionCaps | extClipActionRequest | extClipActionPeek |
|
||||
extClipActionNotify | extClipActionProvide | extClipFormatText
|
||||
payload := make([]byte, 4+4)
|
||||
binary.BigEndian.PutUint32(payload[0:4], flags)
|
||||
binary.BigEndian.PutUint32(payload[4:8], uint32(extClipMaxText))
|
||||
return payload
|
||||
}
|
||||
|
||||
// buildExtClipNotify emits a Notify announcing that we have new clipboard
|
||||
// content available in the given format mask. No data is shipped; the peer
|
||||
// pulls via Request when it actually needs to paste.
|
||||
func buildExtClipNotify(formats uint32) []byte {
|
||||
payload := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(payload, extClipActionNotify|formats)
|
||||
return payload
|
||||
}
|
||||
|
||||
// buildExtClipRequest emits a Request asking the peer to send Provide for
|
||||
// the given format mask. Sent in response to an inbound Notify.
|
||||
func buildExtClipRequest(formats uint32) []byte {
|
||||
payload := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(payload, extClipActionRequest|formats)
|
||||
return payload
|
||||
}
|
||||
|
||||
// buildExtClipProvideText emits a Provide carrying UTF-8 text. The inner
|
||||
// stream (4-byte length including the trailing NUL, then UTF-8 bytes, then
|
||||
// NUL) is zlib-compressed; each Provide uses an independent zlib context
|
||||
// per the extension spec. Rejects oversized input so a caller bug can't
|
||||
// produce a payload larger than the size advertised in our Caps.
|
||||
func buildExtClipProvideText(text string) ([]byte, error) {
|
||||
if len(text)+1 > extClipMaxText {
|
||||
return nil, fmt.Errorf("clipboard text exceeds extClipMaxText (%d > %d)", len(text)+1, extClipMaxText)
|
||||
}
|
||||
body := make([]byte, 0, 4+len(text)+1)
|
||||
var lenBuf [4]byte
|
||||
binary.BigEndian.PutUint32(lenBuf[:], uint32(len(text)+1))
|
||||
body = append(body, lenBuf[:]...)
|
||||
body = append(body, text...)
|
||||
body = append(body, 0)
|
||||
|
||||
var compressed bytes.Buffer
|
||||
zw := zlib.NewWriter(&compressed)
|
||||
if _, err := zw.Write(body); err != nil {
|
||||
return nil, fmt.Errorf("zlib write: %w", err)
|
||||
}
|
||||
if err := zw.Close(); err != nil {
|
||||
return nil, fmt.Errorf("zlib close: %w", err)
|
||||
}
|
||||
|
||||
payload := make([]byte, 4+compressed.Len())
|
||||
binary.BigEndian.PutUint32(payload[0:4], extClipActionProvide|extClipFormatText)
|
||||
copy(payload[4:], compressed.Bytes())
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// parseExtClipProvideText decompresses a Provide payload (the bytes after
|
||||
// the 4-byte flags header) and returns the UTF-8 text record if the text
|
||||
// format bit is set. Records for other formats are skipped. The trailing
|
||||
// NUL byte the spec appends to text records is stripped.
|
||||
func parseExtClipProvideText(flags uint32, payload []byte) (string, error) {
|
||||
zr, err := zlib.NewReader(bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("zlib reader: %w", err)
|
||||
}
|
||||
defer zr.Close()
|
||||
|
||||
limited := io.LimitReader(zr, int64(extClipMaxText)+16)
|
||||
var text string
|
||||
for bit := uint32(1); bit <= extClipFormatFiles; bit <<= 1 {
|
||||
if flags&bit == 0 {
|
||||
continue
|
||||
}
|
||||
var sizeBuf [4]byte
|
||||
if _, err := io.ReadFull(limited, sizeBuf[:]); err != nil {
|
||||
if bit == extClipFormatText && err == io.EOF {
|
||||
return "", nil
|
||||
}
|
||||
return "", fmt.Errorf("read record size: %w", err)
|
||||
}
|
||||
size := binary.BigEndian.Uint32(sizeBuf[:])
|
||||
if size > uint32(extClipMaxText) {
|
||||
return "", fmt.Errorf("record too large: %d", size)
|
||||
}
|
||||
rec := make([]byte, size)
|
||||
if _, err := io.ReadFull(limited, rec); err != nil {
|
||||
return "", fmt.Errorf("read record: %w", err)
|
||||
}
|
||||
if bit == extClipFormatText {
|
||||
if len(rec) > 0 && rec[len(rec)-1] == 0 {
|
||||
rec = rec[:len(rec)-1]
|
||||
}
|
||||
text = string(rec)
|
||||
}
|
||||
}
|
||||
return text, nil
|
||||
}
|
||||
102
client/vnc/server/extclipboard_test.go
Normal file
102
client/vnc/server/extclipboard_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildExtClipCaps(t *testing.T) {
|
||||
payload := buildExtClipCaps()
|
||||
require.Len(t, payload, 8, "Caps with one format should be 4 bytes flags + 4 bytes size")
|
||||
|
||||
flags := binary.BigEndian.Uint32(payload[0:4])
|
||||
// Clients check individual action bits in our Caps to decide whether to
|
||||
// auto-Request on Notify, so all supported actions must be advertised.
|
||||
assert.NotZero(t, flags&extClipActionCaps, "Caps action bit must be set")
|
||||
assert.NotZero(t, flags&extClipActionRequest, "Request action bit must be set")
|
||||
assert.NotZero(t, flags&extClipActionPeek, "Peek action bit must be set")
|
||||
assert.NotZero(t, flags&extClipActionNotify, "Notify action bit must be set")
|
||||
assert.NotZero(t, flags&extClipActionProvide, "Provide action bit must be set")
|
||||
assert.Equal(t, extClipFormatText, flags&extClipFormatMask, "should advertise text format")
|
||||
|
||||
maxSize := binary.BigEndian.Uint32(payload[4:8])
|
||||
assert.Equal(t, uint32(extClipMaxText), maxSize, "should advertise extClipMaxText")
|
||||
}
|
||||
|
||||
func TestBuildExtClipNotify(t *testing.T) {
|
||||
payload := buildExtClipNotify(extClipFormatText)
|
||||
require.Len(t, payload, 4)
|
||||
flags := binary.BigEndian.Uint32(payload)
|
||||
assert.Equal(t, extClipActionNotify, flags&extClipActionMask)
|
||||
assert.Equal(t, extClipFormatText, flags&extClipFormatMask)
|
||||
}
|
||||
|
||||
func TestBuildExtClipRequest(t *testing.T) {
|
||||
payload := buildExtClipRequest(extClipFormatText)
|
||||
require.Len(t, payload, 4)
|
||||
flags := binary.BigEndian.Uint32(payload)
|
||||
assert.Equal(t, extClipActionRequest, flags&extClipActionMask)
|
||||
assert.Equal(t, extClipFormatText, flags&extClipFormatMask)
|
||||
}
|
||||
|
||||
func TestExtClipProvideRoundTripASCII(t *testing.T) {
|
||||
const original = "hello world"
|
||||
payload, err := buildExtClipProvideText(original)
|
||||
require.NoError(t, err)
|
||||
|
||||
flags := binary.BigEndian.Uint32(payload[0:4])
|
||||
require.Equal(t, extClipActionProvide, flags&extClipActionMask)
|
||||
require.Equal(t, extClipFormatText, flags&extClipFormatMask)
|
||||
|
||||
text, err := parseExtClipProvideText(flags, payload[4:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, original, text)
|
||||
}
|
||||
|
||||
func TestExtClipProvideRoundTripUTF8(t *testing.T) {
|
||||
original := "héllo 🦀 世界"
|
||||
payload, err := buildExtClipProvideText(original)
|
||||
require.NoError(t, err)
|
||||
|
||||
flags := binary.BigEndian.Uint32(payload[0:4])
|
||||
text, err := parseExtClipProvideText(flags, payload[4:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, original, text, "UTF-8 should round-trip without mangling")
|
||||
}
|
||||
|
||||
func TestExtClipProvideRoundTripEmpty(t *testing.T) {
|
||||
payload, err := buildExtClipProvideText("")
|
||||
require.NoError(t, err)
|
||||
|
||||
flags := binary.BigEndian.Uint32(payload[0:4])
|
||||
text, err := parseExtClipProvideText(flags, payload[4:])
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, text)
|
||||
}
|
||||
|
||||
func TestExtClipProvideRoundTripLarge(t *testing.T) {
|
||||
original := strings.Repeat("abcd", 200000) // 800 KiB, below cap
|
||||
payload, err := buildExtClipProvideText(original)
|
||||
require.NoError(t, err)
|
||||
assert.Less(t, len(payload), len(original)/2,
|
||||
"highly repetitive text should compress significantly")
|
||||
|
||||
flags := binary.BigEndian.Uint32(payload[0:4])
|
||||
text, err := parseExtClipProvideText(flags, payload[4:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, original, text)
|
||||
}
|
||||
|
||||
func TestParseExtClipProvideTextRejectsOversized(t *testing.T) {
|
||||
var fakePayload [4]byte
|
||||
// 4 bytes of zlib-compressed garbage won't decode; we want to ensure we
|
||||
// don't panic, not that we accept it.
|
||||
_, err := parseExtClipProvideText(extClipActionProvide|extClipFormatText, fakePayload[:])
|
||||
assert.Error(t, err)
|
||||
}
|
||||
865
client/vnc/server/input_darwin.go
Normal file
865
client/vnc/server/input_darwin.go
Normal file
@@ -0,0 +1,865 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Core Graphics event constants.
|
||||
const (
|
||||
kCGEventSourceStateCombinedSessionState int32 = 0
|
||||
|
||||
kCGEventLeftMouseDown int32 = 1
|
||||
kCGEventLeftMouseUp int32 = 2
|
||||
kCGEventRightMouseDown int32 = 3
|
||||
kCGEventRightMouseUp int32 = 4
|
||||
kCGEventMouseMoved int32 = 5
|
||||
kCGEventLeftMouseDragged int32 = 6
|
||||
kCGEventRightMouseDragged int32 = 7
|
||||
kCGEventKeyDown int32 = 10
|
||||
kCGEventKeyUp int32 = 11
|
||||
kCGEventFlagsChanged int32 = 12
|
||||
kCGEventOtherMouseDown int32 = 25
|
||||
kCGEventOtherMouseUp int32 = 26
|
||||
|
||||
kCGMouseButtonLeft int32 = 0
|
||||
kCGMouseButtonRight int32 = 1
|
||||
kCGMouseButtonCenter int32 = 2
|
||||
|
||||
kCGHIDEventTap int32 = 0
|
||||
|
||||
// kCGEventFlagMaskSecondaryFn is the CGEventFlags bit Apple sets when
|
||||
// a key was activated via the Fn modifier on internal keyboards. The
|
||||
// navigation cluster (ForwardDelete, Home, End, PageUp, PageDown,
|
||||
// Help/Insert, arrows) lives in the Fn-shifted region of an Apple
|
||||
// keyboard, so synthesising those keycodes without this bit leaves the
|
||||
// system in a confused "Fn implied" state where the next plain
|
||||
// letter is treated as a menu accelerator.
|
||||
kCGEventFlagMaskSecondaryFn uint64 = 0x00800000
|
||||
|
||||
// kCGMouseEventClickState (event field 1) tells macOS how many
|
||||
// consecutive clicks of this button have happened. Without it, a
|
||||
// double click looks like two independent single clicks and apps
|
||||
// never see the dblclick (window-bar maximize, text word-select, ...).
|
||||
kCGMouseEventClickState int32 = 1
|
||||
|
||||
// doubleClickWindow is the upper bound on the gap between two
|
||||
// down events that still counts as a multi-click. macOS reads the
|
||||
// user's setting from CGEventSourceGetDoubleClickInterval; 500ms is
|
||||
// the default and works as a safe injection-side ceiling.
|
||||
doubleClickWindow = 500 * time.Millisecond
|
||||
|
||||
// IOKit power management constants.
|
||||
kIOPMUserActiveLocal int32 = 0
|
||||
kIOPMAssertionLevelOn uint32 = 255
|
||||
kCFStringEncodingUTF8 uint32 = 0x08000100
|
||||
)
|
||||
|
||||
var darwinInputOnce sync.Once
|
||||
|
||||
var (
|
||||
cgEventSourceCreate func(int32) uintptr
|
||||
cgEventCreateKeyboardEvent func(uintptr, uint16, bool) uintptr
|
||||
// CGEventCreateMouseEvent takes CGPoint as two separate float64 args.
|
||||
// purego can't handle array/struct types but individual float64s work.
|
||||
cgEventCreateMouseEvent func(uintptr, int32, float64, float64, int32) uintptr
|
||||
cgEventPost func(int32, uintptr)
|
||||
cgEventSetIntegerValueField func(uintptr, int32, int64)
|
||||
cgEventSetFlags func(uintptr, uint64)
|
||||
cgEventSetType func(uintptr, int32)
|
||||
cgEventCreateForInput func(uintptr) uintptr
|
||||
|
||||
// CGEventCreateScrollWheelEvent is variadic, call via SyscallN.
|
||||
cgEventCreateScrollWheelEventAddr uintptr
|
||||
|
||||
axIsProcessTrusted func() bool
|
||||
// axIsProcessTrustedWithOptions takes a CFDictionary; when the dict's
|
||||
// kAXTrustedCheckOptionPrompt key is true, macOS shows the native
|
||||
// Accessibility prompt with an "Open System Settings" button the
|
||||
// first time the process asks. The bare AXIsProcessTrusted variant is
|
||||
// a silent check that never prompts.
|
||||
axIsProcessTrustedWithOptions func(uintptr) bool
|
||||
// cfDictionaryCreate builds the options dictionary above.
|
||||
cfDictionaryCreate func(uintptr, *uintptr, *uintptr, int64, uintptr, uintptr) uintptr
|
||||
// cfBooleanTrue is the global CF boolean we cache from a Dlsym lookup.
|
||||
cfBooleanTrue uintptr
|
||||
// axTrustedCheckOptionPromptCFStr is the option key for the dict.
|
||||
axTrustedCheckOptionPromptCFStr uintptr
|
||||
// kCFTypeDictionaryKey/Value CallBacks: standard CF retain/release
|
||||
// callback tables. Required so the dict properly manages refcounts on
|
||||
// the CFString key and CFBoolean value.
|
||||
kCFTypeDictionaryKeyCallBacksAddr uintptr
|
||||
kCFTypeDictionaryValueCallBacksAddr uintptr
|
||||
|
||||
// IOKit power-management bindings used to wake the display and inhibit
|
||||
// idle sleep while a VNC client is driving input.
|
||||
iopmAssertionDeclareUserActivity func(uintptr, int32, *uint32) int32
|
||||
iopmAssertionCreateWithName func(uintptr, uint32, uintptr, *uint32) int32
|
||||
iopmAssertionRelease func(uint32) int32
|
||||
cfStringCreateWithCString func(uintptr, string, uint32) uintptr
|
||||
|
||||
// Cached CFStrings for assertion name and idle-sleep type.
|
||||
pmAssertionNameCFStr uintptr
|
||||
pmPreventIdleDisplayCFStr uintptr
|
||||
|
||||
// Assertion IDs. userActivityID is reused across input events so repeated
|
||||
// calls refresh the same assertion rather than create new ones.
|
||||
pmMu sync.Mutex
|
||||
userActivityID uint32
|
||||
preventSleepID uint32
|
||||
preventSleepHeld bool
|
||||
// preventSleepRef tracks the refcount of held assertions across
|
||||
// concurrent injectors and sessions.
|
||||
preventSleepRef int
|
||||
|
||||
darwinInputReady bool
|
||||
darwinEventSource uintptr
|
||||
)
|
||||
|
||||
func initDarwinInput() {
|
||||
darwinInputOnce.Do(func() {
|
||||
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load CoreGraphics for input: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
purego.RegisterLibFunc(&cgEventSourceCreate, cg, "CGEventSourceCreate")
|
||||
purego.RegisterLibFunc(&cgEventCreateKeyboardEvent, cg, "CGEventCreateKeyboardEvent")
|
||||
purego.RegisterLibFunc(&cgEventCreateMouseEvent, cg, "CGEventCreateMouseEvent")
|
||||
purego.RegisterLibFunc(&cgEventPost, cg, "CGEventPost")
|
||||
purego.RegisterLibFunc(&cgEventSetIntegerValueField, cg, "CGEventSetIntegerValueField")
|
||||
purego.RegisterLibFunc(&cgEventSetFlags, cg, "CGEventSetFlags")
|
||||
purego.RegisterLibFunc(&cgEventSetType, cg, "CGEventSetType")
|
||||
purego.RegisterLibFunc(&cgEventCreateForInput, cg, "CGEventCreate")
|
||||
|
||||
sym, err := purego.Dlsym(cg, "CGEventCreateScrollWheelEvent")
|
||||
if err == nil {
|
||||
cgEventCreateScrollWheelEventAddr = sym
|
||||
}
|
||||
|
||||
if ax, err := purego.Dlopen("/System/Library/Frameworks/ApplicationServices.framework/ApplicationServices", purego.RTLD_NOW|purego.RTLD_GLOBAL); err == nil {
|
||||
if sym, err := purego.Dlsym(ax, "AXIsProcessTrusted"); err == nil {
|
||||
purego.RegisterFunc(&axIsProcessTrusted, sym)
|
||||
}
|
||||
if sym, err := purego.Dlsym(ax, "AXIsProcessTrustedWithOptions"); err == nil {
|
||||
purego.RegisterFunc(&axIsProcessTrustedWithOptions, sym)
|
||||
}
|
||||
}
|
||||
|
||||
// initPowerAssertions registers cfStringCreateWithCString, which
|
||||
// initCFDictionarySymbols then uses to build the AX prompt key.
|
||||
initPowerAssertions()
|
||||
initCFDictionarySymbols()
|
||||
|
||||
darwinInputReady = true
|
||||
})
|
||||
}
|
||||
|
||||
// initCFDictionarySymbols loads the CF symbols needed to build the
|
||||
// options dictionary for AXIsProcessTrustedWithOptions. Best-effort:
|
||||
// failure here just leaves axIsProcessTrustedWithOptions unusable and we
|
||||
// fall back to the silent check.
|
||||
func initCFDictionarySymbols() {
|
||||
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load CoreFoundation for AX prompt dict: %v", err)
|
||||
return
|
||||
}
|
||||
if sym, err := purego.Dlsym(cf, "CFDictionaryCreate"); err == nil {
|
||||
purego.RegisterFunc(&cfDictionaryCreate, sym)
|
||||
}
|
||||
if sym, err := purego.Dlsym(cf, "kCFTypeDictionaryKeyCallBacks"); err == nil {
|
||||
kCFTypeDictionaryKeyCallBacksAddr = sym
|
||||
}
|
||||
if sym, err := purego.Dlsym(cf, "kCFTypeDictionaryValueCallBacks"); err == nil {
|
||||
kCFTypeDictionaryValueCallBacksAddr = sym
|
||||
}
|
||||
if sym, err := purego.Dlsym(cf, "kCFBooleanTrue"); err == nil {
|
||||
// kCFBooleanTrue is a pointer-to-pointer (CFBooleanRef stored at the
|
||||
// symbol address). Dereference once to get the actual CFBoolean.
|
||||
cfBooleanTrue = *(*uintptr)(unsafe.Pointer(sym))
|
||||
}
|
||||
if cfStringCreateWithCString != nil {
|
||||
axTrustedCheckOptionPromptCFStr = cfStringCreateWithCString(0, "AXTrustedCheckOptionPrompt", kCFStringEncodingUTF8)
|
||||
}
|
||||
}
|
||||
|
||||
func initPowerAssertions() {
|
||||
iokit, err := purego.Dlopen("/System/Library/Frameworks/IOKit.framework/IOKit", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load IOKit: %v", err)
|
||||
return
|
||||
}
|
||||
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load CoreFoundation for power assertions: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
purego.RegisterLibFunc(&cfStringCreateWithCString, cf, "CFStringCreateWithCString")
|
||||
purego.RegisterLibFunc(&iopmAssertionDeclareUserActivity, iokit, "IOPMAssertionDeclareUserActivity")
|
||||
purego.RegisterLibFunc(&iopmAssertionCreateWithName, iokit, "IOPMAssertionCreateWithName")
|
||||
purego.RegisterLibFunc(&iopmAssertionRelease, iokit, "IOPMAssertionRelease")
|
||||
|
||||
pmAssertionNameCFStr = cfStringCreateWithCString(0, "NetBird VNC input", kCFStringEncodingUTF8)
|
||||
pmPreventIdleDisplayCFStr = cfStringCreateWithCString(0, "PreventUserIdleDisplaySleep", kCFStringEncodingUTF8)
|
||||
}
|
||||
|
||||
// wakeDisplay declares user activity so macOS treats the synthesized input as
|
||||
// real HID activity, waking the display if it is asleep. Called on every key
|
||||
// and pointer event; the kernel coalesces repeated calls cheaply.
|
||||
func wakeDisplay() {
|
||||
if iopmAssertionDeclareUserActivity == nil || pmAssertionNameCFStr == 0 {
|
||||
return
|
||||
}
|
||||
pmMu.Lock()
|
||||
defer pmMu.Unlock()
|
||||
id := userActivityID
|
||||
r := iopmAssertionDeclareUserActivity(pmAssertionNameCFStr, kIOPMUserActiveLocal, &id)
|
||||
if r != 0 {
|
||||
log.Tracef("IOPMAssertionDeclareUserActivity returned %d", r)
|
||||
return
|
||||
}
|
||||
userActivityID = id
|
||||
}
|
||||
|
||||
// holdPreventIdleSleep creates an assertion that keeps the display from going
|
||||
// idle-to-sleep while a VNC session is active. Reference-counted so multiple
|
||||
// concurrent sessions don't yank the assertion when one of them releases.
|
||||
func holdPreventIdleSleep() {
|
||||
if iopmAssertionCreateWithName == nil || pmPreventIdleDisplayCFStr == 0 || pmAssertionNameCFStr == 0 {
|
||||
return
|
||||
}
|
||||
pmMu.Lock()
|
||||
defer pmMu.Unlock()
|
||||
preventSleepRef++
|
||||
if preventSleepRef > 1 {
|
||||
return
|
||||
}
|
||||
var id uint32
|
||||
r := iopmAssertionCreateWithName(pmPreventIdleDisplayCFStr, kIOPMAssertionLevelOn, pmAssertionNameCFStr, &id)
|
||||
if r != 0 {
|
||||
log.Debugf("IOPMAssertionCreateWithName returned %d", r)
|
||||
// Reset the refcount on failure so a later successful hold can take it.
|
||||
preventSleepRef = 0
|
||||
return
|
||||
}
|
||||
preventSleepID = id
|
||||
preventSleepHeld = true
|
||||
}
|
||||
|
||||
// releasePreventIdleSleep decrements the assertion refcount and only drops
|
||||
// the actual IOKit assertion on the final release.
|
||||
func releasePreventIdleSleep() {
|
||||
if iopmAssertionRelease == nil {
|
||||
return
|
||||
}
|
||||
pmMu.Lock()
|
||||
defer pmMu.Unlock()
|
||||
if !preventSleepHeld || preventSleepRef == 0 {
|
||||
return
|
||||
}
|
||||
preventSleepRef--
|
||||
if preventSleepRef > 0 {
|
||||
return
|
||||
}
|
||||
if r := iopmAssertionRelease(preventSleepID); r != 0 {
|
||||
log.Debugf("IOPMAssertionRelease returned %d", r)
|
||||
}
|
||||
preventSleepHeld = false
|
||||
preventSleepID = 0
|
||||
}
|
||||
|
||||
func ensureEventSource() uintptr {
|
||||
if darwinEventSource != 0 {
|
||||
return darwinEventSource
|
||||
}
|
||||
darwinEventSource = cgEventSourceCreate(kCGEventSourceStateCombinedSessionState)
|
||||
return darwinEventSource
|
||||
}
|
||||
|
||||
// MacInputInjector injects keyboard and mouse events via Core Graphics.
|
||||
type MacInputInjector struct {
|
||||
lastButtons uint16
|
||||
pbcopyPath string
|
||||
pbpastePath string
|
||||
// clickCount[i] / clickAt[i] track the multi-click sequence for
|
||||
// button i (0=left, 1=right, 2=middle). macOS apps reconstruct
|
||||
// double/triple click semantics from the kCGMouseEventClickState
|
||||
// field on each posted event, not from event timing.
|
||||
clickCount [5]int64
|
||||
clickAt [5]time.Time
|
||||
}
|
||||
|
||||
// NewMacInputInjector creates a macOS input injector.
|
||||
func NewMacInputInjector() (*MacInputInjector, error) {
|
||||
initDarwinInput()
|
||||
if !darwinInputReady {
|
||||
return nil, fmt.Errorf("CoreGraphics not available for input injection")
|
||||
}
|
||||
checkMacPermissions()
|
||||
|
||||
m := &MacInputInjector{}
|
||||
if path, err := exec.LookPath("pbcopy"); err == nil {
|
||||
m.pbcopyPath = path
|
||||
}
|
||||
if path, err := exec.LookPath("pbpaste"); err == nil {
|
||||
m.pbpastePath = path
|
||||
}
|
||||
if m.pbcopyPath == "" || m.pbpastePath == "" {
|
||||
log.Debugf("clipboard tools not found (pbcopy=%q, pbpaste=%q)", m.pbcopyPath, m.pbpastePath)
|
||||
}
|
||||
|
||||
holdPreventIdleSleep()
|
||||
|
||||
log.Info("macOS input injector ready")
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// checkMacPermissions probes Accessibility access. Prefers the prompting
|
||||
// variant of AXIsProcessTrusted: when the process is not yet trusted,
|
||||
// macOS shows its native "would like to control your computer" dialog
|
||||
// with an "Open System Settings" button. The silent variant is the
|
||||
// fallback when the prompting symbol or its CF dictionary plumbing
|
||||
// couldn't be loaded.
|
||||
func checkMacPermissions() {
|
||||
if !axProcessIsTrusted() {
|
||||
log.Warn("Accessibility permission not granted. Input injection will not work. " +
|
||||
"Approve the prompt or grant in System Settings > Privacy & Security > Accessibility.")
|
||||
openPrivacyPane("Privacy_Accessibility")
|
||||
}
|
||||
}
|
||||
|
||||
// axProcessIsTrusted asks macOS whether netbird has Accessibility access,
|
||||
// and triggers the native prompt the first time when not trusted. Returns
|
||||
// the current trust status either way.
|
||||
func axProcessIsTrusted() bool {
|
||||
if axIsProcessTrustedWithOptions != nil &&
|
||||
cfDictionaryCreate != nil &&
|
||||
axTrustedCheckOptionPromptCFStr != 0 &&
|
||||
cfBooleanTrue != 0 &&
|
||||
kCFTypeDictionaryKeyCallBacksAddr != 0 &&
|
||||
kCFTypeDictionaryValueCallBacksAddr != 0 {
|
||||
keys := [1]uintptr{axTrustedCheckOptionPromptCFStr}
|
||||
values := [1]uintptr{cfBooleanTrue}
|
||||
dict := cfDictionaryCreate(0, &keys[0], &values[0], 1,
|
||||
kCFTypeDictionaryKeyCallBacksAddr,
|
||||
kCFTypeDictionaryValueCallBacksAddr)
|
||||
if dict != 0 {
|
||||
return axIsProcessTrustedWithOptions(dict)
|
||||
}
|
||||
}
|
||||
if axIsProcessTrusted != nil {
|
||||
return axIsProcessTrusted()
|
||||
}
|
||||
// Symbol load failed entirely. Assume trusted so we don't spam the
|
||||
// log every cycle; capture/inject calls will report concrete errors
|
||||
// if access really is missing.
|
||||
return true
|
||||
}
|
||||
|
||||
// openPrivacyPane opens the relevant pane of System Settings so the user
|
||||
// can toggle the permission without navigating manually. The
|
||||
// x-apple.systempreferences URL scheme works on every macOS release from
|
||||
// 10.10 onward; the per-pane anchor (Privacy_Accessibility, Privacy_ScreenCapture)
|
||||
// is what System Settings/Preferences uses to land on the right row.
|
||||
func openPrivacyPane(pane string) {
|
||||
url := "x-apple.systempreferences:com.apple.preference.security?" + pane
|
||||
if err := exec.Command("open", url).Start(); err != nil {
|
||||
log.Debugf("open privacy pane %s: %v", pane, err)
|
||||
}
|
||||
}
|
||||
|
||||
// InjectKey simulates a key press or release.
|
||||
func (m *MacInputInjector) InjectKey(keysym uint32, down bool) {
|
||||
wakeDisplay()
|
||||
src := ensureEventSource()
|
||||
if src == 0 {
|
||||
return
|
||||
}
|
||||
keycode := keysymToMacKeycode(keysym)
|
||||
if keycode == 0xFFFF {
|
||||
return
|
||||
}
|
||||
m.postMacKey(src, keycode, down)
|
||||
}
|
||||
|
||||
// InjectKeyScancode injects using the QEMU scancode, mapped via the
|
||||
// qemuToMacVK table to Apple's virtual-keycode space. Apple uses an
|
||||
// entirely different scheme from PC AT scancodes, so the table is the
|
||||
// authoritative bridge. On miss we fall back to the keysym path.
|
||||
func (m *MacInputInjector) InjectKeyScancode(scancode, keysym uint32, down bool) {
|
||||
wakeDisplay()
|
||||
src := ensureEventSource()
|
||||
if src == 0 {
|
||||
return
|
||||
}
|
||||
vk, ok := qemuToMacVK[scancode]
|
||||
if !ok {
|
||||
// Fall back to the keysym path so unmapped keys still work.
|
||||
m.InjectKey(keysym, down)
|
||||
return
|
||||
}
|
||||
m.postMacKey(src, vk, down)
|
||||
}
|
||||
|
||||
// postMacKey emits a single key down/up event via Core Graphics. For
|
||||
// keycodes that live in the Fn-shifted region of an Apple keyboard we
|
||||
// also emit explicit flagsChanged events around the keypress: posting
|
||||
// the Fn flag on the key event alone leaves macOS's modifier state
|
||||
// machine without a matching transition, which manifests as "Fn stays
|
||||
// active" for the next key (e.g. the next letter activates a menu
|
||||
// accelerator).
|
||||
func (m *MacInputInjector) postMacKey(src uintptr, keycode uint16, down bool) {
|
||||
fnShifted := isFnShiftedKeycode(keycode)
|
||||
if fnShifted && down {
|
||||
postFnFlagsChanged(src, true)
|
||||
}
|
||||
event := cgEventCreateKeyboardEvent(src, keycode, down)
|
||||
if event == 0 {
|
||||
if fnShifted && !down {
|
||||
postFnFlagsChanged(src, false)
|
||||
}
|
||||
return
|
||||
}
|
||||
if fnShifted && cgEventSetFlags != nil {
|
||||
cgEventSetFlags(event, kCGEventFlagMaskSecondaryFn)
|
||||
}
|
||||
cgEventPost(kCGHIDEventTap, event)
|
||||
cfRelease(event)
|
||||
if fnShifted && !down {
|
||||
postFnFlagsChanged(src, false)
|
||||
}
|
||||
}
|
||||
|
||||
// postFnFlagsChanged emits a synthetic Fn modifier transition so the
|
||||
// system updates its global modifier state to match the key events we
|
||||
// post for the navigation cluster. Without this, posting a Fn-flagged
|
||||
// key event leaves macOS thinking Fn is still held after the key is
|
||||
// released.
|
||||
func postFnFlagsChanged(src uintptr, fnOn bool) {
|
||||
if cgEventCreateForInput == nil || cgEventSetType == nil || cgEventSetFlags == nil {
|
||||
return
|
||||
}
|
||||
event := cgEventCreateForInput(src)
|
||||
if event == 0 {
|
||||
return
|
||||
}
|
||||
cgEventSetType(event, kCGEventFlagsChanged)
|
||||
var flags uint64
|
||||
if fnOn {
|
||||
flags = kCGEventFlagMaskSecondaryFn
|
||||
}
|
||||
cgEventSetFlags(event, flags)
|
||||
cgEventPost(kCGHIDEventTap, event)
|
||||
cfRelease(event)
|
||||
}
|
||||
|
||||
// fnShiftedKeycodes are the Apple navigation/edit keys that hardware produces
|
||||
// with the Fn modifier held.
|
||||
var fnShiftedKeycodes = map[uint16]struct{}{
|
||||
0x72: {}, // Help / Insert
|
||||
0x73: {}, // Home
|
||||
0x74: {}, // PageUp
|
||||
0x75: {}, // ForwardDelete
|
||||
0x77: {}, // End
|
||||
0x79: {}, // PageDown
|
||||
0x7B: {}, // Left
|
||||
0x7C: {}, // Right
|
||||
0x7D: {}, // Down
|
||||
0x7E: {}, // Up
|
||||
}
|
||||
|
||||
// isFnShiftedKeycode reports whether keycode is one of the Apple
|
||||
// navigation/edit keys that hardware produces with the Fn modifier held.
|
||||
func isFnShiftedKeycode(keycode uint16) bool {
|
||||
_, ok := fnShiftedKeycodes[keycode]
|
||||
return ok
|
||||
}
|
||||
|
||||
// InjectPointer simulates mouse movement and button events.
|
||||
func (m *MacInputInjector) InjectPointer(buttonMask uint16, px, py, serverW, serverH int) {
|
||||
wakeDisplay()
|
||||
if serverW == 0 || serverH == 0 {
|
||||
return
|
||||
}
|
||||
src := ensureEventSource()
|
||||
if src == 0 {
|
||||
return
|
||||
}
|
||||
x, y := scalePxToLogical(px, py, serverW, serverH)
|
||||
m.dispatchPointer(src, buttonMask, x, y)
|
||||
m.lastButtons = buttonMask
|
||||
}
|
||||
|
||||
// scalePxToLogical converts framebuffer coordinates (physical pixels) into
|
||||
// the logical points CGEventCreateMouseEvent expects. Falls back to a 1:1
|
||||
// mapping if the display API is unavailable.
|
||||
func scalePxToLogical(px, py, serverW, serverH int) (float64, float64) {
|
||||
x, y := float64(px), float64(py)
|
||||
if cgDisplayPixelsWide == nil || cgMainDisplayID == nil {
|
||||
return x, y
|
||||
}
|
||||
displayID := cgMainDisplayID()
|
||||
logicalW := int(cgDisplayPixelsWide(displayID))
|
||||
logicalH := int(cgDisplayPixelsHigh(displayID))
|
||||
if logicalW <= 0 || logicalH <= 0 {
|
||||
return x, y
|
||||
}
|
||||
return float64(px) * float64(logicalW) / float64(serverW),
|
||||
float64(py) * float64(logicalH) / float64(serverH)
|
||||
}
|
||||
|
||||
func (m *MacInputInjector) dispatchPointer(src uintptr, buttonMask uint16, x, y float64) {
|
||||
leftDown := buttonMask&0x01 != 0
|
||||
rightDown := buttonMask&0x04 != 0
|
||||
middleDown := buttonMask&0x02 != 0
|
||||
m.postMoveOrDrag(src, leftDown, rightDown, x, y)
|
||||
m.postButtonTransitions(src, buttonMask, x, y)
|
||||
m.postScrollWheel(src, buttonMask)
|
||||
_ = middleDown
|
||||
}
|
||||
|
||||
func (m *MacInputInjector) postMoveOrDrag(src uintptr, leftDown, rightDown bool, x, y float64) {
|
||||
switch {
|
||||
case leftDown:
|
||||
m.postMouse(src, kCGEventLeftMouseDragged, x, y, kCGMouseButtonLeft)
|
||||
case rightDown:
|
||||
m.postMouse(src, kCGEventRightMouseDragged, x, y, kCGMouseButtonRight)
|
||||
default:
|
||||
m.postMouse(src, kCGEventMouseMoved, x, y, kCGMouseButtonLeft)
|
||||
}
|
||||
}
|
||||
|
||||
// postButtonTransitions emits the up/down events for each button whose
|
||||
// state changed against m.lastButtons, computing the click count so
|
||||
// macOS recognises double / triple clicks.
|
||||
func (m *MacInputInjector) postButtonTransitions(src uintptr, buttonMask uint16, x, y float64) {
|
||||
emit := func(curBit, prevBit uint16, down, up int32, button int32, idx int) {
|
||||
cur := buttonMask&curBit != 0
|
||||
prev := m.lastButtons&prevBit != 0
|
||||
if cur && !prev {
|
||||
now := time.Now()
|
||||
if !m.clickAt[idx].IsZero() && now.Sub(m.clickAt[idx]) <= doubleClickWindow {
|
||||
m.clickCount[idx]++
|
||||
} else {
|
||||
m.clickCount[idx] = 1
|
||||
}
|
||||
m.clickAt[idx] = now
|
||||
m.postMouseClick(src, down, x, y, button, m.clickCount[idx])
|
||||
} else if !cur && prev {
|
||||
count := m.clickCount[idx]
|
||||
if count == 0 {
|
||||
count = 1
|
||||
}
|
||||
m.postMouseClick(src, up, x, y, button, count)
|
||||
}
|
||||
}
|
||||
emit(0x01, 0x01, kCGEventLeftMouseDown, kCGEventLeftMouseUp, kCGMouseButtonLeft, 0)
|
||||
emit(0x04, 0x04, kCGEventRightMouseDown, kCGEventRightMouseUp, kCGMouseButtonRight, 1)
|
||||
emit(0x02, 0x02, kCGEventOtherMouseDown, kCGEventOtherMouseUp, kCGMouseButtonCenter, 2)
|
||||
// CG mouse-button numbers 3 (back) and 4 (forward) are emitted as
|
||||
// "other" events; macOS apps that swallow Browser nav (Finder, web
|
||||
// views) react to these directly.
|
||||
emit(1<<7, 1<<7, kCGEventOtherMouseDown, kCGEventOtherMouseUp, 3, 3)
|
||||
emit(1<<8, 1<<8, kCGEventOtherMouseDown, kCGEventOtherMouseUp, 4, 4)
|
||||
}
|
||||
|
||||
func (m *MacInputInjector) postScrollWheel(src uintptr, buttonMask uint16) {
|
||||
if buttonMask&0x08 != 0 {
|
||||
m.postScroll(src, scrollPixelsPerWheelTick)
|
||||
}
|
||||
if buttonMask&0x10 != 0 {
|
||||
m.postScroll(src, -scrollPixelsPerWheelTick)
|
||||
}
|
||||
}
|
||||
|
||||
// scrollPixelsPerWheelTick is the pixel delta we post for one VNC wheel
|
||||
// button event. Browser-based RFB clients typically emit one press+release
|
||||
// per ~10 px of host wheel/trackpad motion, so a real gesture arrives as
|
||||
// many small events; ~20 px per event keeps the resulting macOS scroll
|
||||
// fluid without overshooting on a single notch.
|
||||
const scrollPixelsPerWheelTick int32 = 22
|
||||
|
||||
func (m *MacInputInjector) postMouse(src uintptr, eventType int32, x, y float64, button int32) {
|
||||
if cgEventCreateMouseEvent == nil {
|
||||
return
|
||||
}
|
||||
event := cgEventCreateMouseEvent(src, eventType, x, y, button)
|
||||
if event == 0 {
|
||||
return
|
||||
}
|
||||
cgEventPost(kCGHIDEventTap, event)
|
||||
cfRelease(event)
|
||||
}
|
||||
|
||||
// postMouseClick stamps the click count on the event before posting it.
|
||||
// Without this stamp macOS treats every press as a fresh single click.
|
||||
func (m *MacInputInjector) postMouseClick(src uintptr, eventType int32, x, y float64, button int32, clickCount int64) {
|
||||
if cgEventCreateMouseEvent == nil {
|
||||
return
|
||||
}
|
||||
event := cgEventCreateMouseEvent(src, eventType, x, y, button)
|
||||
if event == 0 {
|
||||
return
|
||||
}
|
||||
if cgEventSetIntegerValueField != nil && clickCount > 1 {
|
||||
cgEventSetIntegerValueField(event, kCGMouseEventClickState, clickCount)
|
||||
}
|
||||
cgEventPost(kCGHIDEventTap, event)
|
||||
cfRelease(event)
|
||||
}
|
||||
|
||||
func (m *MacInputInjector) postScroll(src uintptr, deltaY int32) {
|
||||
if cgEventCreateScrollWheelEventAddr == 0 {
|
||||
return
|
||||
}
|
||||
// CGEventCreateScrollWheelEvent(source, units, wheelCount, wheel1delta).
|
||||
// Pixel units (0) feel smoother given the small per-event deltas typical
|
||||
// of RFB wheel events than line units (1) where each event jumps a
|
||||
// whole line. Variadic C function, pass via SyscallN.
|
||||
r1, _, _ := purego.SyscallN(cgEventCreateScrollWheelEventAddr,
|
||||
src, 0, 1, uintptr(uint32(deltaY)))
|
||||
if r1 == 0 {
|
||||
return
|
||||
}
|
||||
cgEventPost(kCGHIDEventTap, r1)
|
||||
cfRelease(r1)
|
||||
}
|
||||
|
||||
// SetClipboard sets the macOS clipboard using pbcopy.
|
||||
func (m *MacInputInjector) SetClipboard(text string) {
|
||||
if m.pbcopyPath == "" {
|
||||
return
|
||||
}
|
||||
cmd := exec.Command(m.pbcopyPath)
|
||||
cmd.Stdin = strings.NewReader(text)
|
||||
if err := cmd.Run(); err != nil {
|
||||
log.Tracef("set clipboard via pbcopy: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TypeText synthesizes the given text as keystrokes via Core Graphics.
|
||||
// Lets a client push host clipboard content to the focused remote app
|
||||
// even when the app doesn't honor pbpaste-style clipboard sync (e.g.
|
||||
// login screens, locked-down apps). ASCII printable runes only; others
|
||||
// are skipped.
|
||||
func (m *MacInputInjector) TypeText(text string) {
|
||||
wakeDisplay()
|
||||
src := ensureEventSource()
|
||||
if src == 0 {
|
||||
return
|
||||
}
|
||||
const maxChars = 4096
|
||||
count := 0
|
||||
for _, r := range text {
|
||||
if count >= maxChars {
|
||||
break
|
||||
}
|
||||
count++
|
||||
typeRune(src, r)
|
||||
}
|
||||
}
|
||||
|
||||
// typeRune emits the press/release events for a single ASCII rune, framing
|
||||
// the keystroke with Shift-down/up when required by the keysym.
|
||||
func typeRune(src uintptr, r rune) {
|
||||
const shiftKey = uint16(0x38) // kVK_Shift
|
||||
keysym, shift, ok := keysymForASCIIRune(r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
keycode := keysymToMacKeycode(keysym)
|
||||
if keycode == 0xFFFF {
|
||||
return
|
||||
}
|
||||
if shift {
|
||||
postKey(src, shiftKey, true)
|
||||
}
|
||||
postKey(src, keycode, true)
|
||||
postKey(src, keycode, false)
|
||||
if shift {
|
||||
postKey(src, shiftKey, false)
|
||||
}
|
||||
}
|
||||
|
||||
func postKey(src uintptr, keycode uint16, down bool) {
|
||||
e := cgEventCreateKeyboardEvent(src, keycode, down)
|
||||
if e == 0 {
|
||||
return
|
||||
}
|
||||
cgEventPost(kCGHIDEventTap, e)
|
||||
cfRelease(e)
|
||||
}
|
||||
|
||||
// GetClipboard reads the macOS clipboard using pbpaste.
|
||||
func (m *MacInputInjector) GetClipboard() string {
|
||||
if m.pbpastePath == "" {
|
||||
return ""
|
||||
}
|
||||
out, err := exec.Command(m.pbpastePath).Output()
|
||||
if err != nil {
|
||||
// pbpaste exits 1 when the pasteboard has no string flavour.
|
||||
return ""
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
|
||||
// Close releases the idle-sleep assertion held for the injector's lifetime.
|
||||
func (m *MacInputInjector) Close() {
|
||||
releasePreventIdleSleep()
|
||||
}
|
||||
|
||||
func keysymToMacKeycode(keysym uint32) uint16 {
|
||||
if keysym >= 0x61 && keysym <= 0x7a {
|
||||
return asciiToMacKey[keysym-0x61]
|
||||
}
|
||||
if keysym >= 0x41 && keysym <= 0x5a {
|
||||
return asciiToMacKey[keysym-0x41]
|
||||
}
|
||||
if keysym >= 0x30 && keysym <= 0x39 {
|
||||
return digitToMacKey[keysym-0x30]
|
||||
}
|
||||
if code, ok := specialKeyMap[keysym]; ok {
|
||||
return code
|
||||
}
|
||||
return 0xFFFF
|
||||
}
|
||||
|
||||
var asciiToMacKey = [26]uint16{
|
||||
0x00, 0x0B, 0x08, 0x02, 0x0E, 0x03, 0x05, 0x04,
|
||||
0x22, 0x26, 0x28, 0x25, 0x2E, 0x2D, 0x1F, 0x23,
|
||||
0x0C, 0x0F, 0x01, 0x11, 0x20, 0x09, 0x0D, 0x07,
|
||||
0x10, 0x06,
|
||||
}
|
||||
|
||||
var digitToMacKey = [10]uint16{
|
||||
0x1D, 0x12, 0x13, 0x14, 0x15, 0x17, 0x16, 0x1A, 0x1C, 0x19,
|
||||
}
|
||||
|
||||
var specialKeyMap = map[uint32]uint16{
|
||||
// Whitespace and editing
|
||||
0x0020: 0x31, // space
|
||||
0xff08: 0x33, // BackSpace
|
||||
0xff09: 0x30, // Tab
|
||||
0xff0d: 0x24, // Return
|
||||
0xff1b: 0x35, // Escape
|
||||
0xffff: 0x75, // Delete (forward)
|
||||
|
||||
// Navigation
|
||||
0xff50: 0x73, // Home
|
||||
0xff51: 0x7B, // Left
|
||||
0xff52: 0x7E, // Up
|
||||
0xff53: 0x7C, // Right
|
||||
0xff54: 0x7D, // Down
|
||||
0xff55: 0x74, // Page_Up
|
||||
0xff56: 0x79, // Page_Down
|
||||
0xff57: 0x77, // End
|
||||
0xff63: 0x72, // Insert (Help on Mac)
|
||||
|
||||
// Modifiers
|
||||
0xffe1: 0x38, // Shift_L
|
||||
0xffe2: 0x3C, // Shift_R
|
||||
0xffe3: 0x3B, // Control_L
|
||||
0xffe4: 0x3E, // Control_R
|
||||
0xffe5: 0x39, // Caps_Lock
|
||||
0xffe9: 0x3A, // Alt_L (Option)
|
||||
0xffea: 0x3D, // Alt_R (Option)
|
||||
0xffe7: 0x37, // Meta_L (Command)
|
||||
0xffe8: 0x36, // Meta_R (Command)
|
||||
0xffeb: 0x37, // Super_L (Command)
|
||||
0xffec: 0x36, // Super_R (Command)
|
||||
|
||||
// Mode_switch / ISO_Level3_Shift (for macOS Option remap on layouts)
|
||||
0xff7e: 0x3A, // Mode_switch -> Option
|
||||
0xfe03: 0x3D, // ISO_Level3_Shift -> Right Option
|
||||
|
||||
// Function keys
|
||||
0xffbe: 0x7A, // F1
|
||||
0xffbf: 0x78, // F2
|
||||
0xffc0: 0x63, // F3
|
||||
0xffc1: 0x76, // F4
|
||||
0xffc2: 0x60, // F5
|
||||
0xffc3: 0x61, // F6
|
||||
0xffc4: 0x62, // F7
|
||||
0xffc5: 0x64, // F8
|
||||
0xffc6: 0x65, // F9
|
||||
0xffc7: 0x6D, // F10
|
||||
0xffc8: 0x67, // F11
|
||||
0xffc9: 0x6F, // F12
|
||||
0xffca: 0x69, // F13
|
||||
0xffcb: 0x6B, // F14
|
||||
0xffcc: 0x71, // F15
|
||||
0xffcd: 0x6A, // F16
|
||||
0xffce: 0x40, // F17
|
||||
0xffcf: 0x4F, // F18
|
||||
0xffd0: 0x50, // F19
|
||||
0xffd1: 0x5A, // F20
|
||||
|
||||
// Punctuation (US keyboard layout, keysym = ASCII code)
|
||||
0x002d: 0x1B, // minus -
|
||||
0x003d: 0x18, // equal =
|
||||
0x005b: 0x21, // bracketleft [
|
||||
0x005d: 0x1E, // bracketright ]
|
||||
0x005c: 0x2A, // backslash
|
||||
0x003b: 0x29, // semicolon ;
|
||||
0x0027: 0x27, // apostrophe '
|
||||
0x0060: 0x32, // grave `
|
||||
0x002c: 0x2B, // comma ,
|
||||
0x002e: 0x2F, // period .
|
||||
0x002f: 0x2C, // slash /
|
||||
|
||||
// Shifted punctuation (clients sometimes send these as separate keysyms)
|
||||
0x005f: 0x1B, // underscore _ (shift+minus)
|
||||
0x002b: 0x18, // plus + (shift+equal)
|
||||
0x007b: 0x21, // braceleft { (shift+[)
|
||||
0x007d: 0x1E, // braceright } (shift+])
|
||||
0x007c: 0x2A, // bar | (shift+\)
|
||||
0x003a: 0x29, // colon : (shift+;)
|
||||
0x0022: 0x27, // quotedbl " (shift+')
|
||||
0x007e: 0x32, // tilde ~ (shift+`)
|
||||
0x003c: 0x2B, // less < (shift+,)
|
||||
0x003e: 0x2F, // greater > (shift+.)
|
||||
0x003f: 0x2C, // question ? (shift+/)
|
||||
0x0021: 0x12, // exclam ! (shift+1)
|
||||
0x0040: 0x13, // at @ (shift+2)
|
||||
0x0023: 0x14, // numbersign # (shift+3)
|
||||
0x0024: 0x15, // dollar $ (shift+4)
|
||||
0x0025: 0x17, // percent % (shift+5)
|
||||
0x005e: 0x16, // asciicircum ^ (shift+6)
|
||||
0x0026: 0x1A, // ampersand & (shift+7)
|
||||
0x002a: 0x1C, // asterisk * (shift+8)
|
||||
0x0028: 0x19, // parenleft ( (shift+9)
|
||||
0x0029: 0x1D, // parenright ) (shift+0)
|
||||
|
||||
// Numpad
|
||||
0xffb0: 0x52, // KP_0
|
||||
0xffb1: 0x53, // KP_1
|
||||
0xffb2: 0x54, // KP_2
|
||||
0xffb3: 0x55, // KP_3
|
||||
0xffb4: 0x56, // KP_4
|
||||
0xffb5: 0x57, // KP_5
|
||||
0xffb6: 0x58, // KP_6
|
||||
0xffb7: 0x59, // KP_7
|
||||
0xffb8: 0x5B, // KP_8
|
||||
0xffb9: 0x5C, // KP_9
|
||||
0xffae: 0x41, // KP_Decimal
|
||||
0xffaa: 0x43, // KP_Multiply
|
||||
0xffab: 0x45, // KP_Add
|
||||
0xffad: 0x4E, // KP_Subtract
|
||||
0xffaf: 0x4B, // KP_Divide
|
||||
0xff8d: 0x4C, // KP_Enter
|
||||
0xffbd: 0x51, // KP_Equal
|
||||
}
|
||||
|
||||
var _ InputInjector = (*MacInputInjector)(nil)
|
||||
17
client/vnc/server/input_uinput_freebsd.go
Normal file
17
client/vnc/server/input_uinput_freebsd.go
Normal file
@@ -0,0 +1,17 @@
|
||||
//go:build freebsd
|
||||
|
||||
package server
|
||||
|
||||
import "fmt"
|
||||
|
||||
// UInputInjector is a freebsd placeholder; the linux uinput implementation
|
||||
// uses Linux-only ioctls (UI_DEV_CREATE etc.) and is not portable.
|
||||
type UInputInjector struct {
|
||||
StubInputInjector
|
||||
}
|
||||
|
||||
// NewUInputInjector always returns an error on freebsd so callers fall back
|
||||
// to a stub or platform-appropriate injector.
|
||||
func NewUInputInjector(_, _ int) (*UInputInjector, error) {
|
||||
return nil, fmt.Errorf("uinput not implemented on freebsd")
|
||||
}
|
||||
488
client/vnc/server/input_uinput_linux.go
Normal file
488
client/vnc/server/input_uinput_linux.go
Normal file
@@ -0,0 +1,488 @@
|
||||
//go:build linux
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// /dev/uinput ioctl numbers. Computed from the kernel _IO/_IOW macros so
|
||||
// we don't depend on cgo. UINPUT_IOCTL_BASE = 'U' = 0x55.
|
||||
const (
|
||||
uiDevCreate = 0x5501
|
||||
uiDevDestroy = 0x5502
|
||||
// _IOW('U', 3, struct uinput_setup); uinput_setup is 92 bytes on amd64.
|
||||
uiDevSetup = (1 << 30) | (92 << 16) | (0x55 << 8) | 3
|
||||
uiSetEvBit = (1 << 30) | (4 << 16) | (0x55 << 8) | 100
|
||||
uiSetKeyBit = (1 << 30) | (4 << 16) | (0x55 << 8) | 101
|
||||
uiSetAbsBit = (1 << 30) | (4 << 16) | (0x55 << 8) | 103
|
||||
uinputAbsSize = 64 // legacy struct uses absmin/absmax/absfuzz/absflat[64].
|
||||
)
|
||||
|
||||
// Linux input event types and key codes (linux/input-event-codes.h).
|
||||
const (
|
||||
evSyn = 0x00
|
||||
evKey = 0x01
|
||||
evAbs = 0x03
|
||||
evRep = 0x14
|
||||
|
||||
synReport = 0
|
||||
|
||||
absX = 0x00
|
||||
absY = 0x01
|
||||
|
||||
btnLeft = 0x110
|
||||
btnRight = 0x111
|
||||
btnMiddle = 0x112
|
||||
btnSide = 0x113 // mouse-back (X1)
|
||||
btnExtra = 0x114 // mouse-forward (X2)
|
||||
)
|
||||
|
||||
// inputEvent matches struct input_event for x86_64 (timeval is 16 bytes).
|
||||
// Total size 24 bytes; Go's natural alignment matches the kernel layout.
|
||||
type inputEvent struct {
|
||||
TvSec int64
|
||||
TvUsec int64
|
||||
Type uint16
|
||||
Code uint16
|
||||
Value int32
|
||||
}
|
||||
|
||||
// UInputInjector synthesizes keyboard and mouse events via /dev/uinput.
|
||||
// Used as a fallback when X11 isn't running, e.g. at the kernel console
|
||||
// or pre-login screen on a server without X. Requires root or
|
||||
// CAP_SYS_ADMIN, which the netbird service has.
|
||||
type UInputInjector struct {
|
||||
mu sync.Mutex
|
||||
fd int
|
||||
closeOnce sync.Once
|
||||
keysymToKey map[uint32]uint16
|
||||
prevButtons uint16
|
||||
screenW int
|
||||
screenH int
|
||||
}
|
||||
|
||||
// NewUInputInjector opens /dev/uinput and registers a virtual keyboard +
|
||||
// absolute pointer device sized to (w, h). The dimensions are needed
|
||||
// because uinput's ABS axes don't autoscale; we always send absolute
|
||||
// coordinates and let the kernel route them to the right monitor.
|
||||
func NewUInputInjector(w, h int) (*UInputInjector, error) {
|
||||
if w <= 0 || h <= 0 {
|
||||
return nil, fmt.Errorf("invalid screen size: %dx%d", w, h)
|
||||
}
|
||||
fd, err := unix.Open("/dev/uinput", unix.O_WRONLY|unix.O_NONBLOCK, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open /dev/uinput: %w", err)
|
||||
}
|
||||
|
||||
if err := setBit(fd, uiSetEvBit, evKey); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
if err := setBit(fd, uiSetEvBit, evAbs); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
if err := setBit(fd, uiSetEvBit, evSyn); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
// Advertise key auto-repeat so the kernel input core repeats held
|
||||
// keys at the configured rate (default ~250 ms delay, ~33 ms period).
|
||||
// Without this, holding Backspace etc. only deletes one character.
|
||||
if err := setBit(fd, uiSetEvBit, evRep); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keymap := buildUInputKeymap()
|
||||
for _, key := range keymap {
|
||||
if err := setBit(fd, uiSetKeyBit, uint32(key)); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("UI_SET_KEYBIT %d: %w", key, err)
|
||||
}
|
||||
}
|
||||
for _, btn := range []uint16{btnLeft, btnRight, btnMiddle, btnSide, btnExtra} {
|
||||
if err := setBit(fd, uiSetKeyBit, uint32(btn)); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("UI_SET_KEYBIT btn %d: %w", btn, err)
|
||||
}
|
||||
}
|
||||
if err := setBit(fd, uiSetAbsBit, absX); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
if err := setBit(fd, uiSetAbsBit, absY); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := writeUInputUserDev(fd, w, h); err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, err
|
||||
}
|
||||
if _, _, e := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uiDevCreate, 0); e != 0 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("UI_DEV_CREATE: %v", e)
|
||||
}
|
||||
// Give udev a moment to settle before sending events.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
inj := &UInputInjector{
|
||||
fd: fd,
|
||||
keysymToKey: keymapByKeysym(keymap),
|
||||
screenW: w,
|
||||
screenH: h,
|
||||
}
|
||||
log.Infof("uinput injector ready: %dx%d, %d keys", w, h, len(inj.keysymToKey))
|
||||
return inj, nil
|
||||
}
|
||||
|
||||
func setBit(fd int, op uintptr, code uint32) error {
|
||||
if _, _, e := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), op, uintptr(code)); e != 0 {
|
||||
return fmt.Errorf("ioctl 0x%x %d: %v", op, code, e)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeUInputUserDev uses the legacy uinput_user_dev path (write the
|
||||
// whole struct then UI_DEV_CREATE) which is universally supported on
|
||||
// older and current kernels alike. uinput_user_dev is name(80) + id(8) +
|
||||
// ff_effects_max(4) + absmax/absmin/absfuzz/absflat[64] = 92 + 4*64*4 =
|
||||
// 1116 bytes total.
|
||||
func writeUInputUserDev(fd, w, h int) error {
|
||||
const sz = 80 + 8 + 4 + uinputAbsSize*4*4
|
||||
buf := make([]byte, sz)
|
||||
copy(buf[0:80], []byte("netbird-vnc-uinput"))
|
||||
// id: BUS_VIRTUAL=0x06, vendor=0x0001, product=0x0001, version=1.
|
||||
binary.LittleEndian.PutUint16(buf[80:82], 0x06)
|
||||
binary.LittleEndian.PutUint16(buf[82:84], 0x0001)
|
||||
binary.LittleEndian.PutUint16(buf[84:86], 0x0001)
|
||||
binary.LittleEndian.PutUint16(buf[86:88], 0x0001)
|
||||
// ff_effects_max(4) at 88..92 stays zero.
|
||||
// absmax[64] at 92..348: set absX/absY.
|
||||
absmaxOff := 80 + 8 + 4
|
||||
absminOff := absmaxOff + uinputAbsSize*4
|
||||
binary.LittleEndian.PutUint32(buf[absmaxOff+absX*4:], uint32(w-1))
|
||||
binary.LittleEndian.PutUint32(buf[absmaxOff+absY*4:], uint32(h-1))
|
||||
binary.LittleEndian.PutUint32(buf[absminOff+absX*4:], 0)
|
||||
binary.LittleEndian.PutUint32(buf[absminOff+absY*4:], 0)
|
||||
if _, err := unix.Write(fd, buf); err != nil {
|
||||
return fmt.Errorf("write uinput_user_dev: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// emit writes a single input_event to the device. Caller-locked.
|
||||
func (u *UInputInjector) emit(typ, code uint16, value int32) error {
|
||||
ev := inputEvent{Type: typ, Code: code, Value: value}
|
||||
buf := (*[unsafe.Sizeof(inputEvent{})]byte)(unsafe.Pointer(&ev))[:]
|
||||
_, err := unix.Write(u.fd, buf)
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *UInputInjector) sync() {
|
||||
_ = u.emit(evSyn, synReport, 0)
|
||||
}
|
||||
|
||||
// InjectKey synthesizes a press or release for the given X11 keysym.
|
||||
func (u *UInputInjector) InjectKey(keysym uint32, down bool) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
code, ok := u.keysymToKey[keysym]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
u.emitKeyCode(code, down)
|
||||
}
|
||||
|
||||
// InjectKeyScancode injects a press or release using the QEMU scancode.
|
||||
// uinput speaks Linux KEY_* codes natively, so we map QEMU scancode →
|
||||
// KEY_* via qemuToLinuxKey. On miss (scancode we don't have a mapping
|
||||
// for) we fall back to the keysym path, which is exactly the legacy
|
||||
// behaviour.
|
||||
func (u *UInputInjector) InjectKeyScancode(scancode, keysym uint32, down bool) {
|
||||
code := qemuScancodeToLinuxKey(scancode)
|
||||
if code == 0 {
|
||||
u.InjectKey(keysym, down)
|
||||
return
|
||||
}
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
u.emitKeyCode(uint16(code), down)
|
||||
}
|
||||
|
||||
// emitKeyCode emits one key down/up event plus a sync. Caller holds u.mu.
|
||||
func (u *UInputInjector) emitKeyCode(code uint16, down bool) {
|
||||
value := int32(0)
|
||||
if down {
|
||||
value = 1
|
||||
}
|
||||
if err := u.emit(evKey, code, value); err != nil {
|
||||
log.Tracef("uinput emit key: %v", err)
|
||||
return
|
||||
}
|
||||
u.sync()
|
||||
}
|
||||
|
||||
// InjectPointer moves the absolute pointer and presses/releases buttons
|
||||
// based on the RFB button mask delta against the previous mask.
|
||||
func (u *UInputInjector) InjectPointer(buttonMask uint16, x, y, serverW, serverH int) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
if serverW <= 1 || serverH <= 1 {
|
||||
return
|
||||
}
|
||||
absXVal := int32(x * (u.screenW - 1) / (serverW - 1))
|
||||
absYVal := int32(y * (u.screenH - 1) / (serverH - 1))
|
||||
_ = u.emit(evAbs, absX, absXVal)
|
||||
_ = u.emit(evAbs, absY, absYVal)
|
||||
|
||||
type btnMap struct {
|
||||
bit uint16
|
||||
key uint16
|
||||
}
|
||||
for _, b := range []btnMap{
|
||||
{0x01, btnLeft},
|
||||
{0x02, btnMiddle},
|
||||
{0x04, btnRight},
|
||||
{1 << 7, btnSide},
|
||||
{1 << 8, btnExtra},
|
||||
} {
|
||||
pressed := buttonMask&b.bit != 0
|
||||
was := u.prevButtons&b.bit != 0
|
||||
if pressed && !was {
|
||||
_ = u.emit(evKey, b.key, 1)
|
||||
} else if !pressed && was {
|
||||
_ = u.emit(evKey, b.key, 0)
|
||||
}
|
||||
}
|
||||
u.prevButtons = buttonMask
|
||||
u.sync()
|
||||
}
|
||||
|
||||
// SetClipboard is a no-op on the framebuffer console: there is no system
|
||||
// clipboard daemon. Use TypeText (Paste button) to deliver host text.
|
||||
func (u *UInputInjector) SetClipboard(_ string) {
|
||||
// no system clipboard daemon on framebuffer console
|
||||
}
|
||||
|
||||
// GetClipboard returns empty: no clipboard outside X11/Wayland.
|
||||
func (u *UInputInjector) GetClipboard() string { return "" }
|
||||
|
||||
// TypeText synthesizes the given UTF-8 text as keystrokes. Only ASCII
|
||||
// printable characters and newline are typed; other runes are skipped.
|
||||
// This drives the "paste" button: with no console clipboard available,
|
||||
// keystroke-by-keystroke entry is the only way to deliver a password to
|
||||
// a TTY login prompt.
|
||||
func (u *UInputInjector) TypeText(text string) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
const maxChars = 4096
|
||||
count := 0
|
||||
for _, r := range text {
|
||||
if count >= maxChars {
|
||||
break
|
||||
}
|
||||
count++
|
||||
code, shift, ok := keyForRune(r)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if shift {
|
||||
_ = u.emit(evKey, keyLeftShift, 1)
|
||||
}
|
||||
_ = u.emit(evKey, code, 1)
|
||||
_ = u.emit(evKey, code, 0)
|
||||
if shift {
|
||||
_ = u.emit(evKey, keyLeftShift, 0)
|
||||
}
|
||||
u.sync()
|
||||
}
|
||||
}
|
||||
|
||||
// Close destroys the virtual uinput device and closes the file descriptor.
|
||||
func (u *UInputInjector) Close() {
|
||||
u.closeOnce.Do(func() {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
if u.fd >= 0 {
|
||||
_, _, _ = unix.Syscall(unix.SYS_IOCTL, uintptr(u.fd), uiDevDestroy, 0)
|
||||
_ = unix.Close(u.fd)
|
||||
u.fd = -1
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Linux KEY_* codes live in scancodes.go (shared with the QEMU scancode
|
||||
// path). Don't duplicate them here.
|
||||
|
||||
// buildUInputKeymap returns every linux KEY_ code we want the virtual
|
||||
// device to advertise during UI_SET_KEYBIT. Order doesn't matter.
|
||||
func buildUInputKeymap() []uint16 {
|
||||
out := make([]uint16, 0, 128)
|
||||
// Letters: KEY_A=30, KEY_B=48, etc; not a clean range. The kernel's
|
||||
// row-by-row layout is qwertyuiop / asdfghjkl / zxcvbnm.
|
||||
letters := []uint16{
|
||||
30, 48, 46, 32, 18, 33, 34, 35, 23, 36, 37, 38, 50, // a..m
|
||||
49, 24, 25, 16, 19, 31, 20, 22, 47, 17, 45, 21, 44, // n..z
|
||||
}
|
||||
out = append(out, letters...)
|
||||
// Top-row digits: KEY_1..KEY_0 = 2..11.
|
||||
for i := uint16(2); i <= 11; i++ {
|
||||
out = append(out, i)
|
||||
}
|
||||
// Function keys F1..F12 = 59..68 + 87, 88. We only register F1..F12
|
||||
// which the kernel header enumerates as a contiguous block.
|
||||
for i := uint16(59); i <= 68; i++ {
|
||||
out = append(out, i)
|
||||
}
|
||||
out = append(out, 87, 88)
|
||||
out = append(out, []uint16{
|
||||
keyEsc, keyMinus, keyEqual, keyBackspace, keyTab, keyEnter,
|
||||
keyLeftCtrl, keyRightCtrl, keyLeftShift, keyRightShift,
|
||||
keyLeftAlt, keyRightAlt, keyLeftMeta, keyRightMeta,
|
||||
keySpace, keyCapsLock,
|
||||
keyLeftBracket, keyRightBracket, keyBackslash,
|
||||
keySemicolon, keyApostrophe, keyGrave,
|
||||
keyComma, keyDot, keySlash,
|
||||
keyHome, keyEnd, keyPageUp, keyPageDown,
|
||||
keyUp, keyDown, keyLeft, keyRight,
|
||||
keyInsert, keyDelete,
|
||||
}...)
|
||||
return out
|
||||
}
|
||||
|
||||
// keymapByKeysym maps X11 keysyms (the values our session receives over
|
||||
// RFB) onto Linux KEY_ codes. Shifted ASCII keysyms (uppercase letters,
|
||||
// "!@#..." etc.) map to the same scan code as their unshifted twin: the
|
||||
// client also sends a separate Shift keysym (0xffe1), so the kernel
|
||||
// composes the final character from the held modifier + scan code.
|
||||
func keymapByKeysym(_ []uint16) map[uint32]uint16 {
|
||||
letters := map[rune]uint16{
|
||||
'a': 30, 'b': 48, 'c': 46, 'd': 32, 'e': 18, 'f': 33, 'g': 34,
|
||||
'h': 35, 'i': 23, 'j': 36, 'k': 37, 'l': 38, 'm': 50,
|
||||
'n': 49, 'o': 24, 'p': 25, 'q': 16, 'r': 19, 's': 31, 't': 20,
|
||||
'u': 22, 'v': 47, 'w': 17, 'x': 45, 'y': 21, 'z': 44,
|
||||
}
|
||||
m := map[uint32]uint16{
|
||||
// Digits.
|
||||
'0': 11, '1': 2, '2': 3, '3': 4, '4': 5, '5': 6, '6': 7,
|
||||
'7': 8, '8': 9, '9': 10,
|
||||
// Shifted digits (US layout).
|
||||
')': 11, '!': 2, '@': 3, '#': 4, '$': 5, '%': 6, '^': 7,
|
||||
'&': 8, '*': 9, '(': 10,
|
||||
// Punctuation (US layout) and shifted twins.
|
||||
' ': keySpace,
|
||||
'-': keyMinus, '_': keyMinus,
|
||||
'=': keyEqual, '+': keyEqual,
|
||||
'[': keyLeftBracket, '{': keyLeftBracket,
|
||||
']': keyRightBracket, '}': keyRightBracket,
|
||||
'\\': keyBackslash, '|': keyBackslash,
|
||||
';': keySemicolon, ':': keySemicolon,
|
||||
'\'': keyApostrophe, '"': keyApostrophe,
|
||||
'`': keyGrave, '~': keyGrave,
|
||||
',': keyComma, '<': keyComma,
|
||||
'.': keyDot, '>': keyDot,
|
||||
'/': keySlash, '?': keySlash,
|
||||
// Special keys (X11 keysyms).
|
||||
0xff08: keyBackspace, 0xff09: keyTab, 0xff0d: keyEnter,
|
||||
0xff1b: keyEsc, 0xffff: keyDelete,
|
||||
0xff50: keyHome, 0xff57: keyEnd,
|
||||
0xff51: keyLeft, 0xff52: keyUp, 0xff53: keyRight, 0xff54: keyDown,
|
||||
0xff55: keyPageUp, 0xff56: keyPageDown, 0xff63: keyInsert,
|
||||
0xffe1: keyLeftShift, 0xffe2: keyRightShift,
|
||||
0xffe3: keyLeftCtrl, 0xffe4: keyRightCtrl,
|
||||
0xffe9: keyLeftAlt, 0xffea: keyRightAlt,
|
||||
0xffeb: keyLeftMeta, 0xffec: keyRightMeta,
|
||||
}
|
||||
// Letters: register both lowercase and uppercase keysyms onto the same
|
||||
// KEY_ code. The client sends Shift separately for uppercase.
|
||||
for r, code := range letters {
|
||||
m[uint32(r)] = code
|
||||
m[uint32(r-'a'+'A')] = code
|
||||
}
|
||||
// Function keys F1..F12 (X11 keysyms 0xffbe..0xffc9 → KEY_F1..KEY_F12).
|
||||
xF := uint32(0xffbe)
|
||||
codes := []uint16{59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 87, 88}
|
||||
for i, c := range codes {
|
||||
m[xF+uint32(i)] = c
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// keyForRune maps a printable rune to (keycode, needsShift). Used by
|
||||
// TypeText to synthesize keystrokes for a paste payload.
|
||||
func keyForRune(r rune) (uint16, bool, bool) {
|
||||
if r >= 'a' && r <= 'z' {
|
||||
m := map[rune]uint16{
|
||||
'a': 30, 'b': 48, 'c': 46, 'd': 32, 'e': 18, 'f': 33, 'g': 34,
|
||||
'h': 35, 'i': 23, 'j': 36, 'k': 37, 'l': 38, 'm': 50,
|
||||
'n': 49, 'o': 24, 'p': 25, 'q': 16, 'r': 19, 's': 31, 't': 20,
|
||||
'u': 22, 'v': 47, 'w': 17, 'x': 45, 'y': 21, 'z': 44,
|
||||
}
|
||||
return m[r], false, true
|
||||
}
|
||||
if r >= 'A' && r <= 'Z' {
|
||||
c, _, ok := keyForRune(unicode.ToLower(r))
|
||||
return c, true, ok
|
||||
}
|
||||
if r >= '0' && r <= '9' {
|
||||
nums := []uint16{11, 2, 3, 4, 5, 6, 7, 8, 9, 10}
|
||||
idx := int(r - '0')
|
||||
if idx < 0 || idx >= len(nums) { //nolint:gosec // explicit bound disarms G602
|
||||
return 0, false, false
|
||||
}
|
||||
return nums[idx], false, true
|
||||
}
|
||||
if r == '\n' || r == '\r' {
|
||||
return keyEnter, false, true
|
||||
}
|
||||
if k, ok := punctUnshifted[r]; ok {
|
||||
return k, false, true
|
||||
}
|
||||
if k, ok := punctShifted[r]; ok {
|
||||
return k, true, true
|
||||
}
|
||||
return 0, false, false
|
||||
}
|
||||
|
||||
// punctUnshifted maps ASCII punctuation that needs no Shift to its uinput
|
||||
// KEY_* code. Split out of keyForRune's switch to keep the function's
|
||||
// cognitive complexity below Sonar's threshold.
|
||||
var punctUnshifted = map[rune]uint16{
|
||||
' ': keySpace,
|
||||
'\t': keyTab,
|
||||
'-': keyMinus,
|
||||
'=': keyEqual,
|
||||
'[': keyLeftBracket,
|
||||
']': keyRightBracket,
|
||||
'\\': keyBackslash,
|
||||
';': keySemicolon,
|
||||
'\'': keyApostrophe,
|
||||
'`': keyGrave,
|
||||
',': keyComma,
|
||||
'.': keyDot,
|
||||
'/': keySlash,
|
||||
}
|
||||
|
||||
// punctShifted maps ASCII punctuation that requires Shift to its base KEY_*
|
||||
// code; the caller adds the shift modifier itself.
|
||||
var punctShifted = map[rune]uint16{
|
||||
'!': 2, '@': 3, '#': 4, '$': 5, '%': 6, '^': 7, '&': 8, '*': 9,
|
||||
'(': 10, ')': 11,
|
||||
'_': keyMinus, '+': keyEqual,
|
||||
'{': keyLeftBracket, '}': keyRightBracket, '|': keyBackslash,
|
||||
':': keySemicolon, '"': keyApostrophe, '~': keyGrave,
|
||||
'<': keyComma, '>': keyDot, '?': keySlash,
|
||||
}
|
||||
|
||||
var _ InputInjector = (*UInputInjector)(nil)
|
||||
599
client/vnc/server/input_windows.go
Normal file
599
client/vnc/server/input_windows.go
Normal file
@@ -0,0 +1,599 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
procOpenEventW = kernel32.NewProc("OpenEventW")
|
||||
procSendInput = user32.NewProc("SendInput")
|
||||
procVkKeyScanA = user32.NewProc("VkKeyScanA")
|
||||
)
|
||||
|
||||
const eventModifyState = 0x0002
|
||||
|
||||
const (
|
||||
inputMouse = 0
|
||||
inputKeyboard = 1
|
||||
|
||||
mouseeventfMove = 0x0001
|
||||
mouseeventfLeftDown = 0x0002
|
||||
mouseeventfLeftUp = 0x0004
|
||||
mouseeventfRightDown = 0x0008
|
||||
mouseeventfRightUp = 0x0010
|
||||
mouseeventfMiddleDown = 0x0020
|
||||
mouseeventfMiddleUp = 0x0040
|
||||
mouseeventfXDown = 0x0080
|
||||
mouseeventfXUp = 0x0100
|
||||
mouseeventfWheel = 0x0800
|
||||
mouseeventfAbsolute = 0x8000
|
||||
|
||||
// X-button identifiers carried in the dwData field of MOUSEEVENTF_X*
|
||||
// events. XBUTTON1 is mouse-back, XBUTTON2 is mouse-forward.
|
||||
xButton1 = 0x0001
|
||||
xButton2 = 0x0002
|
||||
|
||||
wheelDelta = 120
|
||||
|
||||
keyeventfExtendedKey = 0x0001
|
||||
keyeventfKeyUp = 0x0002
|
||||
keyeventfUnicode = 0x0004
|
||||
keyeventfScanCode = 0x0008
|
||||
)
|
||||
|
||||
// maxTypedClipboardChars caps the number of characters we will synthesize as
|
||||
// keystrokes when falling back on the Winlogon desktop. Passwords are short;
|
||||
// a huge clipboard getting typed into the login screen would be surprising.
|
||||
const maxTypedClipboardChars = 4096
|
||||
|
||||
type mouseInput struct {
|
||||
Dx int32
|
||||
Dy int32
|
||||
MouseData uint32
|
||||
DwFlags uint32
|
||||
Time uint32
|
||||
DwExtraInfo uintptr
|
||||
}
|
||||
|
||||
type keybdInput struct {
|
||||
WVk uint16
|
||||
WScan uint16
|
||||
DwFlags uint32
|
||||
Time uint32
|
||||
DwExtraInfo uintptr
|
||||
_ [8]byte
|
||||
}
|
||||
|
||||
type inputUnion [32]byte
|
||||
|
||||
type winInput struct {
|
||||
Type uint32
|
||||
_ [4]byte
|
||||
Data inputUnion
|
||||
}
|
||||
|
||||
func sendMouseInput(flags uint32, dx, dy int32, mouseData uint32) {
|
||||
mi := mouseInput{
|
||||
Dx: dx,
|
||||
Dy: dy,
|
||||
MouseData: mouseData,
|
||||
DwFlags: flags,
|
||||
}
|
||||
inp := winInput{Type: inputMouse}
|
||||
copy(inp.Data[:], (*[unsafe.Sizeof(mi)]byte)(unsafe.Pointer(&mi))[:])
|
||||
r, _, err := procSendInput.Call(1, uintptr(unsafe.Pointer(&inp)), unsafe.Sizeof(inp))
|
||||
if r == 0 {
|
||||
log.Tracef("SendInput(mouse flags=0x%x): %v", flags, err)
|
||||
}
|
||||
}
|
||||
|
||||
func sendKeyInput(vk uint16, scanCode uint16, flags uint32) {
|
||||
ki := keybdInput{
|
||||
WVk: vk,
|
||||
WScan: scanCode,
|
||||
DwFlags: flags,
|
||||
}
|
||||
inp := winInput{Type: inputKeyboard}
|
||||
copy(inp.Data[:], (*[unsafe.Sizeof(ki)]byte)(unsafe.Pointer(&ki))[:])
|
||||
r, _, err := procSendInput.Call(1, uintptr(unsafe.Pointer(&inp)), unsafe.Sizeof(inp))
|
||||
if r == 0 {
|
||||
log.Tracef("SendInput(key vk=0x%x): %v", vk, err)
|
||||
}
|
||||
}
|
||||
|
||||
const sasEventName = `Global\NetBirdVNC_SAS`
|
||||
|
||||
type inputCmd struct {
|
||||
isKey bool
|
||||
isScancode bool
|
||||
isClipboard bool
|
||||
isType bool
|
||||
keysym uint32
|
||||
scancode uint32
|
||||
down bool
|
||||
buttonMask uint16
|
||||
x, y int
|
||||
serverW int
|
||||
serverH int
|
||||
clipText string
|
||||
}
|
||||
|
||||
// WindowsInputInjector delivers input events from a dedicated OS thread that
|
||||
// calls switchToInputDesktop before each injection. SendInput targets the
|
||||
// calling thread's desktop, so the injection thread must be on the same
|
||||
// desktop the user sees.
|
||||
type WindowsInputInjector struct {
|
||||
ch chan inputCmd
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
prevButtonMask uint16
|
||||
// lastQueuedButtonMask is the most recent buttonMask submitted to ch
|
||||
// by InjectPointer. Compared against the incoming sample to decide
|
||||
// whether the new event is move-only (lossy enqueue) or carries a
|
||||
// button/wheel transition (reliable enqueue).
|
||||
lastQueuedButtonMask uint16
|
||||
lastQueuedMaskValid bool
|
||||
queueMu sync.Mutex
|
||||
ctrlDown bool
|
||||
altDown bool
|
||||
}
|
||||
|
||||
// NewWindowsInputInjector creates a desktop-aware input injector.
|
||||
func NewWindowsInputInjector() *WindowsInputInjector {
|
||||
w := &WindowsInputInjector{
|
||||
ch: make(chan inputCmd, 64),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
go w.loop()
|
||||
return w
|
||||
}
|
||||
|
||||
// Close stops the injector loop. Safe to call multiple times. Subsequent
|
||||
// Inject*/SetClipboard/TypeText calls become no-ops; we use a separate
|
||||
// signal channel rather than closing ch so late senders can't panic.
|
||||
func (w *WindowsInputInjector) Close() {
|
||||
w.closeOnce.Do(func() {
|
||||
close(w.closed)
|
||||
})
|
||||
}
|
||||
|
||||
// tryEnqueue posts a command unless the injector is closed or the channel is
|
||||
// full. Non-blocking so callers (RFB read loop) never stall.
|
||||
func (w *WindowsInputInjector) tryEnqueue(cmd inputCmd) {
|
||||
select {
|
||||
case <-w.closed:
|
||||
case w.ch <- cmd:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// enqueueReliable posts a command and blocks until it's accepted or the
|
||||
// injector closes. Used for edge-triggered events (button/wheel) where a
|
||||
// drop would desynchronize prevButtonMask in dispatch().
|
||||
func (w *WindowsInputInjector) enqueueReliable(cmd inputCmd) {
|
||||
select {
|
||||
case <-w.closed:
|
||||
return
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case w.ch <- cmd:
|
||||
case <-w.closed:
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WindowsInputInjector) loop() {
|
||||
runtime.LockOSThread()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-w.closed:
|
||||
return
|
||||
case cmd := <-w.ch:
|
||||
w.dispatch(cmd)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WindowsInputInjector) dispatch(cmd inputCmd) {
|
||||
// Switch to the current input desktop so SendInput and the clipboard
|
||||
// API target the desktop the user sees. The returned name tells us
|
||||
// whether we are on the secure Winlogon desktop.
|
||||
_, _ = switchToInputDesktop()
|
||||
|
||||
switch {
|
||||
case cmd.isClipboard:
|
||||
w.doSetClipboard(cmd.clipText)
|
||||
case cmd.isType:
|
||||
w.typeUnicodeText(cmd.clipText)
|
||||
case cmd.isScancode:
|
||||
w.doInjectKeyScancode(cmd.scancode, cmd.keysym, cmd.down)
|
||||
case cmd.isKey:
|
||||
w.doInjectKey(cmd.keysym, cmd.down)
|
||||
default:
|
||||
w.doInjectPointer(cmd.buttonMask, cmd.x, cmd.y, cmd.serverW, cmd.serverH)
|
||||
}
|
||||
}
|
||||
|
||||
// InjectKey queues a key event for injection on the input desktop thread.
|
||||
func (w *WindowsInputInjector) InjectKey(keysym uint32, down bool) {
|
||||
w.tryEnqueue(inputCmd{isKey: true, keysym: keysym, down: down})
|
||||
}
|
||||
|
||||
// InjectKeyScancode queues a raw-scancode key event. PC AT Set 1 maps
|
||||
// directly onto what SendInput's KEYEVENTF_SCANCODE flag wants, so the
|
||||
// only translation is splitting the optional 0xE0 prefix off into the
|
||||
// KEYEVENTF_EXTENDEDKEY flag. keysym is the client-provided fallback we
|
||||
// reach for if the scancode is zero.
|
||||
func (w *WindowsInputInjector) InjectKeyScancode(scancode uint32, keysym uint32, down bool) {
|
||||
if scancode == 0 {
|
||||
w.InjectKey(keysym, down)
|
||||
return
|
||||
}
|
||||
w.tryEnqueue(inputCmd{isScancode: true, scancode: scancode, keysym: keysym, down: down})
|
||||
}
|
||||
|
||||
// InjectPointer queues a pointer event for injection on the input desktop
|
||||
// thread. Move-only updates use lossy enqueue (next sample carries fresher
|
||||
// position anyway), but any sample whose buttonMask differs from the last
|
||||
// queued mask is enqueued reliably so wheel ticks and button transitions
|
||||
// can't be dropped under backpressure.
|
||||
func (w *WindowsInputInjector) InjectPointer(buttonMask uint16, x, y, serverW, serverH int) {
|
||||
cmd := inputCmd{buttonMask: buttonMask, x: x, y: y, serverW: serverW, serverH: serverH}
|
||||
w.queueMu.Lock()
|
||||
transition := !w.lastQueuedMaskValid || w.lastQueuedButtonMask != buttonMask
|
||||
w.lastQueuedButtonMask = buttonMask
|
||||
w.lastQueuedMaskValid = true
|
||||
w.queueMu.Unlock()
|
||||
if transition {
|
||||
w.enqueueReliable(cmd)
|
||||
return
|
||||
}
|
||||
w.tryEnqueue(cmd)
|
||||
}
|
||||
|
||||
// doInjectKeyScancode injects a key event using the QEMU scancode directly,
|
||||
// bypassing the keysym→VK lookup. Windows accepts PC AT Set 1 scancodes
|
||||
// natively via KEYEVENTF_SCANCODE, so the only work is splitting the
|
||||
// optional 0xE0 prefix off into the EXTENDEDKEY flag and tracking
|
||||
// modifier state for the SAS Ctrl+Alt+Del shortcut.
|
||||
func (w *WindowsInputInjector) doInjectKeyScancode(scancode, keysym uint32, down bool) {
|
||||
switch keysym {
|
||||
case 0xffe3, 0xffe4:
|
||||
w.ctrlDown = down
|
||||
case 0xffe9, 0xffea:
|
||||
w.altDown = down
|
||||
}
|
||||
if (keysym == 0xff9f || keysym == 0xffff) && w.ctrlDown && w.altDown && down {
|
||||
signalSAS()
|
||||
return
|
||||
}
|
||||
flags := uint32(keyeventfScanCode)
|
||||
if !down {
|
||||
flags |= keyeventfKeyUp
|
||||
}
|
||||
if qemuScancodeIsExtended(scancode) {
|
||||
flags |= keyeventfExtendedKey
|
||||
}
|
||||
sendKeyInput(0, qemuScancodeLowByte(scancode), flags)
|
||||
}
|
||||
|
||||
func (w *WindowsInputInjector) doInjectKey(keysym uint32, down bool) {
|
||||
switch keysym {
|
||||
case 0xffe3, 0xffe4:
|
||||
w.ctrlDown = down
|
||||
case 0xffe9, 0xffea:
|
||||
w.altDown = down
|
||||
}
|
||||
|
||||
if (keysym == 0xff9f || keysym == 0xffff) && w.ctrlDown && w.altDown && down {
|
||||
signalSAS()
|
||||
return
|
||||
}
|
||||
|
||||
vk, _, extended := keysym2VK(keysym)
|
||||
if vk == 0 {
|
||||
return
|
||||
}
|
||||
var flags uint32
|
||||
if !down {
|
||||
flags |= keyeventfKeyUp
|
||||
}
|
||||
if extended {
|
||||
flags |= keyeventfExtendedKey
|
||||
}
|
||||
sendKeyInput(vk, 0, flags)
|
||||
}
|
||||
|
||||
// signalSAS signals the SAS named event. A listener in Session 0
|
||||
// (startSASListener) calls SendSAS to trigger the Secure Attention Sequence.
|
||||
func signalSAS() {
|
||||
namePtr, err := windows.UTF16PtrFromString(sasEventName)
|
||||
if err != nil {
|
||||
log.Warnf("SAS UTF16: %v", err)
|
||||
return
|
||||
}
|
||||
h, _, lerr := procOpenEventW.Call(
|
||||
uintptr(eventModifyState),
|
||||
0,
|
||||
uintptr(unsafe.Pointer(namePtr)),
|
||||
)
|
||||
if h == 0 {
|
||||
log.Warnf("OpenEvent(%s): %v", sasEventName, lerr)
|
||||
return
|
||||
}
|
||||
ev := windows.Handle(h)
|
||||
defer func() { _ = windows.CloseHandle(ev) }()
|
||||
if err := windows.SetEvent(ev); err != nil {
|
||||
log.Warnf("SetEvent SAS: %v", err)
|
||||
} else {
|
||||
log.Info("SAS event signaled")
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WindowsInputInjector) doInjectPointer(buttonMask uint16, x, y, serverW, serverH int) {
|
||||
if serverW == 0 || serverH == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
absX := int32(x * 65535 / serverW)
|
||||
absY := int32(y * 65535 / serverH)
|
||||
|
||||
sendMouseInput(mouseeventfMove|mouseeventfAbsolute, absX, absY, 0)
|
||||
|
||||
changed := buttonMask ^ w.prevButtonMask
|
||||
w.prevButtonMask = buttonMask
|
||||
|
||||
type btnMap struct {
|
||||
bit uint16
|
||||
down uint32
|
||||
up uint32
|
||||
}
|
||||
buttons := [...]btnMap{
|
||||
{0x01, mouseeventfLeftDown, mouseeventfLeftUp},
|
||||
{0x02, mouseeventfMiddleDown, mouseeventfMiddleUp},
|
||||
{0x04, mouseeventfRightDown, mouseeventfRightUp},
|
||||
}
|
||||
for _, b := range buttons {
|
||||
if changed&b.bit == 0 {
|
||||
continue
|
||||
}
|
||||
var flags uint32
|
||||
if buttonMask&b.bit != 0 {
|
||||
flags = b.down
|
||||
} else {
|
||||
flags = b.up
|
||||
}
|
||||
sendMouseInput(flags|mouseeventfAbsolute, absX, absY, 0)
|
||||
}
|
||||
|
||||
negWheelDelta := ^uint32(wheelDelta - 1)
|
||||
if changed&0x08 != 0 && buttonMask&0x08 != 0 {
|
||||
sendMouseInput(mouseeventfWheel|mouseeventfAbsolute, absX, absY, wheelDelta)
|
||||
}
|
||||
if changed&0x10 != 0 && buttonMask&0x10 != 0 {
|
||||
sendMouseInput(mouseeventfWheel|mouseeventfAbsolute, absX, absY, negWheelDelta)
|
||||
}
|
||||
|
||||
// XBUTTON1/back at bit 7, XBUTTON2/forward at bit 8. SendInput
|
||||
// MOUSEEVENTF_X{DOWN,UP} carries the X button number in dwData.
|
||||
xbuttons := [...]struct {
|
||||
bit uint16
|
||||
data uint32
|
||||
}{
|
||||
{1 << 7, xButton1},
|
||||
{1 << 8, xButton2},
|
||||
}
|
||||
for _, b := range xbuttons {
|
||||
if changed&b.bit == 0 {
|
||||
continue
|
||||
}
|
||||
var flags uint32 = mouseeventfXUp
|
||||
if buttonMask&b.bit != 0 {
|
||||
flags = mouseeventfXDown
|
||||
}
|
||||
sendMouseInput(flags|mouseeventfAbsolute, absX, absY, b.data)
|
||||
}
|
||||
}
|
||||
|
||||
// keysym2VK converts an X11 keysym to a Windows virtual key code.
|
||||
func keysym2VK(keysym uint32) (vk uint16, scan uint16, extended bool) {
|
||||
if keysym >= 0x20 && keysym <= 0x7e {
|
||||
r, _, _ := procVkKeyScanA.Call(uintptr(keysym))
|
||||
vk = uint16(r & 0xff)
|
||||
return
|
||||
}
|
||||
|
||||
if keysym >= 0xffbe && keysym <= 0xffc9 {
|
||||
vk = uint16(0x70 + keysym - 0xffbe)
|
||||
return
|
||||
}
|
||||
|
||||
switch keysym {
|
||||
case 0xff08:
|
||||
vk = 0x08 // Backspace
|
||||
case 0xff09:
|
||||
vk = 0x09 // Tab
|
||||
case 0xff0d:
|
||||
vk = 0x0d // Return
|
||||
case 0xff1b:
|
||||
vk = 0x1b // Escape
|
||||
case 0xff63:
|
||||
vk, extended = 0x2d, true // Insert
|
||||
case 0xff9f, 0xffff:
|
||||
vk, extended = 0x2e, true // Delete
|
||||
case 0xff50:
|
||||
vk, extended = 0x24, true // Home
|
||||
case 0xff57:
|
||||
vk, extended = 0x23, true // End
|
||||
case 0xff55:
|
||||
vk, extended = 0x21, true // PageUp
|
||||
case 0xff56:
|
||||
vk, extended = 0x22, true // PageDown
|
||||
case 0xff51:
|
||||
vk, extended = 0x25, true // Left
|
||||
case 0xff52:
|
||||
vk, extended = 0x26, true // Up
|
||||
case 0xff53:
|
||||
vk, extended = 0x27, true // Right
|
||||
case 0xff54:
|
||||
vk, extended = 0x28, true // Down
|
||||
case 0xffe1, 0xffe2:
|
||||
vk = 0x10 // Shift
|
||||
case 0xffe3, 0xffe4:
|
||||
vk = 0x11 // Control
|
||||
case 0xffe9, 0xffea:
|
||||
vk = 0x12 // Alt
|
||||
case 0xffe5:
|
||||
vk = 0x14 // CapsLock
|
||||
case 0xffe7, 0xffeb:
|
||||
vk, extended = 0x5B, true // Meta_L / Super_L -> Left Windows
|
||||
case 0xffe8, 0xffec:
|
||||
vk, extended = 0x5C, true // Meta_R / Super_R -> Right Windows
|
||||
case 0xff61:
|
||||
vk = 0x2c // PrintScreen
|
||||
case 0xff13:
|
||||
vk = 0x13 // Pause
|
||||
case 0xff14:
|
||||
vk = 0x91 // ScrollLock
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
procOpenClipboard = user32.NewProc("OpenClipboard")
|
||||
procCloseClipboard = user32.NewProc("CloseClipboard")
|
||||
procEmptyClipboard = user32.NewProc("EmptyClipboard")
|
||||
procSetClipboardData = user32.NewProc("SetClipboardData")
|
||||
procGetClipboardData = user32.NewProc("GetClipboardData")
|
||||
procIsClipboardFormatAvailable = user32.NewProc("IsClipboardFormatAvailable")
|
||||
|
||||
procGlobalAlloc = kernel32.NewProc("GlobalAlloc")
|
||||
procGlobalLock = kernel32.NewProc("GlobalLock")
|
||||
procGlobalUnlock = kernel32.NewProc("GlobalUnlock")
|
||||
procGlobalFree = kernel32.NewProc("GlobalFree")
|
||||
)
|
||||
|
||||
const (
|
||||
cfUnicodeText = 13
|
||||
gmemMoveable = 0x0002
|
||||
)
|
||||
|
||||
// SetClipboard queues a request to update the Windows clipboard with the
|
||||
// given UTF-8 text. The work runs on the input thread so it follows the
|
||||
// current input desktop. Secure desktops (Winlogon, UAC) have isolated
|
||||
// clipboards we cannot reach, so the call is a no-op there; use TypeText
|
||||
// to enter text into a secure desktop instead.
|
||||
func (w *WindowsInputInjector) SetClipboard(text string) {
|
||||
w.tryEnqueue(inputCmd{isClipboard: true, clipText: text})
|
||||
}
|
||||
|
||||
// TypeText queues a request to synthesize the given text as Unicode
|
||||
// keystrokes on the current input desktop. Targets the secure desktop
|
||||
// when the user is on Winlogon/UAC, where the clipboard is unreachable.
|
||||
func (w *WindowsInputInjector) TypeText(text string) {
|
||||
w.tryEnqueue(inputCmd{isType: true, clipText: text})
|
||||
}
|
||||
|
||||
func (w *WindowsInputInjector) doSetClipboard(text string) {
|
||||
utf16, err := windows.UTF16FromString(text)
|
||||
if err != nil {
|
||||
log.Tracef("clipboard UTF16 encode: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
size := uintptr(len(utf16) * 2)
|
||||
hMem, _, _ := procGlobalAlloc.Call(gmemMoveable, size)
|
||||
if hMem == 0 {
|
||||
log.Tracef("GlobalAlloc for clipboard: allocation returned nil")
|
||||
return
|
||||
}
|
||||
|
||||
ptr, _, _ := procGlobalLock.Call(hMem)
|
||||
if ptr == 0 {
|
||||
log.Tracef("GlobalLock for clipboard: lock returned nil")
|
||||
_, _, _ = procGlobalFree.Call(hMem)
|
||||
return
|
||||
}
|
||||
copy(unsafe.Slice((*uint16)(unsafe.Pointer(ptr)), len(utf16)), utf16)
|
||||
_, _, _ = procGlobalUnlock.Call(hMem)
|
||||
|
||||
r, _, lerr := procOpenClipboard.Call(0)
|
||||
if r == 0 {
|
||||
log.Tracef("OpenClipboard: %v", lerr)
|
||||
_, _, _ = procGlobalFree.Call(hMem)
|
||||
return
|
||||
}
|
||||
defer logCleanupCall("CloseClipboard", procCloseClipboard)
|
||||
|
||||
_, _, _ = procEmptyClipboard.Call()
|
||||
r, _, lerr = procSetClipboardData.Call(cfUnicodeText, hMem)
|
||||
if r == 0 {
|
||||
log.Tracef("SetClipboardData: %v", lerr)
|
||||
// Ownership only transfers to the OS on success; on failure we
|
||||
// still own hMem and must free it.
|
||||
_, _, _ = procGlobalFree.Call(hMem)
|
||||
}
|
||||
}
|
||||
|
||||
// typeUnicodeText synthesizes the given text as Unicode keystrokes via
|
||||
// SendInput+KEYEVENTF_UNICODE. Used on the Winlogon secure desktop where the
|
||||
// clipboard is isolated: this lets a VNC client paste a password into the
|
||||
// login or credential prompt by sending ClientCutText.
|
||||
func (w *WindowsInputInjector) typeUnicodeText(text string) {
|
||||
utf16, err := windows.UTF16FromString(text)
|
||||
if err != nil {
|
||||
log.Tracef("clipboard UTF16 encode: %v", err)
|
||||
return
|
||||
}
|
||||
if len(utf16) > 0 && utf16[len(utf16)-1] == 0 {
|
||||
utf16 = utf16[:len(utf16)-1]
|
||||
}
|
||||
if len(utf16) > maxTypedClipboardChars {
|
||||
log.Warnf("clipboard paste on Winlogon truncated to %d chars", maxTypedClipboardChars)
|
||||
utf16 = utf16[:maxTypedClipboardChars]
|
||||
}
|
||||
for _, c := range utf16 {
|
||||
sendKeyInput(0, c, keyeventfUnicode)
|
||||
sendKeyInput(0, c, keyeventfUnicode|keyeventfKeyUp)
|
||||
}
|
||||
}
|
||||
|
||||
// GetClipboard reads the Windows clipboard as UTF-8 text.
|
||||
func (w *WindowsInputInjector) GetClipboard() string {
|
||||
r, _, _ := procIsClipboardFormatAvailable.Call(cfUnicodeText)
|
||||
if r == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
r, _, lerr := procOpenClipboard.Call(0)
|
||||
if r == 0 {
|
||||
log.Tracef("OpenClipboard for read: %v", lerr)
|
||||
return ""
|
||||
}
|
||||
defer logCleanupCall("CloseClipboard", procCloseClipboard)
|
||||
|
||||
hData, _, _ := procGetClipboardData.Call(cfUnicodeText)
|
||||
if hData == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
ptr, _, _ := procGlobalLock.Call(hData)
|
||||
if ptr == 0 {
|
||||
return ""
|
||||
}
|
||||
defer logCleanupCallArgs("GlobalUnlock", procGlobalUnlock, hData)
|
||||
|
||||
return windows.UTF16PtrToString((*uint16)(unsafe.Pointer(ptr)))
|
||||
}
|
||||
|
||||
var _ InputInjector = (*WindowsInputInjector)(nil)
|
||||
|
||||
var _ ScreenCapturer = (*DesktopCapturer)(nil)
|
||||
312
client/vnc/server/input_x11.go
Normal file
312
client/vnc/server/input_x11.go
Normal file
@@ -0,0 +1,312 @@
|
||||
//go:build unix && !darwin && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/jezek/xgb"
|
||||
"github.com/jezek/xgb/xproto"
|
||||
"github.com/jezek/xgb/xtest"
|
||||
)
|
||||
|
||||
// X11InputInjector injects keyboard and mouse events via the XTest extension.
|
||||
type X11InputInjector struct {
|
||||
conn *xgb.Conn
|
||||
root xproto.Window
|
||||
screen *xproto.ScreenInfo
|
||||
display string
|
||||
keysymMap map[uint32]byte
|
||||
lastButtons uint16
|
||||
clipboardTool string
|
||||
clipboardToolName string
|
||||
}
|
||||
|
||||
// NewX11InputInjector connects to the X11 display and initializes XTest.
|
||||
func NewX11InputInjector(display string) (*X11InputInjector, error) {
|
||||
detectX11Display()
|
||||
|
||||
if display == "" {
|
||||
display = os.Getenv(envDisplay)
|
||||
}
|
||||
if display == "" {
|
||||
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
|
||||
}
|
||||
|
||||
conn, err := xgb.NewConnDisplay(display)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
|
||||
}
|
||||
|
||||
if err := xtest.Init(conn); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("init XTest extension: %w", err)
|
||||
}
|
||||
|
||||
setup := xproto.Setup(conn)
|
||||
if len(setup.Roots) == 0 {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("no X11 screens")
|
||||
}
|
||||
screen := setup.Roots[0]
|
||||
|
||||
inj := &X11InputInjector{
|
||||
conn: conn,
|
||||
root: screen.Root,
|
||||
screen: &screen,
|
||||
display: display,
|
||||
}
|
||||
inj.cacheKeyboardMapping()
|
||||
inj.resolveClipboardTool()
|
||||
|
||||
log.Infof("X11 input injector ready (display=%s)", display)
|
||||
return inj, nil
|
||||
}
|
||||
|
||||
// InjectKey simulates a key press or release. keysym is an X11 KeySym.
|
||||
func (x *X11InputInjector) InjectKey(keysym uint32, down bool) {
|
||||
keycode := x.keysymToKeycode(keysym)
|
||||
if keycode == 0 {
|
||||
return
|
||||
}
|
||||
x.fakeKeyEvent(keycode, down)
|
||||
}
|
||||
|
||||
// InjectKeyScancode injects using the QEMU scancode by translating to a
|
||||
// Linux KEY_ code and then to an X11 keycode (KEY_* + xkbKeycodeOffset).
|
||||
// On a server running a standard XKB keymap this is layout-independent:
|
||||
// the scancode names the physical key, the server's layout determines the
|
||||
// resulting character. Falls back to the keysym path when the scancode
|
||||
// has no Linux mapping.
|
||||
func (x *X11InputInjector) InjectKeyScancode(scancode, keysym uint32, down bool) {
|
||||
linuxKey := qemuScancodeToLinuxKey(scancode)
|
||||
if linuxKey == 0 {
|
||||
x.InjectKey(keysym, down)
|
||||
return
|
||||
}
|
||||
x.fakeKeyEvent(byte(linuxKey+xkbKeycodeOffset), down)
|
||||
}
|
||||
|
||||
// xkbKeycodeOffset is the per-server constant offset between Linux KEY_*
|
||||
// event codes and the X server's keycode space under XKB. The X protocol
|
||||
// reserves keycodes 0..7 for internal use, so any normal XKB keymap
|
||||
// starts at 8 (KEY_ESC=1 → X keycode 9, KEY_A=30 → X keycode 38, etc.).
|
||||
const xkbKeycodeOffset = 8
|
||||
|
||||
// fakeKeyEvent sends an XTest FakeInput for a press or release.
|
||||
func (x *X11InputInjector) fakeKeyEvent(keycode byte, down bool) {
|
||||
var eventType byte
|
||||
if down {
|
||||
eventType = xproto.KeyPress
|
||||
} else {
|
||||
eventType = xproto.KeyRelease
|
||||
}
|
||||
xtest.FakeInput(x.conn, eventType, keycode, 0, x.root, 0, 0, 0)
|
||||
}
|
||||
|
||||
// InjectPointer simulates mouse movement and button events.
|
||||
func (x *X11InputInjector) InjectPointer(buttonMask uint16, px, py, serverW, serverH int) {
|
||||
if serverW == 0 || serverH == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Scale to actual screen coordinates.
|
||||
screenW := int(x.screen.WidthInPixels)
|
||||
screenH := int(x.screen.HeightInPixels)
|
||||
absX := px * screenW / serverW
|
||||
absY := py * screenH / serverH
|
||||
|
||||
// Move pointer.
|
||||
xtest.FakeInput(x.conn, xproto.MotionNotify, 0, 0, x.root, int16(absX), int16(absY), 0)
|
||||
|
||||
// Handle button events. RFB button mask: bit0=left, bit1=middle, bit2=right,
|
||||
// bit3=scrollUp, bit4=scrollDown. X11 buttons: 1=left, 2=middle, 3=right,
|
||||
// 4=scrollUp, 5=scrollDown.
|
||||
type btnMap struct {
|
||||
rfbBit uint16
|
||||
x11Btn byte
|
||||
}
|
||||
// X11 button numbers: 1=left, 2=middle, 3=right, 4/5=scroll up/down,
|
||||
// 6/7=scroll left/right (skipped), 8=back, 9=forward.
|
||||
buttons := [...]btnMap{
|
||||
{0x01, 1},
|
||||
{0x02, 2},
|
||||
{0x04, 3},
|
||||
{0x08, 4},
|
||||
{0x10, 5},
|
||||
{1 << 7, 8},
|
||||
{1 << 8, 9},
|
||||
}
|
||||
|
||||
for _, b := range buttons {
|
||||
pressed := buttonMask&b.rfbBit != 0
|
||||
wasPressed := x.lastButtons&b.rfbBit != 0
|
||||
if b.x11Btn == 4 || b.x11Btn == 5 {
|
||||
// Scroll: send press+release on each scroll event.
|
||||
if pressed {
|
||||
xtest.FakeInput(x.conn, xproto.ButtonPress, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||
xtest.FakeInput(x.conn, xproto.ButtonRelease, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||
}
|
||||
} else {
|
||||
if pressed && !wasPressed {
|
||||
xtest.FakeInput(x.conn, xproto.ButtonPress, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||
} else if !pressed && wasPressed {
|
||||
xtest.FakeInput(x.conn, xproto.ButtonRelease, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
x.lastButtons = buttonMask
|
||||
}
|
||||
|
||||
// cacheKeyboardMapping fetches the X11 keyboard mapping once and stores it
|
||||
// as a keysym-to-keycode map, avoiding a round-trip per keystroke.
|
||||
func (x *X11InputInjector) cacheKeyboardMapping() {
|
||||
setup := xproto.Setup(x.conn)
|
||||
minKeycode := setup.MinKeycode
|
||||
maxKeycode := setup.MaxKeycode
|
||||
|
||||
reply, err := xproto.GetKeyboardMapping(x.conn, minKeycode,
|
||||
byte(maxKeycode-minKeycode+1)).Reply()
|
||||
if err != nil {
|
||||
log.Debugf("cache keyboard mapping: %v", err)
|
||||
x.keysymMap = make(map[uint32]byte)
|
||||
return
|
||||
}
|
||||
|
||||
m := make(map[uint32]byte, int(maxKeycode-minKeycode+1)*int(reply.KeysymsPerKeycode))
|
||||
keysymsPerKeycode := int(reply.KeysymsPerKeycode)
|
||||
for i := int(minKeycode); i <= int(maxKeycode); i++ {
|
||||
offset := (i - int(minKeycode)) * keysymsPerKeycode
|
||||
for j := 0; j < keysymsPerKeycode; j++ {
|
||||
ks := uint32(reply.Keysyms[offset+j])
|
||||
if ks != 0 {
|
||||
if _, exists := m[ks]; !exists {
|
||||
m[ks] = byte(i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
x.keysymMap = m
|
||||
}
|
||||
|
||||
// keysymToKeycode looks up a cached keysym-to-keycode mapping.
|
||||
// Returns 0 if the keysym is not mapped.
|
||||
func (x *X11InputInjector) keysymToKeycode(keysym uint32) byte {
|
||||
return x.keysymMap[keysym]
|
||||
}
|
||||
|
||||
// SetClipboard sets the X11 clipboard using xclip or xsel.
|
||||
func (x *X11InputInjector) SetClipboard(text string) {
|
||||
if x.clipboardTool == "" {
|
||||
return
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if x.clipboardToolName == "xclip" {
|
||||
cmd = exec.Command(x.clipboardTool, "-selection", "clipboard")
|
||||
} else {
|
||||
cmd = exec.Command(x.clipboardTool, "--clipboard", "--input")
|
||||
}
|
||||
cmd.Env = x.clipboardEnv()
|
||||
cmd.Stdin = strings.NewReader(text)
|
||||
if err := cmd.Run(); err != nil {
|
||||
log.Debugf("set clipboard via %s: %v", x.clipboardToolName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// TypeText synthesizes the given text as keystrokes via XTest. Used in
|
||||
// places where the focused application isn't clipboard-aware (e.g. a TTY
|
||||
// login in an X11 session, an SDDM/GDM password field that ignores
|
||||
// XSelection, or a kiosk app), so stuffing the X clipboard and relying on
|
||||
// Ctrl+V would not reach the input.
|
||||
//
|
||||
// Limitation: only ASCII printable characters are typed. Non-ASCII runes
|
||||
// are skipped: a paste workflow for them needs Wayland-aware text input
|
||||
// or layout introspection that this path does not implement.
|
||||
func (x *X11InputInjector) TypeText(text string) {
|
||||
const maxChars = 4096
|
||||
count := 0
|
||||
for _, r := range text {
|
||||
if count >= maxChars {
|
||||
break
|
||||
}
|
||||
count++
|
||||
keysym, shift, ok := keysymForASCIIRune(r)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
keycode := x.keysymToKeycode(keysym)
|
||||
if keycode == 0 {
|
||||
continue
|
||||
}
|
||||
var shiftCode byte
|
||||
if shift {
|
||||
shiftCode = x.keysymToKeycode(0xffe1) // Shift_L
|
||||
if shiftCode != 0 {
|
||||
xtest.FakeInput(x.conn, xproto.KeyPress, shiftCode, 0, x.root, 0, 0, 0)
|
||||
}
|
||||
}
|
||||
xtest.FakeInput(x.conn, xproto.KeyPress, keycode, 0, x.root, 0, 0, 0)
|
||||
xtest.FakeInput(x.conn, xproto.KeyRelease, keycode, 0, x.root, 0, 0, 0)
|
||||
if shift && shiftCode != 0 {
|
||||
xtest.FakeInput(x.conn, xproto.KeyRelease, shiftCode, 0, x.root, 0, 0, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (x *X11InputInjector) resolveClipboardTool() {
|
||||
for _, name := range []string{"xclip", "xsel"} {
|
||||
path, err := exec.LookPath(name)
|
||||
if err == nil {
|
||||
x.clipboardTool = path
|
||||
x.clipboardToolName = name
|
||||
log.Debugf("clipboard tool resolved to %s", path)
|
||||
return
|
||||
}
|
||||
}
|
||||
log.Debugf("no clipboard tool (xclip/xsel) found, clipboard sync disabled")
|
||||
}
|
||||
|
||||
// GetClipboard reads the X11 clipboard using xclip or xsel.
|
||||
func (x *X11InputInjector) GetClipboard() string {
|
||||
if x.clipboardTool == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if x.clipboardToolName == "xclip" {
|
||||
cmd = exec.Command(x.clipboardTool, "-selection", "clipboard", "-o")
|
||||
} else {
|
||||
cmd = exec.Command(x.clipboardTool, "--clipboard", "--output")
|
||||
}
|
||||
cmd.Env = x.clipboardEnv()
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
// Exit status 1 just means there is no STRING selection set yet,
|
||||
// which is the steady state on a fresh Xvfb session, logging it
|
||||
// every clipboard poll (2s) floods the trace stream.
|
||||
return ""
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
|
||||
func (x *X11InputInjector) clipboardEnv() []string {
|
||||
env := []string{envDisplay + "=" + x.display}
|
||||
if auth := os.Getenv(envXAuthority); auth != "" {
|
||||
env = append(env, envXAuthority+"="+auth)
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// Close releases X11 resources.
|
||||
func (x *X11InputInjector) Close() {
|
||||
x.conn.Close()
|
||||
}
|
||||
|
||||
var _ InputInjector = (*X11InputInjector)(nil)
|
||||
var _ ScreenCapturer = (*X11Poller)(nil)
|
||||
73
client/vnc/server/keysym_typetext.go
Normal file
73
client/vnc/server/keysym_typetext.go
Normal file
@@ -0,0 +1,73 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
|
||||
// keysymForASCIIRune maps an ASCII rune to (X11 keysym for the unshifted
|
||||
// version, needsShift). Used by TypeText implementations on each platform
|
||||
// so the caller can explicitly press Shift instead of relying on the
|
||||
// server-side modifier state. Returns ok=false for runes outside the
|
||||
// supported set; non-ASCII text is dropped by TypeText.
|
||||
func keysymForASCIIRune(r rune) (uint32, bool, bool) {
|
||||
if r >= 'a' && r <= 'z' {
|
||||
return uint32(r), false, true
|
||||
}
|
||||
if r >= 'A' && r <= 'Z' {
|
||||
return uint32(r - 'A' + 'a'), true, true
|
||||
}
|
||||
if r >= '0' && r <= '9' {
|
||||
return uint32(r), false, true
|
||||
}
|
||||
switch r {
|
||||
case ' ':
|
||||
return 0x20, false, true
|
||||
case '\n', '\r':
|
||||
return 0xff0d, false, true // Return
|
||||
case '\t':
|
||||
return 0xff09, false, true // Tab
|
||||
case '-', '=', '[', ']', '\\', ';', '\'', '`', ',', '.', '/':
|
||||
return uint32(r), false, true
|
||||
case '!':
|
||||
return '1', true, true
|
||||
case '@':
|
||||
return '2', true, true
|
||||
case '#':
|
||||
return '3', true, true
|
||||
case '$':
|
||||
return '4', true, true
|
||||
case '%':
|
||||
return '5', true, true
|
||||
case '^':
|
||||
return '6', true, true
|
||||
case '&':
|
||||
return '7', true, true
|
||||
case '*':
|
||||
return '8', true, true
|
||||
case '(':
|
||||
return '9', true, true
|
||||
case ')':
|
||||
return '0', true, true
|
||||
case '_':
|
||||
return '-', true, true
|
||||
case '+':
|
||||
return '=', true, true
|
||||
case '{':
|
||||
return '[', true, true
|
||||
case '}':
|
||||
return ']', true, true
|
||||
case '|':
|
||||
return '\\', true, true
|
||||
case ':':
|
||||
return ';', true, true
|
||||
case '"':
|
||||
return '\'', true, true
|
||||
case '~':
|
||||
return '`', true, true
|
||||
case '<':
|
||||
return ',', true, true
|
||||
case '>':
|
||||
return '.', true, true
|
||||
case '?':
|
||||
return '/', true, true
|
||||
}
|
||||
return 0, false, false
|
||||
}
|
||||
225
client/vnc/server/metrics_conn.go
Normal file
225
client/vnc/server/metrics_conn.go
Normal file
@@ -0,0 +1,225 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SessionTick is one sampling slice of a VNC session's wire activity.
|
||||
// BytesOut / Writes / FBUs are deltas observed during this tick;
|
||||
// Max* fields are the high-water marks observed during this tick (reset
|
||||
// at the start of the next). Period is the wall-clock duration covered
|
||||
// (typically sessionTickInterval, shorter for the final flush).
|
||||
type SessionTick struct {
|
||||
Period time.Duration
|
||||
BytesOut uint64
|
||||
Writes uint64
|
||||
FBUs uint64
|
||||
MaxFBUBytes uint64
|
||||
MaxFBURects uint64
|
||||
MaxWriteBytes uint64
|
||||
WriteNanos uint64
|
||||
}
|
||||
|
||||
// sessionTickInterval is how often metricsConn emits a SessionTick. One
|
||||
// second covers roughly one FBU round-trip at typical client request
|
||||
// cadences during steady-state activity.
|
||||
const sessionTickInterval = time.Second
|
||||
|
||||
// metricsConn wraps a net.Conn and tracks per-session byte / write / FBU
|
||||
// counters. Updates are atomic so the cost is a few atomic ops per Write
|
||||
// (well under 100 ns), negligible against the syscall itself, so the wrap
|
||||
// is always installed. A goroutine emits a SessionTick to the recorder
|
||||
// every sessionTickInterval (only when the tick has activity to report);
|
||||
// a final partial-tick flush runs on Close.
|
||||
type metricsConn struct {
|
||||
net.Conn
|
||||
|
||||
recorder func(SessionTick)
|
||||
|
||||
bytesOut atomic.Uint64
|
||||
writes atomic.Uint64
|
||||
writeNanos atomic.Uint64
|
||||
largestPkt atomic.Uint64
|
||||
fbus atomic.Uint64
|
||||
fbuBytes atomic.Uint64
|
||||
fbuRects atomic.Uint64
|
||||
maxFBUBytes atomic.Uint64
|
||||
maxFBURects atomic.Uint64
|
||||
|
||||
tickMu sync.Mutex
|
||||
tickStart time.Time
|
||||
tickPrevB uint64
|
||||
tickPrevW uint64
|
||||
tickPrevF uint64
|
||||
tickPrevNS uint64
|
||||
|
||||
// busyMu guards the sliding window used by BusyFraction.
|
||||
busyMu sync.Mutex
|
||||
busyLastTime time.Time
|
||||
busyLastNanos uint64
|
||||
busyFraction float64
|
||||
|
||||
closeOnce sync.Once
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newMetricsConn(c net.Conn, recorder func(SessionTick)) net.Conn {
|
||||
m := &metricsConn{
|
||||
Conn: c,
|
||||
recorder: recorder,
|
||||
tickStart: time.Now(),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
if recorder != nil {
|
||||
go m.tickLoop()
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// tickLoop emits a SessionTick every sessionTickInterval until done.
|
||||
// Empty ticks (no writes since the last tick) are skipped.
|
||||
func (m *metricsConn) tickLoop() {
|
||||
t := time.NewTicker(sessionTickInterval)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-m.done:
|
||||
return
|
||||
case <-t.C:
|
||||
m.flushTick(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// flushTick computes deltas since the last tick, resets the per-tick max
|
||||
// trackers, and emits a SessionTick to the recorder. final=true forces
|
||||
// emission even if no writes happened (used at session close to record
|
||||
// the trailing partial period).
|
||||
func (m *metricsConn) flushTick(final bool) {
|
||||
m.tickMu.Lock()
|
||||
defer m.tickMu.Unlock()
|
||||
|
||||
b := m.bytesOut.Load()
|
||||
w := m.writes.Load()
|
||||
f := m.fbus.Load()
|
||||
ns := m.writeNanos.Load()
|
||||
|
||||
db := b - m.tickPrevB
|
||||
dw := w - m.tickPrevW
|
||||
df := f - m.tickPrevF
|
||||
dns := ns - m.tickPrevNS
|
||||
m.tickPrevB, m.tickPrevW, m.tickPrevF, m.tickPrevNS = b, w, f, ns
|
||||
|
||||
maxFBU := m.maxFBUBytes.Swap(0)
|
||||
maxRects := m.maxFBURects.Swap(0)
|
||||
maxPkt := m.largestPkt.Swap(0)
|
||||
|
||||
period := time.Since(m.tickStart)
|
||||
m.tickStart = time.Now()
|
||||
|
||||
if dw == 0 && !final {
|
||||
return
|
||||
}
|
||||
m.recorder(SessionTick{
|
||||
Period: period,
|
||||
BytesOut: db,
|
||||
Writes: dw,
|
||||
FBUs: df,
|
||||
MaxFBUBytes: maxFBU,
|
||||
MaxFBURects: maxRects,
|
||||
MaxWriteBytes: maxPkt,
|
||||
WriteNanos: dns,
|
||||
})
|
||||
}
|
||||
|
||||
// BusyFraction reports the fraction of recent wall time that Write spent
|
||||
// blocked in the underlying socket, as an exponentially smoothed value in
|
||||
// [0, 1]. Approximates downstream backpressure: persistent values near 1
|
||||
// mean the socket cannot keep up with the encoder's output. Callers can
|
||||
// throttle JPEG quality or skip frames in response.
|
||||
func (m *metricsConn) BusyFraction() float64 {
|
||||
now := time.Now()
|
||||
ns := m.writeNanos.Load()
|
||||
|
||||
m.busyMu.Lock()
|
||||
defer m.busyMu.Unlock()
|
||||
if m.busyLastTime.IsZero() {
|
||||
m.busyLastTime = now
|
||||
m.busyLastNanos = ns
|
||||
return 0
|
||||
}
|
||||
period := now.Sub(m.busyLastTime)
|
||||
if period < 50*time.Millisecond {
|
||||
return m.busyFraction
|
||||
}
|
||||
delta := ns - m.busyLastNanos
|
||||
sample := float64(delta) / float64(period.Nanoseconds())
|
||||
if sample > 1 {
|
||||
sample = 1
|
||||
}
|
||||
const alpha = 0.4
|
||||
m.busyFraction = alpha*sample + (1-alpha)*m.busyFraction
|
||||
m.busyLastTime = now
|
||||
m.busyLastNanos = ns
|
||||
return m.busyFraction
|
||||
}
|
||||
|
||||
// isFBUHeader reports whether the given Write payload is the 4-byte
|
||||
// FramebufferUpdate header (message type 0, padding 0, rect-count high
|
||||
// byte). Rect bodies are written separately by sendDirtyAndMoves, so the
|
||||
// FBU/rect boundary lines up with Write boundaries.
|
||||
func isFBUHeader(p []byte) bool {
|
||||
return len(p) == 4 && p[0] == serverFramebufferUpdate
|
||||
}
|
||||
|
||||
func (m *metricsConn) Write(p []byte) (int, error) {
|
||||
if isFBUHeader(p) {
|
||||
if b := m.fbuBytes.Swap(0); b > 0 {
|
||||
if b > m.maxFBUBytes.Load() {
|
||||
m.maxFBUBytes.Store(b)
|
||||
}
|
||||
}
|
||||
if r := m.fbuRects.Swap(0); r > 0 {
|
||||
if r > m.maxFBURects.Load() {
|
||||
m.maxFBURects.Store(r)
|
||||
}
|
||||
}
|
||||
m.fbus.Add(1)
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
n, err := m.Conn.Write(p)
|
||||
m.writeNanos.Add(uint64(time.Since(t0).Nanoseconds()))
|
||||
m.bytesOut.Add(uint64(n))
|
||||
m.writes.Add(1)
|
||||
if !isFBUHeader(p) {
|
||||
m.fbuBytes.Add(uint64(n))
|
||||
m.fbuRects.Add(1)
|
||||
}
|
||||
if uint64(n) > m.largestPkt.Load() {
|
||||
m.largestPkt.Store(uint64(n))
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (m *metricsConn) Close() error {
|
||||
m.closeOnce.Do(func() {
|
||||
close(m.done)
|
||||
if m.recorder == nil {
|
||||
return
|
||||
}
|
||||
if b := m.fbuBytes.Swap(0); b > m.maxFBUBytes.Load() {
|
||||
m.maxFBUBytes.Store(b)
|
||||
}
|
||||
if r := m.fbuRects.Swap(0); r > m.maxFBURects.Load() {
|
||||
m.maxFBURects.Store(r)
|
||||
}
|
||||
m.flushTick(true)
|
||||
})
|
||||
return m.Conn.Close()
|
||||
}
|
||||
432
client/vnc/server/noise_auth_test.go
Normal file
432
client/vnc/server/noise_auth_test.go
Normal file
@@ -0,0 +1,432 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
// noiseTestServer starts a VNC server with a freshly generated identity
|
||||
// key and returns the listener address, the server, and the server's
|
||||
// static public key for client-side handshake setup.
|
||||
func noiseTestServer(t *testing.T) (net.Addr, *Server, []byte) {
|
||||
t.Helper()
|
||||
|
||||
kp, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, kp.Private)
|
||||
srv.SetDisableAuth(false)
|
||||
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
require.NoError(t, srv.Start(t.Context(), addr, network))
|
||||
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
||||
t.Cleanup(func() { _ = srv.Stop() })
|
||||
|
||||
return srv.listener.Addr(), srv, kp.Public
|
||||
}
|
||||
|
||||
// registerSessionKey enrolls a fresh X25519 keypair under the given user
|
||||
// ID into the server's authorizer with the requested OS-user wildcard
|
||||
// mapping. Returns the keypair so the test can drive the handshake.
|
||||
func registerSessionKey(t *testing.T, srv *Server, userID string) noise.DHKey {
|
||||
t.Helper()
|
||||
|
||||
kp, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
userHash, err := sshuserhash.HashUserID(userID)
|
||||
require.NoError(t, err)
|
||||
|
||||
srv.UpdateVNCAuth(&sshauth.Config{
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{userHash},
|
||||
MachineUsers: map[string][]uint32{sshauth.Wildcard: {0}},
|
||||
SessionPubKeys: []sshauth.SessionPubKey{
|
||||
{PubKey: kp.Public, UserIDHash: userHash},
|
||||
},
|
||||
})
|
||||
return kp
|
||||
}
|
||||
|
||||
// writeHeaderPrefix writes the mode + zero-length-username prefix that
|
||||
// precedes the optional Noise handshake in the NetBird VNC header.
|
||||
func writeHeaderPrefix(t *testing.T, conn net.Conn, mode byte) {
|
||||
t.Helper()
|
||||
prefix := []byte{mode, 0, 0}
|
||||
_, err := conn.Write(prefix)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// writeHeaderTail writes the sessionID/width/height fields that follow
|
||||
// either the Noise msg2 (auth path) or the prefix alone (no-auth path).
|
||||
func writeHeaderTail(t *testing.T, conn net.Conn) {
|
||||
t.Helper()
|
||||
tail := make([]byte, 8)
|
||||
_, err := conn.Write(tail)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// performInitiator drives the initiator side of Noise_IK against the
|
||||
// server's identity public key, returns the resulting state. The Noise
|
||||
// msg2 produced by the server is read and consumed.
|
||||
func performInitiator(t *testing.T, conn net.Conn, clientKey noise.DHKey, serverPub []byte) {
|
||||
t.Helper()
|
||||
|
||||
state, err := noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: vncNoiseSuite,
|
||||
Pattern: noise.HandshakeIK,
|
||||
Initiator: true,
|
||||
StaticKeypair: clientKey,
|
||||
PeerStatic: serverPub,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
msg1, _, _, err := state.WriteMessage(nil, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, noiseInitiatorMsgLen, len(msg1))
|
||||
|
||||
_, err = conn.Write(append([]byte("NBV3"), msg1...))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
|
||||
msg2 := make([]byte, noiseResponderMsgLen)
|
||||
_, err = io.ReadFull(conn, msg2)
|
||||
require.NoError(t, err)
|
||||
_, _, _, err = state.ReadMessage(nil, msg2)
|
||||
require.NoError(t, err, "server responder message must decrypt with the correct peer static")
|
||||
}
|
||||
|
||||
// readRFBFailure consumes the RFB version exchange and returns the
|
||||
// security-failure reason string. Fails the test if the server did not
|
||||
// send a failure (i.e. produced a non-zero security-types list).
|
||||
func readRFBFailure(t *testing.T, conn net.Conn) string {
|
||||
t.Helper()
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
|
||||
|
||||
var ver [12]byte
|
||||
_, err := io.ReadFull(conn, ver[:])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "RFB 003.008\n", string(ver[:]))
|
||||
|
||||
_, err = conn.Write(ver[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
var n [1]byte
|
||||
_, err = io.ReadFull(conn, n[:])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, byte(0), n[0], "expected security-failure (0 types)")
|
||||
|
||||
var rl [4]byte
|
||||
_, err = io.ReadFull(conn, rl[:])
|
||||
require.NoError(t, err)
|
||||
reason := make([]byte, binary.BigEndian.Uint32(rl[:]))
|
||||
_, err = io.ReadFull(conn, reason)
|
||||
require.NoError(t, err)
|
||||
return string(reason)
|
||||
}
|
||||
|
||||
// readRFBGreetingNoFailure asserts the server proceeded past auth: it
|
||||
// must offer at least one security type rather than a 0 failure.
|
||||
func readRFBGreetingNoFailure(t *testing.T, conn net.Conn) {
|
||||
t.Helper()
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
|
||||
|
||||
var ver [12]byte
|
||||
_, err := io.ReadFull(conn, ver[:])
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "RFB 003.008\n", string(ver[:]))
|
||||
|
||||
_, err = conn.Write(ver[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
var n [1]byte
|
||||
_, err = io.ReadFull(conn, n[:])
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, byte(0), n[0], "server must offer security types after a valid handshake")
|
||||
}
|
||||
|
||||
// TestNoise_RegisteredKey_AccessGranted exercises the happy path: a
|
||||
// session key enrolled in the authorizer completes a Noise_IK handshake
|
||||
// and the server proceeds to the RFB greeting.
|
||||
func TestNoise_RegisteredKey_AccessGranted(t *testing.T) {
|
||||
addr, srv, serverPub := noiseTestServer(t)
|
||||
clientKey := registerSessionKey(t, srv, "alice@example")
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
performInitiator(t, conn, clientKey, serverPub)
|
||||
writeHeaderTail(t, conn)
|
||||
|
||||
readRFBGreetingNoFailure(t, conn)
|
||||
}
|
||||
|
||||
// TestNoise_UnregisteredClientStatic_Rejected proves the authorizer is
|
||||
// consulted: a syntactically-valid handshake from a key the server has
|
||||
// never been told about must be rejected fail-closed.
|
||||
func TestNoise_UnregisteredClientStatic_Rejected(t *testing.T) {
|
||||
addr, _, serverPub := noiseTestServer(t)
|
||||
// Auth is enabled but the authorizer was not updated, so the lookup
|
||||
// path returns ErrSessionKeyNotKnown.
|
||||
attackerKey, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
performInitiator(t, conn, attackerKey, serverPub)
|
||||
writeHeaderTail(t, conn)
|
||||
|
||||
reason := readRFBFailure(t, conn)
|
||||
assert.Contains(t, reason, RejectCodeAuthForbidden)
|
||||
assert.Contains(t, reason, "session pubkey not registered")
|
||||
}
|
||||
|
||||
// TestNoise_WrongServerStatic_HandshakeFails proves the server's
|
||||
// identity is bound into the handshake: an initiator using the wrong
|
||||
// peer static encrypts msg1 under keys the real server can't derive, so
|
||||
// the server fails the handshake and closes without RFB output.
|
||||
func TestNoise_WrongServerStatic_HandshakeFails(t *testing.T) {
|
||||
addr, srv, _ := noiseTestServer(t)
|
||||
clientKey := registerSessionKey(t, srv, "alice@example")
|
||||
|
||||
bogusServerKey, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
|
||||
state, err := noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: vncNoiseSuite,
|
||||
Pattern: noise.HandshakeIK,
|
||||
Initiator: true,
|
||||
StaticKeypair: clientKey,
|
||||
PeerStatic: bogusServerKey.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
msg1, _, _, err := state.WriteMessage(nil, nil)
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Write(append([]byte("NBV3"), msg1...))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
|
||||
var b [1]byte
|
||||
_, err = io.ReadFull(conn, b[:])
|
||||
require.Error(t, err, "server must close without RFB greeting when msg1 is sealed for a different server identity")
|
||||
}
|
||||
|
||||
// TestNoise_MalformedMsg1_ClosesConnection covers the case where the
|
||||
// magic prefix is correct but the following 96 bytes are random: the
|
||||
// noise library fails ReadMessage and the server closes silently.
|
||||
func TestNoise_MalformedMsg1_ClosesConnection(t *testing.T) {
|
||||
addr, _, _ := noiseTestServer(t)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
junk := make([]byte, noiseInitiatorMsgLen)
|
||||
for i := range junk {
|
||||
junk[i] = byte(i)
|
||||
}
|
||||
_, err = conn.Write(append([]byte("NBV3"), junk...))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
|
||||
var b [1]byte
|
||||
_, err = io.ReadFull(conn, b[:])
|
||||
require.Error(t, err, "garbage msg1 must terminate the connection before any RFB output")
|
||||
}
|
||||
|
||||
// TestNoise_TruncatedMsg1_ClosesConnection sends fewer than the 96
|
||||
// bytes a Noise_IK msg1 must contain. The server's io.ReadFull short-
|
||||
// reads and closes; no RFB greeting must leak.
|
||||
func TestNoise_TruncatedMsg1_ClosesConnection(t *testing.T) {
|
||||
addr, _, _ := noiseTestServer(t)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
_, err = conn.Write([]byte("NBV3"))
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Write(make([]byte, 8))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, conn.(*net.TCPConn).CloseWrite())
|
||||
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(2*time.Second)))
|
||||
buf := make([]byte, 64)
|
||||
n, err := conn.Read(buf)
|
||||
require.Equal(t, 0, n, "server must not emit RFB bytes after a truncated handshake")
|
||||
require.ErrorIs(t, err, io.EOF, "server must close the connection on truncated msg1")
|
||||
}
|
||||
|
||||
// TestNoise_AuthEnabled_NoHandshake_Rejected proves that with auth on,
|
||||
// a connection that skips the Noise prefix (older client / VNC client)
|
||||
// is rejected with AUTH_FORBIDDEN: identity proof missing.
|
||||
func TestNoise_AuthEnabled_NoHandshake_Rejected(t *testing.T) {
|
||||
addr, _, _ := noiseTestServer(t)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
writeHeaderTail(t, conn)
|
||||
|
||||
reason := readRFBFailure(t, conn)
|
||||
assert.Contains(t, reason, RejectCodeAuthForbidden)
|
||||
assert.Contains(t, reason, "identity proof missing")
|
||||
}
|
||||
|
||||
// TestNoise_RevokedKey_RejectedAfterAuthUpdate verifies the authorizer
|
||||
// honors revocations: a key that worked before a UpdateVNCAuth call
|
||||
// must stop working as soon as the new config omits it.
|
||||
func TestNoise_RevokedKey_RejectedAfterAuthUpdate(t *testing.T) {
|
||||
addr, srv, serverPub := noiseTestServer(t)
|
||||
clientKey := registerSessionKey(t, srv, "alice@example")
|
||||
|
||||
// First connection succeeds.
|
||||
conn1, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn1.Close()
|
||||
writeHeaderPrefix(t, conn1, ModeAttach)
|
||||
performInitiator(t, conn1, clientKey, serverPub)
|
||||
writeHeaderTail(t, conn1)
|
||||
readRFBGreetingNoFailure(t, conn1)
|
||||
|
||||
// Revoke by pushing a fresh config that drops the pubkey entry.
|
||||
srv.UpdateVNCAuth(&sshauth.Config{})
|
||||
|
||||
// Same client, same Noise key, should now be denied.
|
||||
conn2, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn2.Close()
|
||||
writeHeaderPrefix(t, conn2, ModeAttach)
|
||||
performInitiator(t, conn2, clientKey, serverPub)
|
||||
writeHeaderTail(t, conn2)
|
||||
|
||||
reason := readRFBFailure(t, conn2)
|
||||
assert.Contains(t, reason, RejectCodeAuthForbidden)
|
||||
assert.Contains(t, reason, "session pubkey not registered")
|
||||
}
|
||||
|
||||
// TestNoise_NoIdentityKey_FailsClosed ensures a server constructed
|
||||
// without a static private key still rejects authenticated connections
|
||||
// fail-closed; it must not silently accept the client.
|
||||
func TestNoise_NoIdentityKey_FailsClosed(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(false)
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
require.NoError(t, srv.Start(t.Context(), addr, network))
|
||||
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
||||
t.Cleanup(func() { _ = srv.Stop() })
|
||||
|
||||
clientKey, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
fakeServerKey, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := net.Dial("tcp", srv.listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
writeHeaderPrefix(t, conn, ModeAttach)
|
||||
|
||||
state, err := noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: vncNoiseSuite,
|
||||
Pattern: noise.HandshakeIK,
|
||||
Initiator: true,
|
||||
StaticKeypair: clientKey,
|
||||
PeerStatic: fakeServerKey.Public,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
msg1, _, _, err := state.WriteMessage(nil, nil)
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Write(append([]byte("NBV3"), msg1...))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(5*time.Second)))
|
||||
var b [1]byte
|
||||
_, err = io.ReadFull(conn, b[:])
|
||||
require.Error(t, err, "server without identity key must not write the RFB greeting")
|
||||
}
|
||||
|
||||
// TestNoise_DerivedIdentityPublicMatchesPrivate sanity-checks the
|
||||
// derivation done in New(): the identityPublic must be Curve25519.
|
||||
// Basepoint multiplied with identityKey.
|
||||
func TestNoise_DerivedIdentityPublicMatchesPrivate(t *testing.T) {
|
||||
priv := make([]byte, 32)
|
||||
for i := range priv {
|
||||
priv[i] = byte(i + 1)
|
||||
}
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, priv)
|
||||
|
||||
expected, err := curve25519.X25519(priv, curve25519.Basepoint)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, srv.identityPublic)
|
||||
}
|
||||
|
||||
// TestNoise_SessionMode_OSUserCheckRunsAfterHandshake verifies that a
|
||||
// successful Noise handshake doesn't bypass OS-user authorization: an
|
||||
// authenticated key whose user index isn't mapped to the requested OS
|
||||
// user must be rejected.
|
||||
func TestNoise_SessionMode_OSUserCheckRunsAfterHandshake(t *testing.T) {
|
||||
addr, srv, serverPub := noiseTestServer(t)
|
||||
|
||||
clientKey, err := noise.DH25519.GenerateKeypair(nil)
|
||||
require.NoError(t, err)
|
||||
userHash, err := sshuserhash.HashUserID("alice@example")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Map Alice only to "alice" OS user, not the wildcard.
|
||||
srv.UpdateVNCAuth(&sshauth.Config{
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{userHash},
|
||||
MachineUsers: map[string][]uint32{"alice": {0}},
|
||||
SessionPubKeys: []sshauth.SessionPubKey{
|
||||
{PubKey: clientKey.Public, UserIDHash: userHash},
|
||||
},
|
||||
})
|
||||
|
||||
// Request session for "bob" — Noise succeeds, OS-user check denies.
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
bob := []byte("bob")
|
||||
prefix := []byte{ModeSession, 0, byte(len(bob))}
|
||||
prefix = append(prefix, bob...)
|
||||
_, err = conn.Write(prefix)
|
||||
require.NoError(t, err)
|
||||
|
||||
performInitiator(t, conn, clientKey, serverPub)
|
||||
writeHeaderTail(t, conn)
|
||||
|
||||
reason := readRFBFailure(t, conn)
|
||||
assert.Contains(t, reason, RejectCodeAuthForbidden)
|
||||
assert.Contains(t, reason, "authorize OS user")
|
||||
}
|
||||
59
client/vnc/server/pseudo_encodings_test.go
Normal file
59
client/vnc/server/pseudo_encodings_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestEncodeDesktopSizeBody(t *testing.T) {
|
||||
got := encodeDesktopSizeBody(1920, 1080)
|
||||
if len(got) != 12 {
|
||||
t.Fatalf("DesktopSize body length: want 12, got %d", len(got))
|
||||
}
|
||||
if got[0] != 0 || got[1] != 0 || got[2] != 0 || got[3] != 0 {
|
||||
t.Fatalf("DesktopSize: x and y must be zero; got % x", got[0:4])
|
||||
}
|
||||
if got[4] != 0x07 || got[5] != 0x80 {
|
||||
t.Fatalf("DesktopSize: width should be 1920 (0x0780); got % x", got[4:6])
|
||||
}
|
||||
if got[6] != 0x04 || got[7] != 0x38 {
|
||||
t.Fatalf("DesktopSize: height should be 1080 (0x0438); got % x", got[6:8])
|
||||
}
|
||||
// Encoding = -223 → 0xFFFFFF21 in two's complement big-endian.
|
||||
if got[8] != 0xFF || got[9] != 0xFF || got[10] != 0xFF || got[11] != 0x21 {
|
||||
t.Fatalf("DesktopSize: encoding bytes wrong: % x", got[8:12])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeDesktopNameBody(t *testing.T) {
|
||||
name := "vma@debian3"
|
||||
got := encodeDesktopNameBody(name)
|
||||
if len(got) != 12+4+len(name) {
|
||||
t.Fatalf("DesktopName body length: want %d, got %d", 12+4+len(name), len(got))
|
||||
}
|
||||
// Encoding = -307 → 0xFFFFFECD.
|
||||
if got[8] != 0xFF || got[9] != 0xFF || got[10] != 0xFE || got[11] != 0xCD {
|
||||
t.Fatalf("DesktopName: encoding bytes wrong: % x", got[8:12])
|
||||
}
|
||||
if got[12] != 0 || got[13] != 0 || got[14] != 0 || got[15] != byte(len(name)) {
|
||||
t.Fatalf("DesktopName: name length prefix wrong: % x", got[12:16])
|
||||
}
|
||||
if string(got[16:]) != name {
|
||||
t.Fatalf("DesktopName: name body wrong: %q", got[16:])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeLastRectBody(t *testing.T) {
|
||||
got := encodeLastRectBody()
|
||||
if len(got) != 12 {
|
||||
t.Fatalf("LastRect body length: want 12, got %d", len(got))
|
||||
}
|
||||
for i := 0; i < 8; i++ {
|
||||
if got[i] != 0 {
|
||||
t.Fatalf("LastRect: header bytes 0..7 must be zero; got byte %d = 0x%02x", i, got[i])
|
||||
}
|
||||
}
|
||||
// Encoding = -224 → 0xFFFFFF20.
|
||||
if got[8] != 0xFF || got[9] != 0xFF || got[10] != 0xFF || got[11] != 0x20 {
|
||||
t.Fatalf("LastRect: encoding bytes wrong: % x", got[8:12])
|
||||
}
|
||||
}
|
||||
806
client/vnc/server/rfb.go
Normal file
806
client/vnc/server/rfb.go
Normal file
@@ -0,0 +1,806 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/zlib"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/jpeg"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// rect describes a rectangle on the framebuffer in pixels.
|
||||
type rect struct {
|
||||
x, y, w, h int
|
||||
}
|
||||
|
||||
const (
|
||||
rfbProtocolVersion = "RFB 003.008\n"
|
||||
|
||||
secNone = 1
|
||||
|
||||
// Client message types.
|
||||
clientSetPixelFormat = 0
|
||||
clientSetEncodings = 2
|
||||
clientFramebufferUpdateRequest = 3
|
||||
clientKeyEvent = 4
|
||||
clientPointerEvent = 5
|
||||
clientCutText = 6
|
||||
// clientQEMUMessage is the QEMU vendor message wrapper. The subtype
|
||||
// byte that follows selects the actual operation; we only handle the
|
||||
// Extended Key Event (subtype 0) which carries a hardware scancode in
|
||||
// addition to the X11 keysym. Layout-independent key entry.
|
||||
clientQEMUMessage = 255
|
||||
|
||||
// QEMU Extended Key Event subtype carried inside clientQEMUMessage.
|
||||
qemuSubtypeExtendedKeyEvent = 0
|
||||
|
||||
// clientNetbirdTypeText is a NetBird-specific message that asks the
|
||||
// server to synthesize the given text as keystrokes regardless of the
|
||||
// active desktop. Lets a client push host clipboard content into a
|
||||
// Windows secure desktop (Winlogon, UAC), where the OS clipboard is
|
||||
// isolated. Format mirrors clientCutText: 1-byte message type + 3-byte
|
||||
// padding + 4-byte length + text bytes. The opcode is in the
|
||||
// vendor-specific range (>=128).
|
||||
clientNetbirdTypeText = 250
|
||||
|
||||
// clientNetbirdShowRemoteCursor toggles "show remote cursor" mode.
|
||||
// When enabled the encoder composites the server cursor sprite into
|
||||
// the captured framebuffer and suppresses the Cursor pseudo-encoding
|
||||
// so the client sees a single pointer at the remote position.
|
||||
// Wire format: 1-byte msgType + 1-byte enable flag + 6 padding bytes
|
||||
// reserved for future arguments (so the message is fixed-size).
|
||||
clientNetbirdShowRemoteCursor = 251
|
||||
|
||||
// Server message types.
|
||||
serverFramebufferUpdate = 0
|
||||
serverCutText = 3
|
||||
|
||||
// Encoding types.
|
||||
encRaw = 0
|
||||
encCopyRect = 1
|
||||
encHextile = 5
|
||||
encZlib = 6
|
||||
encTight = 7
|
||||
|
||||
// Pseudo-encodings carried over wire as rects with a negative
|
||||
// encoding value. The client advertises supported optional protocol
|
||||
// extensions by listing these in SetEncodings.
|
||||
pseudoEncCursor = -239
|
||||
pseudoEncDesktopSize = -223
|
||||
pseudoEncLastRect = -224
|
||||
pseudoEncQEMUExtendedKeyEvent = -258
|
||||
pseudoEncDesktopName = -307
|
||||
pseudoEncExtendedDesktopSize = -308
|
||||
pseudoEncExtendedMouseButtons = -316
|
||||
|
||||
// Quality/Compression level pseudo-encodings. The client picks one
|
||||
// value from each range to tune JPEG quality and zlib effort. 0 is
|
||||
// lowest quality / fastest, 9 is highest quality / best compression.
|
||||
pseudoEncQualityLevelMin = -32
|
||||
pseudoEncQualityLevelMax = -23
|
||||
pseudoEncCompressLevelMin = -256
|
||||
pseudoEncCompressLevelMax = -247
|
||||
|
||||
// Hextile sub-encoding bits used by the SolidFill fast path.
|
||||
hextileBackgroundSpecified = 0x02
|
||||
hextileSubSize = 16
|
||||
|
||||
// Tight compression-control byte top nibble. Stream-reset bits 0-3
|
||||
// (one per zlib stream) are unused while we run a single stream.
|
||||
tightFillSubenc = 0x80
|
||||
tightJPEGSubenc = 0x90
|
||||
tightBasicFilter = 0x40 // Bit 6 set = explicit filter byte follows.
|
||||
tightFilterCopy = 0x00 // No-op filter, raw pixel stream.
|
||||
|
||||
// JPEG quality used by the Tight encoder. 70 is a reasonable speed/
|
||||
// quality knee; bandwidth roughly halves vs raw RGB while staying
|
||||
// visually clean for typical desktop content. Large rects (e.g. a
|
||||
// fullscreen video region) drop to a lower quality so the encoder
|
||||
// keeps up at 30+ fps; the visual hit is small for moving content.
|
||||
tightJPEGQuality = 70
|
||||
tightJPEGQualityMedium = 55
|
||||
tightJPEGQualityLarge = 40
|
||||
tightJPEGMediumPixels = 800 * 600 // ≈ SVGA, applies medium tier
|
||||
tightJPEGLargePixels = 1280 * 720 // ≈ 720p, applies large tier
|
||||
// Minimum rect area before we consider JPEG. Below this, header
|
||||
// overhead dominates and Basic+zlib wins.
|
||||
tightJPEGMinArea = 4096 // 64×64 ≈ 1 tile
|
||||
// Distinct-colour cap below which we still prefer Basic+zlib (text,
|
||||
// UI). Sampled, not exhaustive: cheap to compute, good enough.
|
||||
tightJPEGMinColors = 64
|
||||
)
|
||||
|
||||
// serverPixelFormat is the pixel format the server advertises and requires:
|
||||
// 32bpp RGBA, little-endian, true-colour, 8 bits per channel at standard
|
||||
// shifts (R=16, G=8, B=0). handleSetPixelFormat rejects any client that
|
||||
// negotiates a different format. Browser-side decoders are little-endian
|
||||
// natively, so advertising little-endian skips a byte-swap on every pixel.
|
||||
var serverPixelFormat = [16]byte{
|
||||
32, // bits-per-pixel
|
||||
24, // depth
|
||||
0, // big-endian-flag
|
||||
1, // true-colour-flag
|
||||
0, 255, // red-max
|
||||
0, 255, // green-max
|
||||
0, 255, // blue-max
|
||||
16, // red-shift
|
||||
8, // green-shift
|
||||
0, // blue-shift
|
||||
0, 0, 0, // padding
|
||||
}
|
||||
|
||||
// clientPixelFormat holds the negotiated pixel format. Only RGB channel
|
||||
// shifts are tracked: every other field is constrained by the server to
|
||||
// the values in serverPixelFormat (32bpp / little-endian / truecolour /
|
||||
// 8-bit channels) and rejected at SetPixelFormat time if the client tries
|
||||
// to negotiate otherwise.
|
||||
type clientPixelFormat struct {
|
||||
rShift uint8
|
||||
gShift uint8
|
||||
bShift uint8
|
||||
}
|
||||
|
||||
func defaultClientPixelFormat() clientPixelFormat {
|
||||
return clientPixelFormat{
|
||||
rShift: serverPixelFormat[10],
|
||||
gShift: serverPixelFormat[11],
|
||||
bShift: serverPixelFormat[12],
|
||||
}
|
||||
}
|
||||
|
||||
// parsePixelFormat returns the negotiated client pixel format, or an error
|
||||
// if the client tried to negotiate an unsupported format. The server only
|
||||
// supports 32bpp truecolour little-endian with 8-bit channels; arbitrary
|
||||
// shifts within that constraint are allowed because they are cheap to honour.
|
||||
func parsePixelFormat(pf []byte) (clientPixelFormat, error) {
|
||||
bpp := pf[0]
|
||||
bigEndian := pf[2]
|
||||
trueColour := pf[3]
|
||||
rMax := binary.BigEndian.Uint16(pf[4:6])
|
||||
gMax := binary.BigEndian.Uint16(pf[6:8])
|
||||
bMax := binary.BigEndian.Uint16(pf[8:10])
|
||||
if bpp != 32 || bigEndian != 0 || trueColour != 1 ||
|
||||
rMax != 255 || gMax != 255 || bMax != 255 {
|
||||
return clientPixelFormat{}, fmt.Errorf(
|
||||
"unsupported pixel format (bpp=%d be=%d tc=%d rgb-max=%d/%d/%d): "+
|
||||
"server only supports 32bpp truecolour little-endian 8-bit channels",
|
||||
bpp, bigEndian, trueColour, rMax, gMax, bMax)
|
||||
}
|
||||
return clientPixelFormat{
|
||||
rShift: pf[10],
|
||||
gShift: pf[11],
|
||||
bShift: pf[12],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// encodeCopyRectBody emits the per-rect payload for a CopyRect rectangle:
|
||||
// the 12-byte rect header (dst position + size + encoding=1) plus a 4-byte
|
||||
// source position. Used inside multi-rect FramebufferUpdate messages, so
|
||||
// the 4-byte FU header is the caller's responsibility.
|
||||
func encodeCopyRectBody(srcX, srcY, dstX, dstY, w, h int) []byte {
|
||||
buf := make([]byte, 12+4)
|
||||
binary.BigEndian.PutUint16(buf[0:2], uint16(dstX))
|
||||
binary.BigEndian.PutUint16(buf[2:4], uint16(dstY))
|
||||
binary.BigEndian.PutUint16(buf[4:6], uint16(w))
|
||||
binary.BigEndian.PutUint16(buf[6:8], uint16(h))
|
||||
binary.BigEndian.PutUint32(buf[8:12], uint32(encCopyRect))
|
||||
binary.BigEndian.PutUint16(buf[12:14], uint16(srcX))
|
||||
binary.BigEndian.PutUint16(buf[14:16], uint16(srcY))
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeDesktopSizeBody emits a DesktopSize pseudo-encoded rectangle. The
|
||||
// "rect" carries no pixel data: x and y are zero, w and h are the new
|
||||
// framebuffer dimensions, and encoding=-223 signals to the client that the
|
||||
// framebuffer was resized. Clients reallocate their backing buffer and
|
||||
// expect a full update at the new size to follow.
|
||||
func encodeDesktopSizeBody(w, h int) []byte {
|
||||
buf := make([]byte, 12)
|
||||
binary.BigEndian.PutUint16(buf[0:2], 0)
|
||||
binary.BigEndian.PutUint16(buf[2:4], 0)
|
||||
binary.BigEndian.PutUint16(buf[4:6], uint16(w))
|
||||
binary.BigEndian.PutUint16(buf[6:8], uint16(h))
|
||||
enc := int32(pseudoEncDesktopSize)
|
||||
binary.BigEndian.PutUint32(buf[8:12], uint32(enc))
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeDesktopNameBody emits a DesktopName pseudo-encoded rectangle. The
|
||||
// rect header is all zeros and encoding=-307; the body is a 4-byte
|
||||
// big-endian length followed by the UTF-8 name. Clients update their
|
||||
// window title or label without reconnecting.
|
||||
func encodeDesktopNameBody(name string) []byte {
|
||||
nameBytes := []byte(name)
|
||||
buf := make([]byte, 12+4+len(nameBytes))
|
||||
binary.BigEndian.PutUint16(buf[0:2], 0)
|
||||
binary.BigEndian.PutUint16(buf[2:4], 0)
|
||||
binary.BigEndian.PutUint16(buf[4:6], 0)
|
||||
binary.BigEndian.PutUint16(buf[6:8], 0)
|
||||
enc := int32(pseudoEncDesktopName)
|
||||
binary.BigEndian.PutUint32(buf[8:12], uint32(enc))
|
||||
binary.BigEndian.PutUint32(buf[12:16], uint32(len(nameBytes)))
|
||||
copy(buf[16:], nameBytes)
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeLastRectBody emits a LastRect sentinel. When the server sets
|
||||
// numRects=0xFFFF in the FramebufferUpdate header, the client reads rects
|
||||
// until it sees one with this encoding. Lets us stream rects from a
|
||||
// goroutine without committing to a count up front.
|
||||
func encodeLastRectBody() []byte {
|
||||
buf := make([]byte, 12)
|
||||
// x, y, w, h all zero; encoding = -224.
|
||||
enc := int32(pseudoEncLastRect)
|
||||
binary.BigEndian.PutUint32(buf[8:12], uint32(enc))
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeRawRect encodes a framebuffer region as a raw RFB rectangle.
|
||||
// The returned buffer includes the FramebufferUpdate header (1 rectangle).
|
||||
func encodeRawRect(img *image.RGBA, pf clientPixelFormat, x, y, w, h int) []byte {
|
||||
buf := make([]byte, 4+12+w*h*4)
|
||||
|
||||
// FramebufferUpdate header.
|
||||
buf[0] = serverFramebufferUpdate
|
||||
buf[1] = 0 // padding
|
||||
binary.BigEndian.PutUint16(buf[2:4], 1)
|
||||
|
||||
// Rectangle header.
|
||||
binary.BigEndian.PutUint16(buf[4:6], uint16(x))
|
||||
binary.BigEndian.PutUint16(buf[6:8], uint16(y))
|
||||
binary.BigEndian.PutUint16(buf[8:10], uint16(w))
|
||||
binary.BigEndian.PutUint16(buf[10:12], uint16(h))
|
||||
binary.BigEndian.PutUint32(buf[12:16], uint32(encRaw))
|
||||
|
||||
writePixels(buf[16:], img, pf, rect{x, y, w, h})
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeZlibRect encodes a framebuffer region using the standalone Zlib
|
||||
// encoding. The zlib stream is continuous for the entire VNC session: the
|
||||
// client keeps a single inflate context and reuses it across rects. The
|
||||
// returned buffer includes the 4-byte FramebufferUpdate header.
|
||||
func encodeZlibRect(img *image.RGBA, pf clientPixelFormat, x, y, w, h int, z *zlibState) []byte {
|
||||
zw, zbuf := z.w, z.buf
|
||||
zbuf.Reset()
|
||||
|
||||
rowBytes := w * 4
|
||||
total := rowBytes * h
|
||||
if cap(z.scratch) < total {
|
||||
z.scratch = make([]byte, total)
|
||||
}
|
||||
scratch := z.scratch[:total]
|
||||
writePixels(scratch, img, pf, rect{x, y, w, h})
|
||||
for row := 0; row < h; row++ {
|
||||
if _, err := zw.Write(scratch[row*rowBytes : (row+1)*rowBytes]); err != nil {
|
||||
log.Debugf("zlib write row %d: %v", row, err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if err := zw.Flush(); err != nil {
|
||||
log.Debugf("zlib flush: %v", err)
|
||||
return nil
|
||||
}
|
||||
compressed := zbuf.Bytes()
|
||||
|
||||
buf := make([]byte, 4+12+4+len(compressed))
|
||||
buf[0] = serverFramebufferUpdate
|
||||
binary.BigEndian.PutUint16(buf[2:4], 1)
|
||||
binary.BigEndian.PutUint16(buf[4:6], uint16(x))
|
||||
binary.BigEndian.PutUint16(buf[6:8], uint16(y))
|
||||
binary.BigEndian.PutUint16(buf[8:10], uint16(w))
|
||||
binary.BigEndian.PutUint16(buf[10:12], uint16(h))
|
||||
binary.BigEndian.PutUint32(buf[12:16], uint32(encZlib))
|
||||
binary.BigEndian.PutUint32(buf[16:20], uint32(len(compressed)))
|
||||
copy(buf[20:], compressed)
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeHextileSolidRect emits a Hextile-encoded rectangle whose every
|
||||
// pixel is the same colour. The first sub-tile carries the background
|
||||
// pixel; remaining sub-tiles inherit it via a zero sub-encoding byte,
|
||||
// collapsing a uniform 64×64 tile down to ~20 bytes. The returned buffer
|
||||
// starts with the 12-byte rect header; callers prepend a FramebufferUpdate
|
||||
// header.
|
||||
func encodeHextileSolidRect(r, g, b byte, pf clientPixelFormat, rc rect) []byte {
|
||||
cols := (rc.w + hextileSubSize - 1) / hextileSubSize
|
||||
rows := (rc.h + hextileSubSize - 1) / hextileSubSize
|
||||
subs := cols * rows
|
||||
// One sub-encoding byte plus a 32bpp pixel for the first sub-tile, then
|
||||
// one zero byte per remaining sub-tile to inherit the background.
|
||||
bodySize := 1 + 4 + (subs - 1)
|
||||
buf := make([]byte, 12+bodySize)
|
||||
|
||||
binary.BigEndian.PutUint16(buf[0:2], uint16(rc.x))
|
||||
binary.BigEndian.PutUint16(buf[2:4], uint16(rc.y))
|
||||
binary.BigEndian.PutUint16(buf[4:6], uint16(rc.w))
|
||||
binary.BigEndian.PutUint16(buf[6:8], uint16(rc.h))
|
||||
binary.BigEndian.PutUint32(buf[8:12], uint32(encHextile))
|
||||
|
||||
buf[12] = hextileBackgroundSpecified
|
||||
pixel := (uint32(r) << pf.rShift) | (uint32(g) << pf.gShift) | (uint32(b) << pf.bShift)
|
||||
binary.LittleEndian.PutUint32(buf[13:17], pixel)
|
||||
return buf
|
||||
}
|
||||
|
||||
// writePixels writes a rectangle of img into dst as 32bpp little-endian
|
||||
// pixels at the negotiated RGB shifts. The pixel format is constrained at
|
||||
// SetPixelFormat time so we can assume 4 bytes per pixel, 8-bit channels,
|
||||
// and little-endian byte order; arbitrary shifts (R/G/B order) are honoured.
|
||||
func writePixels(dst []byte, img *image.RGBA, pf clientPixelFormat, r rect) {
|
||||
stride := img.Stride
|
||||
rShift, gShift, bShift := pf.rShift, pf.gShift, pf.bShift
|
||||
off := 0
|
||||
for row := r.y; row < r.y+r.h; row++ {
|
||||
p := row*stride + r.x*4
|
||||
for col := 0; col < r.w; col++ {
|
||||
pixel := (uint32(img.Pix[p]) << rShift) |
|
||||
(uint32(img.Pix[p+1]) << gShift) |
|
||||
(uint32(img.Pix[p+2]) << bShift)
|
||||
binary.LittleEndian.PutUint32(dst[off:off+4], pixel)
|
||||
p += 4
|
||||
off += 4
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// diffTiles compares two RGBA images and returns a tile-ordered list of
|
||||
// dirty tiles, one entry per tile. Tile order is top-to-bottom, left-to-
|
||||
// right within each row. The caller decides whether to coalesce or hand
|
||||
// the list off to the CopyRect detector first.
|
||||
func diffTiles(prev, cur *image.RGBA, w, h, tileSize int) [][4]int {
|
||||
if prev == nil {
|
||||
return [][4]int{{0, 0, w, h}}
|
||||
}
|
||||
var rects [][4]int
|
||||
for ty := 0; ty < h; ty += tileSize {
|
||||
th := min(tileSize, h-ty)
|
||||
for tx := 0; tx < w; tx += tileSize {
|
||||
tw := min(tileSize, w-tx)
|
||||
if tileChanged(prev, cur, tx, ty, tw, th) {
|
||||
rects = append(rects, [4]int{tx, ty, tw, th})
|
||||
}
|
||||
}
|
||||
}
|
||||
return rects
|
||||
}
|
||||
|
||||
// diffRects is the legacy convenience: diff then coalesce. Used by paths
|
||||
// that don't go through the CopyRect detector and by tests that exercise
|
||||
// the diff-plus-coalesce pipeline as one unit.
|
||||
func diffRects(prev, cur *image.RGBA, w, h, tileSize int) [][4]int {
|
||||
return coalesceRects(diffTiles(prev, cur, w, h, tileSize))
|
||||
}
|
||||
|
||||
// coalesceRects merges adjacent dirty tiles into larger rectangles to cut
|
||||
// per-rect framing overhead. Input must be tile-ordered (top-to-bottom rows,
|
||||
// left-to-right within each row), as produced by diffRects. Two passes:
|
||||
// 1. Horizontal: within a row, merge tiles whose x-extents touch.
|
||||
// 2. Vertical: merge a row's run with the run directly above it when they
|
||||
// share the same [x, x+w] extent and are vertically adjacent.
|
||||
//
|
||||
// Larger merged rects still encode correctly: Hextile-solid and Zlib paths
|
||||
// both work on arbitrary sizes, and uniform-tile detection still fires when
|
||||
// the merged region happens to be a single colour.
|
||||
func coalesceRects(in [][4]int) [][4]int {
|
||||
if len(in) < 2 {
|
||||
return in
|
||||
}
|
||||
c := newRectCoalescer(len(in))
|
||||
c.curY = in[0][1]
|
||||
for _, r := range in {
|
||||
c.consume(r)
|
||||
}
|
||||
c.flushCurrentRow()
|
||||
return c.out
|
||||
}
|
||||
|
||||
// rectCoalescer is the working state for coalesceRects, lifted out so the
|
||||
// algorithm can be split across small methods without long parameter lists
|
||||
// and to keep each method's cognitive complexity below Sonar's threshold.
|
||||
type rectCoalescer struct {
|
||||
out [][4]int
|
||||
prevRowStart, prevRowEnd int
|
||||
curRowStart int
|
||||
curY int
|
||||
}
|
||||
|
||||
func newRectCoalescer(capacity int) *rectCoalescer {
|
||||
return &rectCoalescer{out: make([][4]int, 0, capacity)}
|
||||
}
|
||||
|
||||
// consume processes one rect from the (row-ordered) input.
|
||||
func (c *rectCoalescer) consume(r [4]int) {
|
||||
if r[1] != c.curY {
|
||||
c.flushCurrentRow()
|
||||
c.prevRowEnd = len(c.out)
|
||||
c.curRowStart = len(c.out)
|
||||
c.curY = r[1]
|
||||
}
|
||||
if c.tryHorizontalMerge(r) {
|
||||
return
|
||||
}
|
||||
c.out = append(c.out, r)
|
||||
}
|
||||
|
||||
// tryHorizontalMerge extends the last run in the current row when r is
|
||||
// vertically aligned and horizontally adjacent to it.
|
||||
func (c *rectCoalescer) tryHorizontalMerge(r [4]int) bool {
|
||||
if len(c.out) <= c.curRowStart {
|
||||
return false
|
||||
}
|
||||
last := &c.out[len(c.out)-1]
|
||||
if last[1] == r[1] && last[3] == r[3] && last[0]+last[2] == r[0] {
|
||||
last[2] += r[2]
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// flushCurrentRow merges each run in the current row with any run from the
|
||||
// previous row that has identical x extent and is vertically adjacent.
|
||||
func (c *rectCoalescer) flushCurrentRow() {
|
||||
i := c.curRowStart
|
||||
for i < len(c.out) {
|
||||
if c.mergeWithPrevRow(i) {
|
||||
continue
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
// mergeWithPrevRow tries to extend a previous-row run downward to absorb
|
||||
// out[i]. Returns true and removes out[i] from the slice on success.
|
||||
func (c *rectCoalescer) mergeWithPrevRow(i int) bool {
|
||||
for j := c.prevRowStart; j < c.prevRowEnd; j++ {
|
||||
if c.out[j][0] == c.out[i][0] &&
|
||||
c.out[j][2] == c.out[i][2] &&
|
||||
c.out[j][1]+c.out[j][3] == c.out[i][1] {
|
||||
c.out[j][3] += c.out[i][3]
|
||||
copy(c.out[i:], c.out[i+1:])
|
||||
c.out = c.out[:len(c.out)-1]
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func tileChanged(prev, cur *image.RGBA, x, y, w, h int) bool {
|
||||
stride := prev.Stride
|
||||
for row := y; row < y+h; row++ {
|
||||
off := row*stride + x*4
|
||||
end := off + w*4
|
||||
prevRow := prev.Pix[off:end]
|
||||
curRow := cur.Pix[off:end]
|
||||
if !bytes.Equal(prevRow, curRow) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// tileIsUniform reports whether every pixel in the given rectangle of img is
|
||||
// the same RGBA value, and returns that pixel packed as 0xRRGGBBAA when so.
|
||||
// Uses uint32 comparisons across rows; returns early on the first mismatch.
|
||||
func tileIsUniform(img *image.RGBA, x, y, w, h int) (uint32, bool) {
|
||||
if w <= 0 || h <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
stride := img.Stride
|
||||
base := y*stride + x*4
|
||||
first := *(*uint32)(unsafe.Pointer(&img.Pix[base]))
|
||||
rowBytes := w * 4
|
||||
for row := 0; row < h; row++ {
|
||||
p := base + row*stride
|
||||
for col := 0; col < rowBytes; col += 4 {
|
||||
if *(*uint32)(unsafe.Pointer(&img.Pix[p+col])) != first {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
}
|
||||
return first, true
|
||||
}
|
||||
|
||||
// tightState holds the per-session JPEG scratch buffer and reused encoders
|
||||
// so per-rect encoding stays alloc-free in the steady state.
|
||||
type tightState struct {
|
||||
jpegBuf *bytes.Buffer
|
||||
zlib *zlibState
|
||||
scratch []byte // RGB-packed pixel scratch for JPEG and Basic paths.
|
||||
// colorSeen is reused by sampledColorCount per rect; cleared via the Go
|
||||
// runtime's map-clear fast path to avoid a fresh allocation each call.
|
||||
colorSeen map[uint32]struct{}
|
||||
// jpegQualityOverride forces a fixed JPEG quality on every rect when
|
||||
// non-zero (set from the client's QualityLevel pseudo-encoding). Zero
|
||||
// falls back to the area-based tiers in tightQualityFor.
|
||||
jpegQualityOverride int
|
||||
// qualityLevel and compressLevel are the 0..9 levels currently applied,
|
||||
// or -1 if the client did not express a preference. Used to decide
|
||||
// whether a SetEncodings refresh needs to recreate the tight state.
|
||||
qualityLevel int
|
||||
compressLevel int
|
||||
// pendingZlibReset becomes true when this tightState replaces an
|
||||
// in-use one (e.g. CompressLevel change mid-session). The next Basic
|
||||
// rect we emit ORs the stream-0 reset bit into its sub-encoding byte
|
||||
// so the client's inflater drops its now-stale dictionary; cleared
|
||||
// after one emission.
|
||||
pendingZlibReset bool
|
||||
}
|
||||
|
||||
func newTightState() *tightState {
|
||||
return newTightStateWithLevels(-1, -1)
|
||||
}
|
||||
|
||||
// newTightStateWithLevels builds a tightState whose zlib stream and JPEG
|
||||
// quality reflect the client's QualityLevel / CompressLevel pseudo-encodings.
|
||||
// Pass -1 for either level to keep our defaults (BestSpeed zlib and the
|
||||
// area-tiered JPEG quality in tightQualityFor).
|
||||
func newTightStateWithLevels(qualityLevel, compressLevel int) *tightState {
|
||||
return &tightState{
|
||||
jpegBuf: &bytes.Buffer{},
|
||||
zlib: newZlibStateLevel(zlibLevelFor(compressLevel)),
|
||||
colorSeen: make(map[uint32]struct{}, 64),
|
||||
jpegQualityOverride: jpegQualityForLevel(qualityLevel),
|
||||
qualityLevel: qualityLevel,
|
||||
compressLevel: compressLevel,
|
||||
}
|
||||
}
|
||||
|
||||
// jpegQualityForLevel maps a 0..9 client preference to a JPEG quality value.
|
||||
// Returns 0 when no preference is set (-1), letting the encoder fall back
|
||||
// to the area-based tiers. The encoder lowers this dynamically when the
|
||||
// socket is backpressured, so this routine emits the unclamped, client-
|
||||
// requested value.
|
||||
func jpegQualityForLevel(level int) int {
|
||||
if level < 0 {
|
||||
return 0
|
||||
}
|
||||
if level > 9 {
|
||||
level = 9
|
||||
}
|
||||
return 30 + level*7
|
||||
}
|
||||
|
||||
// zlibLevelFor maps a 0..9 client preference to a zlib compression level.
|
||||
// Level 0 ("no compression") would emit larger output than input on most
|
||||
// rects, so we floor to BestSpeed (1). -1 (no preference) also picks
|
||||
// BestSpeed: matches the historical default before the pseudo-encoding
|
||||
// was honoured.
|
||||
func zlibLevelFor(level int) int {
|
||||
if level < 1 {
|
||||
return zlib.BestSpeed
|
||||
}
|
||||
if level > zlib.BestCompression {
|
||||
return zlib.BestCompression
|
||||
}
|
||||
return level
|
||||
}
|
||||
|
||||
// tightMaxLength is the maximum payload size representable in the Tight
|
||||
// compact length prefix (RFB §7.7.6: 22 bits, three 7+7+8 bit groups).
|
||||
// Exceeding this would silently truncate the high byte; callers must fall
|
||||
// back to a different encoding when an attempt would overflow.
|
||||
const tightMaxLength = (1 << 22) - 1
|
||||
|
||||
// encodeTightRect emits a single Tight-encoded rect. Picks Fill for uniform
|
||||
// content, JPEG for photo-like rects above a size and color-count threshold,
|
||||
// and Basic+zlib otherwise. When Tight's 22-bit length cap would be exceeded
|
||||
// (huge full-frame rects under bad compression), falls back to Raw. Returns
|
||||
// the rect header + body (no FramebufferUpdate header).
|
||||
func encodeTightRect(img *image.RGBA, pf clientPixelFormat, x, y, w, h int, t *tightState) []byte {
|
||||
if pixel, uniform := tileIsUniform(img, x, y, w, h); uniform {
|
||||
return encodeTightFill(x, y, w, h, byte(pixel), byte(pixel>>8), byte(pixel>>16))
|
||||
}
|
||||
if w*h >= tightJPEGMinArea && sampledColorCountInto(t.colorSeen, img, x, y, w, h, tightJPEGMinColors) >= tightJPEGMinColors {
|
||||
if buf, ok := encodeTightJPEG(img, x, y, w, h, t); ok {
|
||||
return buf
|
||||
}
|
||||
}
|
||||
if buf, ok := encodeTightBasic(img, x, y, w, h, t); ok {
|
||||
return buf
|
||||
}
|
||||
// Fall back to Raw rect body (skip the 4-byte FU header that encodeRawRect
|
||||
// prepends, since callers compose their own FU header).
|
||||
return encodeRawRect(img, pf, x, y, w, h)[4:]
|
||||
}
|
||||
|
||||
func writeTightRectHeader(buf []byte, x, y, w, h int) {
|
||||
binary.BigEndian.PutUint16(buf[0:2], uint16(x))
|
||||
binary.BigEndian.PutUint16(buf[2:4], uint16(y))
|
||||
binary.BigEndian.PutUint16(buf[4:6], uint16(w))
|
||||
binary.BigEndian.PutUint16(buf[6:8], uint16(h))
|
||||
binary.BigEndian.PutUint32(buf[8:12], uint32(encTight))
|
||||
}
|
||||
|
||||
// appendTightLength encodes a Tight compact length prefix (1, 2, or 3 bytes
|
||||
// LE-ish, top bit of each byte signals continuation). Lengths exceeding
|
||||
// tightMaxLength would silently truncate the high byte; callers must clamp
|
||||
// or fall back before reaching here.
|
||||
func appendTightLength(buf []byte, n int) []byte {
|
||||
if n < 0 || n > tightMaxLength {
|
||||
panic(fmt.Sprintf("tight length out of range: %d", n))
|
||||
}
|
||||
b0 := byte(n & 0x7f)
|
||||
if n <= 0x7f {
|
||||
return append(buf, b0)
|
||||
}
|
||||
b0 |= 0x80
|
||||
b1 := byte((n >> 7) & 0x7f)
|
||||
if n <= 0x3fff {
|
||||
return append(buf, b0, b1)
|
||||
}
|
||||
b1 |= 0x80
|
||||
// High group is 8 bits per spec, but our cap guarantees the top 2 bits
|
||||
// are zero; mask defensively.
|
||||
b2 := byte((n >> 14) & 0xff)
|
||||
return append(buf, b0, b1, b2)
|
||||
}
|
||||
|
||||
// encodeTightFill emits a uniform rect: 12-byte rect header + 1-byte
|
||||
// subenc (0x80) + 3-byte RGB pixel. Tight Fill always uses 24-bit RGB
|
||||
// regardless of the negotiated pixel format.
|
||||
func encodeTightFill(x, y, w, h int, r, g, b byte) []byte {
|
||||
buf := make([]byte, 12+1+3)
|
||||
writeTightRectHeader(buf, x, y, w, h)
|
||||
buf[12] = tightFillSubenc
|
||||
buf[13] = r
|
||||
buf[14] = g
|
||||
buf[15] = b
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeTightJPEG compresses the rect as a baseline JPEG. Returns ok=false
|
||||
// if the encoder errors so the caller can fall back to Basic.
|
||||
func encodeTightJPEG(img *image.RGBA, x, y, w, h int, t *tightState) ([]byte, bool) {
|
||||
t.jpegBuf.Reset()
|
||||
sub := img.SubImage(image.Rect(img.Rect.Min.X+x, img.Rect.Min.Y+y, img.Rect.Min.X+x+w, img.Rect.Min.Y+y+h))
|
||||
q := t.jpegQualityOverride
|
||||
if q == 0 {
|
||||
q = tightQualityFor(w * h)
|
||||
}
|
||||
if err := jpeg.Encode(t.jpegBuf, sub, &jpeg.Options{Quality: q}); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
jpegBytes := t.jpegBuf.Bytes()
|
||||
if len(jpegBytes) > tightMaxLength {
|
||||
return nil, false
|
||||
}
|
||||
buf := make([]byte, 0, 12+1+3+len(jpegBytes))
|
||||
buf = buf[:12]
|
||||
writeTightRectHeader(buf, x, y, w, h)
|
||||
buf = append(buf, tightJPEGSubenc)
|
||||
buf = appendTightLength(buf, len(jpegBytes))
|
||||
buf = append(buf, jpegBytes...)
|
||||
return buf, true
|
||||
}
|
||||
|
||||
// encodeTightBasic emits Basic+zlib with the no-op (CopyFilter) filter.
|
||||
// Pixels are sent as 24-bit RGB ("TPIXEL" format) which most clients
|
||||
// negotiate when the server advertises 32bpp true colour. Streams under
|
||||
// 12 bytes ship uncompressed per RFB Tight spec. Returns ok=false when the
|
||||
// compressed payload would exceed Tight's 22-bit length cap or when zlib
|
||||
// errors, signalling the caller to fall back to Raw.
|
||||
func encodeTightBasic(img *image.RGBA, x, y, w, h int, t *tightState) ([]byte, bool) {
|
||||
pixelStream := w * h * 3
|
||||
if cap(t.scratch) < pixelStream {
|
||||
t.scratch = make([]byte, pixelStream)
|
||||
}
|
||||
scratch := t.scratch[:pixelStream]
|
||||
stride := img.Stride
|
||||
off := 0
|
||||
for row := y; row < y+h; row++ {
|
||||
p := row*stride + x*4
|
||||
for col := 0; col < w; col++ {
|
||||
scratch[off+0] = img.Pix[p]
|
||||
scratch[off+1] = img.Pix[p+1]
|
||||
scratch[off+2] = img.Pix[p+2]
|
||||
p += 4
|
||||
off += 3
|
||||
}
|
||||
}
|
||||
|
||||
// Sub-encoding byte: stream 0, basic encoding (top nibble = 0x40 =
|
||||
// explicit filter follows). The low nibble carries per-stream reset
|
||||
// flags; bit 0 here tells the client to reset its stream-0 inflater
|
||||
// when our deflater was just recreated.
|
||||
subenc := byte(tightBasicFilter)
|
||||
if t.pendingZlibReset {
|
||||
subenc |= 0x01
|
||||
t.pendingZlibReset = false
|
||||
}
|
||||
filter := byte(tightFilterCopy)
|
||||
|
||||
if pixelStream < 12 {
|
||||
buf := make([]byte, 0, 12+2+pixelStream)
|
||||
buf = buf[:12]
|
||||
writeTightRectHeader(buf, x, y, w, h)
|
||||
buf = append(buf, subenc, filter)
|
||||
buf = append(buf, scratch...)
|
||||
return buf, true
|
||||
}
|
||||
|
||||
z := t.zlib
|
||||
z.buf.Reset()
|
||||
if _, err := z.w.Write(scratch); err != nil {
|
||||
log.Debugf("tight zlib write: %v", err)
|
||||
return nil, false
|
||||
}
|
||||
if err := z.w.Flush(); err != nil {
|
||||
log.Debugf("tight zlib flush: %v", err)
|
||||
return nil, false
|
||||
}
|
||||
compressed := z.buf.Bytes()
|
||||
if len(compressed) > tightMaxLength {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
buf := make([]byte, 0, 12+2+5+len(compressed))
|
||||
buf = buf[:12]
|
||||
writeTightRectHeader(buf, x, y, w, h)
|
||||
buf = append(buf, subenc, filter)
|
||||
buf = appendTightLength(buf, len(compressed))
|
||||
buf = append(buf, compressed...)
|
||||
return buf, true
|
||||
}
|
||||
|
||||
func tightQualityFor(pixels int) int {
|
||||
switch {
|
||||
case pixels >= tightJPEGLargePixels:
|
||||
return tightJPEGQualityLarge
|
||||
case pixels >= tightJPEGMediumPixels:
|
||||
return tightJPEGQualityMedium
|
||||
default:
|
||||
return tightJPEGQuality
|
||||
}
|
||||
}
|
||||
|
||||
// sampledColorCountInto estimates distinct-colour count by checking up to
|
||||
// maxColors samples. The caller-provided `seen` map is cleared and reused so
|
||||
// per-rect Tight encoding stays alloc-free. Cheap O(maxColors) per call.
|
||||
func sampledColorCountInto(seen map[uint32]struct{}, img *image.RGBA, x, y, w, h, maxColors int) int {
|
||||
clear(seen)
|
||||
stride := img.Stride
|
||||
step := max((w*h)/(maxColors*4), 1)
|
||||
var idx int
|
||||
for row := 0; row < h; row++ {
|
||||
p := (y+row)*stride + x*4
|
||||
for col := 0; col < w; col++ {
|
||||
if idx%step == 0 {
|
||||
px := *(*uint32)(unsafe.Pointer(&img.Pix[p+col*4]))
|
||||
seen[px&0x00ffffff] = struct{}{}
|
||||
if len(seen) > maxColors {
|
||||
return len(seen)
|
||||
}
|
||||
}
|
||||
idx++
|
||||
}
|
||||
}
|
||||
return len(seen)
|
||||
}
|
||||
|
||||
// zlibState holds the persistent zlib writer and its output buffer, reused
|
||||
// across rects so steady-state Tight encoding stays alloc-free.
|
||||
type zlibState struct {
|
||||
buf *bytes.Buffer
|
||||
w *zlib.Writer
|
||||
// scratch stages the packed pixel stream for a rect before it is fed
|
||||
// to the deflater. Grown to the largest rect seen in the session and
|
||||
// reused to keep the steady-state encode allocation-free.
|
||||
scratch []byte
|
||||
}
|
||||
|
||||
func newZlibStateLevel(level int) *zlibState {
|
||||
buf := &bytes.Buffer{}
|
||||
w, _ := zlib.NewWriterLevel(buf, level)
|
||||
return &zlibState{buf: buf, w: w}
|
||||
}
|
||||
|
||||
func (z *zlibState) Close() error {
|
||||
return z.w.Close()
|
||||
}
|
||||
364
client/vnc/server/rfb_bench_test.go
Normal file
364
client/vnc/server/rfb_bench_test.go
Normal file
@@ -0,0 +1,364 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"image"
|
||||
"math/rand"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Representative frame sizes.
|
||||
var benchRects = []struct {
|
||||
name string
|
||||
w, h int
|
||||
}{
|
||||
{"1080p_full", 1920, 1080},
|
||||
{"720p_full", 1280, 720},
|
||||
{"256x256_tile", 256, 256},
|
||||
{"64x64_tile", 64, 64},
|
||||
}
|
||||
|
||||
func makeBenchImage(w, h int, seed int64) *image.RGBA {
|
||||
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
r := rand.New(rand.NewSource(seed))
|
||||
_, _ = r.Read(img.Pix)
|
||||
// Force alpha byte so the fast path and slow path produce identical output.
|
||||
for i := 3; i < len(img.Pix); i += 4 {
|
||||
img.Pix[i] = 0xff
|
||||
}
|
||||
return img
|
||||
}
|
||||
|
||||
func makeBenchImagePartial(w, h, changedRows int) (*image.RGBA, *image.RGBA) {
|
||||
prev := makeBenchImage(w, h, 1)
|
||||
cur := image.NewRGBA(prev.Rect)
|
||||
copy(cur.Pix, prev.Pix)
|
||||
if changedRows > h {
|
||||
changedRows = h
|
||||
}
|
||||
// Dirty the first `changedRows` rows.
|
||||
r := rand.New(rand.NewSource(2))
|
||||
_, _ = r.Read(cur.Pix[:changedRows*cur.Stride])
|
||||
for i := 3; i < len(cur.Pix); i += 4 {
|
||||
cur.Pix[i] = 0xff
|
||||
}
|
||||
return prev, cur
|
||||
}
|
||||
|
||||
func BenchmarkEncodeRawRect(b *testing.B) {
|
||||
pf := defaultClientPixelFormat()
|
||||
for _, r := range benchRects {
|
||||
img := makeBenchImage(r.w, r.h, 1)
|
||||
b.Run(r.name, func(b *testing.B) {
|
||||
b.SetBytes(int64(r.w * r.h * 4))
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = encodeRawRect(img, pf, 0, 0, r.w, r.h)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEncodeTightRect(b *testing.B) {
|
||||
pf := defaultClientPixelFormat()
|
||||
for _, r := range benchRects {
|
||||
img := makeBenchImage(r.w, r.h, 1)
|
||||
t := newTightState()
|
||||
b.Run(r.name, func(b *testing.B) {
|
||||
b.SetBytes(int64(r.w * r.h * 4))
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = encodeTightRect(img, pf, 0, 0, r.w, r.h, t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkWritePixels isolates the per-pixel pack loop from the allocation
|
||||
// and FramebufferUpdate-header overhead.
|
||||
func BenchmarkWritePixels(b *testing.B) {
|
||||
pf := defaultClientPixelFormat()
|
||||
for _, r := range benchRects {
|
||||
img := makeBenchImage(r.w, r.h, 1)
|
||||
dst := make([]byte, r.w*r.h*4)
|
||||
b.Run(r.name, func(b *testing.B) {
|
||||
b.SetBytes(int64(r.w * r.h * 4))
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
writePixels(dst, img, pf, rect{0, 0, r.w, r.h})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSwizzleBGRAtoRGBA(b *testing.B) {
|
||||
for _, r := range benchRects {
|
||||
size := r.w * r.h * 4
|
||||
src := make([]byte, size)
|
||||
dst := make([]byte, size)
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
_, _ = rng.Read(src)
|
||||
b.Run(r.name, func(b *testing.B) {
|
||||
b.SetBytes(int64(size))
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
swizzleBGRAtoRGBA(dst, src)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkSwizzleBGRAtoRGBANaive is the naive byte-by-byte implementation
|
||||
// that the Linux SHM capturer used before the uint32 rewrite, kept here so
|
||||
// we can compare the cost directly.
|
||||
func BenchmarkSwizzleBGRAtoRGBANaive(b *testing.B) {
|
||||
for _, r := range benchRects {
|
||||
size := r.w * r.h * 4
|
||||
src := make([]byte, size)
|
||||
dst := make([]byte, size)
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
_, _ = rng.Read(src)
|
||||
b.Run(r.name, func(b *testing.B) {
|
||||
b.SetBytes(int64(size))
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
for j := 0; j < size; j += 4 {
|
||||
dst[j+0] = src[j+2]
|
||||
dst[j+1] = src[j+1]
|
||||
dst[j+2] = src[j+0]
|
||||
dst[j+3] = 0xff
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkEncodeUniformTile_TightFill measures the fast path for a uniform
|
||||
// 64×64 tile via Tight's Fill subencoding (16 wire bytes regardless of size).
|
||||
func BenchmarkEncodeUniformTile_TightFill(b *testing.B) {
|
||||
pf := defaultClientPixelFormat()
|
||||
img := image.NewRGBA(image.Rect(0, 0, 64, 64))
|
||||
for i := 0; i < len(img.Pix); i += 4 {
|
||||
img.Pix[i+0] = 0x33
|
||||
img.Pix[i+1] = 0x66
|
||||
img.Pix[i+2] = 0x99
|
||||
img.Pix[i+3] = 0xff
|
||||
}
|
||||
t := newTightState()
|
||||
b.ReportAllocs()
|
||||
var bytesOut int
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := encodeTightRect(img, pf, 0, 0, 64, 64, t)
|
||||
bytesOut = len(out)
|
||||
}
|
||||
b.ReportMetric(float64(bytesOut), "wire_bytes")
|
||||
}
|
||||
|
||||
func BenchmarkTileIsUniform(b *testing.B) {
|
||||
img := image.NewRGBA(image.Rect(0, 0, 64, 64))
|
||||
for i := 0; i < len(img.Pix); i += 4 {
|
||||
img.Pix[i+3] = 0xff
|
||||
}
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = tileIsUniform(img, 0, 0, 64, 64)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkEncodeManyTilesVsFullFrame exercises the bandwidth + CPU
|
||||
// trade-off that motivates the full-frame promotion path: encoding a burst
|
||||
// of N dirty 64×64 tiles as separate Tight rects vs emitting one big Tight
|
||||
// rect for the whole frame.
|
||||
func BenchmarkEncodeManyTilesVsFullFrame(b *testing.B) {
|
||||
pf := defaultClientPixelFormat()
|
||||
const w, h = 1920, 1080
|
||||
img := makeBenchImage(w, h, 1)
|
||||
|
||||
// Build the list of every tile in the frame (worst case: entire screen dirty).
|
||||
var tiles [][4]int
|
||||
for ty := 0; ty < h; ty += tileSize {
|
||||
th := tileSize
|
||||
if ty+th > h {
|
||||
th = h - ty
|
||||
}
|
||||
for tx := 0; tx < w; tx += tileSize {
|
||||
tw := tileSize
|
||||
if tx+tw > w {
|
||||
tw = w - tx
|
||||
}
|
||||
tiles = append(tiles, [4]int{tx, ty, tw, th})
|
||||
}
|
||||
}
|
||||
nTiles := len(tiles)
|
||||
|
||||
b.Run("per_tile_tight", func(b *testing.B) {
|
||||
t := newTightState()
|
||||
b.SetBytes(int64(w * h * 4))
|
||||
b.ReportAllocs()
|
||||
var totalOut int
|
||||
for i := 0; i < b.N; i++ {
|
||||
totalOut = 0
|
||||
for _, r := range tiles {
|
||||
out := encodeTightRect(img, pf, r[0], r[1], r[2], r[3], t)
|
||||
totalOut += len(out)
|
||||
}
|
||||
}
|
||||
b.ReportMetric(float64(totalOut), "wire_bytes")
|
||||
b.ReportMetric(float64(nTiles), "tiles")
|
||||
})
|
||||
|
||||
b.Run("full_frame_tight", func(b *testing.B) {
|
||||
t := newTightState()
|
||||
b.SetBytes(int64(w * h * 4))
|
||||
b.ReportAllocs()
|
||||
var totalOut int
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := encodeTightRect(img, pf, 0, 0, w, h, t)
|
||||
totalOut = len(out)
|
||||
}
|
||||
b.ReportMetric(float64(totalOut), "wire_bytes")
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkShouldPromoteToFullFrame verifies the threshold check itself is
|
||||
// cheap. It runs on every frame, so regressions here hit all workloads.
|
||||
func BenchmarkShouldPromoteToFullFrame(b *testing.B) {
|
||||
const w, h = 1920, 1080
|
||||
s := &session{serverW: w, serverH: h}
|
||||
// Build a worst-case rect list (every tile dirty, 510 entries).
|
||||
var rects [][4]int
|
||||
for ty := 0; ty < h; ty += tileSize {
|
||||
th := tileSize
|
||||
if ty+th > h {
|
||||
th = h - ty
|
||||
}
|
||||
for tx := 0; tx < w; tx += tileSize {
|
||||
tw := tileSize
|
||||
if tx+tw > w {
|
||||
tw = w - tx
|
||||
}
|
||||
rects = append(rects, [4]int{tx, ty, tw, th})
|
||||
}
|
||||
}
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = s.shouldPromoteToFullFrame(rects)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkEncodeCoalescedVsPerTile compares per-tile encoding vs the
|
||||
// coalesced rect list emitted by diffRects, on a horizontal-band dirty
|
||||
// pattern (e.g. a scrolling status bar) where coalescing pays off.
|
||||
func BenchmarkEncodeCoalescedVsPerTile(b *testing.B) {
|
||||
pf := defaultClientPixelFormat()
|
||||
const w, h = 1920, 1080
|
||||
img := makeBenchImage(w, h, 1)
|
||||
|
||||
// Dirty band: rows 200..264 (one tile-row), full width.
|
||||
var perTile [][4]int
|
||||
for tx := 0; tx < w; tx += tileSize {
|
||||
tw := tileSize
|
||||
if tx+tw > w {
|
||||
tw = w - tx
|
||||
}
|
||||
perTile = append(perTile, [4]int{tx, 200, tw, tileSize})
|
||||
}
|
||||
coalesced := coalesceRects(append([][4]int(nil), perTile...))
|
||||
|
||||
b.Run("per_tile", func(b *testing.B) {
|
||||
t := newTightState()
|
||||
b.ReportAllocs()
|
||||
var bytesOut int
|
||||
for i := 0; i < b.N; i++ {
|
||||
bytesOut = 0
|
||||
for _, r := range perTile {
|
||||
out := encodeTightRect(img, pf, r[0], r[1], r[2], r[3], t)
|
||||
bytesOut += len(out)
|
||||
}
|
||||
}
|
||||
b.ReportMetric(float64(bytesOut), "wire_bytes")
|
||||
b.ReportMetric(float64(len(perTile)), "rects")
|
||||
})
|
||||
|
||||
b.Run("coalesced", func(b *testing.B) {
|
||||
t := newTightState()
|
||||
b.ReportAllocs()
|
||||
var bytesOut int
|
||||
for i := 0; i < b.N; i++ {
|
||||
bytesOut = 0
|
||||
for _, r := range coalesced {
|
||||
out := encodeTightRect(img, pf, r[0], r[1], r[2], r[3], t)
|
||||
bytesOut += len(out)
|
||||
}
|
||||
}
|
||||
b.ReportMetric(float64(bytesOut), "wire_bytes")
|
||||
b.ReportMetric(float64(len(coalesced)), "rects")
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCoalesceRects(b *testing.B) {
|
||||
const w, h = 1920, 1080
|
||||
// Worst case: every tile dirty.
|
||||
var allTiles [][4]int
|
||||
for ty := 0; ty < h; ty += tileSize {
|
||||
th := tileSize
|
||||
if ty+th > h {
|
||||
th = h - ty
|
||||
}
|
||||
for tx := 0; tx < w; tx += tileSize {
|
||||
tw := tileSize
|
||||
if tx+tw > w {
|
||||
tw = w - tx
|
||||
}
|
||||
allTiles = append(allTiles, [4]int{tx, ty, tw, th})
|
||||
}
|
||||
}
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
in := make([][4]int, len(allTiles))
|
||||
copy(in, allTiles)
|
||||
_ = coalesceRects(in)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkEncodeTight_Photo measures Tight on random/photographic content.
|
||||
// The internal sampledColorCount gate routes large many-colour rects to JPEG
|
||||
// at quality 70.
|
||||
func BenchmarkEncodeTight_Photo(b *testing.B) {
|
||||
pf := defaultClientPixelFormat()
|
||||
for _, r := range []struct {
|
||||
name string
|
||||
w, h int
|
||||
}{
|
||||
{"256x256", 256, 256},
|
||||
{"512x512", 512, 512},
|
||||
{"1080p", 1920, 1080},
|
||||
} {
|
||||
img := makeBenchImage(r.w, r.h, 1)
|
||||
b.Run(r.name+"/tight", func(b *testing.B) {
|
||||
t := newTightState()
|
||||
b.SetBytes(int64(r.w * r.h * 4))
|
||||
b.ReportAllocs()
|
||||
var bytesOut int
|
||||
for i := 0; i < b.N; i++ {
|
||||
out := encodeTightRect(img, pf, 0, 0, r.w, r.h, t)
|
||||
bytesOut = len(out)
|
||||
}
|
||||
b.ReportMetric(float64(bytesOut), "wire_bytes")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDiffRects(b *testing.B) {
|
||||
for _, r := range benchRects {
|
||||
prev, cur := makeBenchImagePartial(r.w, r.h, 100)
|
||||
b.Run(r.name, func(b *testing.B) {
|
||||
b.SetBytes(int64(r.w * r.h * 4))
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = diffRects(prev, cur, r.w, r.h, tileSize)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
274
client/vnc/server/scancodes.go
Normal file
274
client/vnc/server/scancodes.go
Normal file
@@ -0,0 +1,274 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
// QEMU Extended Key Event carries hardware scancodes encoded as PC AT Set 1.
|
||||
// Single-byte codes cover the standard keys; the "extended" prefix 0xE0 is
|
||||
// merged into the high byte (so 0xE048 is the extended-Up arrow). This file
|
||||
// translates those scancodes into the per-platform identifiers each input
|
||||
// backend wants:
|
||||
//
|
||||
// - Linux uinput wants Linux KEY_* codes (defined in
|
||||
// linux/input-event-codes.h). uinput is what we use for virtual Xvfb
|
||||
// sessions on Linux.
|
||||
// - X11 XTest wants XKB keycodes, which on a standard layout equal
|
||||
// Linux KEY_* + 8 (the per-server offset between the Linux event code
|
||||
// and the X server's keycode space).
|
||||
// - Windows SendInput accepts the PC AT scancode directly via
|
||||
// KEYEVENTF_SCANCODE, so no mapping table is needed there; the
|
||||
// extended-key bit is set when the QEMU scancode high byte is 0xE0.
|
||||
// - macOS CGEventCreateKeyboardEvent takes a "virtual keycode" from
|
||||
// Apple's HID set, which is unrelated to PC AT and needs its own
|
||||
// table (see qemuToMacVK in input_darwin.go).
|
||||
//
|
||||
// Linux KEY_* codes. Only the ones we reference, since the full
|
||||
// linux/input-event-codes.h list isn't useful here. Naming mirrors the
|
||||
// existing constants in input_uinput_linux.go (mixed case, no underscores).
|
||||
const (
|
||||
keyEsc = 1
|
||||
key1 = 2
|
||||
key2 = 3
|
||||
key3 = 4
|
||||
key4 = 5
|
||||
key5 = 6
|
||||
key6 = 7
|
||||
key7 = 8
|
||||
key8 = 9
|
||||
key9 = 10
|
||||
key0 = 11
|
||||
keyMinus = 12
|
||||
keyEqual = 13
|
||||
keyBackspace = 14
|
||||
keyTab = 15
|
||||
keyQ = 16
|
||||
keyW = 17
|
||||
keyE = 18
|
||||
keyR = 19
|
||||
keyT = 20
|
||||
keyY = 21
|
||||
keyU = 22
|
||||
keyI = 23
|
||||
keyO = 24
|
||||
keyP = 25
|
||||
keyLeftBracket = 26
|
||||
keyRightBracket = 27
|
||||
keyEnter = 28
|
||||
keyLeftCtrl = 29
|
||||
keyA = 30
|
||||
keyS = 31
|
||||
keyD = 32
|
||||
keyF = 33
|
||||
keyG = 34
|
||||
keyH = 35
|
||||
keyJ = 36
|
||||
keyK = 37
|
||||
keyL = 38
|
||||
keySemicolon = 39
|
||||
keyApostrophe = 40
|
||||
keyGrave = 41
|
||||
keyLeftShift = 42
|
||||
keyBackslash = 43
|
||||
keyZ = 44
|
||||
keyX = 45
|
||||
keyC = 46
|
||||
keyV = 47
|
||||
keyB = 48
|
||||
keyN = 49
|
||||
keyM = 50
|
||||
keyComma = 51
|
||||
keyDot = 52
|
||||
keySlash = 53
|
||||
keyRightShift = 54
|
||||
keyKPAsterisk = 55
|
||||
keyLeftAlt = 56
|
||||
keySpace = 57
|
||||
keyCapsLock = 58
|
||||
keyF1 = 59
|
||||
keyF2 = 60
|
||||
keyF3 = 61
|
||||
keyF4 = 62
|
||||
keyF5 = 63
|
||||
keyF6 = 64
|
||||
keyF7 = 65
|
||||
keyF8 = 66
|
||||
keyF9 = 67
|
||||
keyF10 = 68
|
||||
keyNumLock = 69
|
||||
keyScrollLock = 70
|
||||
keyKP7 = 71
|
||||
keyKP8 = 72
|
||||
keyKP9 = 73
|
||||
keyKPMinus = 74
|
||||
keyKP4 = 75
|
||||
keyKP5 = 76
|
||||
keyKP6 = 77
|
||||
keyKPPlus = 78
|
||||
keyKP1 = 79
|
||||
keyKP2 = 80
|
||||
keyKP3 = 81
|
||||
keyKP0 = 82
|
||||
keyKPDot = 83
|
||||
key102nd = 86
|
||||
keyF11 = 87
|
||||
keyF12 = 88
|
||||
keyKPEnter = 96
|
||||
keyRightCtrl = 97
|
||||
keyKPSlash = 98
|
||||
keySysRq = 99
|
||||
keyRightAlt = 100
|
||||
keyHome = 102
|
||||
keyUp = 103
|
||||
keyPageUp = 104
|
||||
keyLeft = 105
|
||||
keyRight = 106
|
||||
keyEnd = 107
|
||||
keyDown = 108
|
||||
keyPageDown = 109
|
||||
keyInsert = 110
|
||||
keyDelete = 111
|
||||
keyMute = 113
|
||||
keyVolumeDown = 114
|
||||
keyVolumeUp = 115
|
||||
keyLeftMeta = 125
|
||||
keyRightMeta = 126
|
||||
keyCompose = 127
|
||||
)
|
||||
|
||||
// qemuToLinuxKey maps the PC AT Set 1 scancode QEMU sends to a Linux KEY_*
|
||||
// code. The high byte 0xE0 marks "extended" scancodes (arrows, the right-
|
||||
// side modifier keys, keypad enter/divide, browser keys, etc.).
|
||||
//
|
||||
// Keep this table dense so a reviewer sees the whole keyboard at a glance,
|
||||
// and so adding a new key is a single line.
|
||||
var qemuToLinuxKey = map[uint32]int{
|
||||
// Single-byte (non-extended) scancodes.
|
||||
0x01: keyEsc,
|
||||
0x02: key1,
|
||||
0x03: key2,
|
||||
0x04: key3,
|
||||
0x05: key4,
|
||||
0x06: key5,
|
||||
0x07: key6,
|
||||
0x08: key7,
|
||||
0x09: key8,
|
||||
0x0A: key9,
|
||||
0x0B: key0,
|
||||
0x0C: keyMinus,
|
||||
0x0D: keyEqual,
|
||||
0x0E: keyBackspace,
|
||||
0x0F: keyTab,
|
||||
0x10: keyQ,
|
||||
0x11: keyW,
|
||||
0x12: keyE,
|
||||
0x13: keyR,
|
||||
0x14: keyT,
|
||||
0x15: keyY,
|
||||
0x16: keyU,
|
||||
0x17: keyI,
|
||||
0x18: keyO,
|
||||
0x19: keyP,
|
||||
0x1A: keyLeftBracket,
|
||||
0x1B: keyRightBracket,
|
||||
0x1C: keyEnter,
|
||||
0x1D: keyLeftCtrl,
|
||||
0x1E: keyA,
|
||||
0x1F: keyS,
|
||||
0x20: keyD,
|
||||
0x21: keyF,
|
||||
0x22: keyG,
|
||||
0x23: keyH,
|
||||
0x24: keyJ,
|
||||
0x25: keyK,
|
||||
0x26: keyL,
|
||||
0x27: keySemicolon,
|
||||
0x28: keyApostrophe,
|
||||
0x29: keyGrave,
|
||||
0x2A: keyLeftShift,
|
||||
0x2B: keyBackslash,
|
||||
0x2C: keyZ,
|
||||
0x2D: keyX,
|
||||
0x2E: keyC,
|
||||
0x2F: keyV,
|
||||
0x30: keyB,
|
||||
0x31: keyN,
|
||||
0x32: keyM,
|
||||
0x33: keyComma,
|
||||
0x34: keyDot,
|
||||
0x35: keySlash,
|
||||
0x36: keyRightShift,
|
||||
0x37: keyKPAsterisk,
|
||||
0x38: keyLeftAlt,
|
||||
0x39: keySpace,
|
||||
0x3A: keyCapsLock,
|
||||
0x3B: keyF1,
|
||||
0x3C: keyF2,
|
||||
0x3D: keyF3,
|
||||
0x3E: keyF4,
|
||||
0x3F: keyF5,
|
||||
0x40: keyF6,
|
||||
0x41: keyF7,
|
||||
0x42: keyF8,
|
||||
0x43: keyF9,
|
||||
0x44: keyF10,
|
||||
0x45: keyNumLock,
|
||||
0x46: keyScrollLock,
|
||||
0x47: keyKP7,
|
||||
0x48: keyKP8,
|
||||
0x49: keyKP9,
|
||||
0x4A: keyKPMinus,
|
||||
0x4B: keyKP4,
|
||||
0x4C: keyKP5,
|
||||
0x4D: keyKP6,
|
||||
0x4E: keyKPPlus,
|
||||
0x4F: keyKP1,
|
||||
0x50: keyKP2,
|
||||
0x51: keyKP3,
|
||||
0x52: keyKP0,
|
||||
0x53: keyKPDot,
|
||||
0x56: key102nd,
|
||||
0x57: keyF11,
|
||||
0x58: keyF12,
|
||||
|
||||
// Extended (0xE0-prefixed) scancodes.
|
||||
0xE01C: keyKPEnter,
|
||||
0xE01D: keyRightCtrl,
|
||||
0xE020: keyMute,
|
||||
0xE02E: keyVolumeDown,
|
||||
0xE030: keyVolumeUp,
|
||||
0xE035: keyKPSlash,
|
||||
0xE037: keySysRq, // PrintScreen
|
||||
0xE038: keyRightAlt,
|
||||
0xE047: keyHome,
|
||||
0xE048: keyUp,
|
||||
0xE049: keyPageUp,
|
||||
0xE04B: keyLeft,
|
||||
0xE04D: keyRight,
|
||||
0xE04F: keyEnd,
|
||||
0xE050: keyDown,
|
||||
0xE051: keyPageDown,
|
||||
0xE052: keyInsert,
|
||||
0xE053: keyDelete,
|
||||
0xE05B: keyLeftMeta,
|
||||
0xE05C: keyRightMeta,
|
||||
0xE05D: keyCompose,
|
||||
}
|
||||
|
||||
// qemuScancodeToLinuxKey is the lookup the uinput and X11 paths use.
|
||||
// Returns 0 (which Linux treats as KEY_RESERVED) when the scancode has no
|
||||
// mapping, signalling "fall back to the keysym path".
|
||||
func qemuScancodeToLinuxKey(scancode uint32) int {
|
||||
return qemuToLinuxKey[scancode]
|
||||
}
|
||||
|
||||
// qemuScancodeIsExtended reports whether a QEMU scancode is in the
|
||||
// 0xE0-prefixed extended range. Used by Windows SendInput to set the
|
||||
// KEYEVENTF_EXTENDEDKEY flag.
|
||||
func qemuScancodeIsExtended(scancode uint32) bool {
|
||||
return scancode&0xFF00 == 0xE000
|
||||
}
|
||||
|
||||
// qemuScancodeLowByte returns the byte SendInput's wScan field actually
|
||||
// stores: the low byte of the scancode regardless of any extended prefix.
|
||||
func qemuScancodeLowByte(scancode uint32) uint16 {
|
||||
return uint16(scancode & 0xFF)
|
||||
}
|
||||
238
client/vnc/server/scancodes_darwin.go
Normal file
238
client/vnc/server/scancodes_darwin.go
Normal file
@@ -0,0 +1,238 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
// Apple keyboard virtual-key codes used with CGEventCreateKeyboardEvent.
|
||||
// These are the kVK_ANSI_* / kVK_* values from Apple's
|
||||
// HIToolbox/Events.h; reproduced here so we don't need to drag in the
|
||||
// HIToolbox framework just for the constants.
|
||||
const (
|
||||
macKeyA uint16 = 0x00
|
||||
macKeyS uint16 = 0x01
|
||||
macKeyD uint16 = 0x02
|
||||
macKeyF uint16 = 0x03
|
||||
macKeyH uint16 = 0x04
|
||||
macKeyG uint16 = 0x05
|
||||
macKeyZ uint16 = 0x06
|
||||
macKeyX uint16 = 0x07
|
||||
macKeyC uint16 = 0x08
|
||||
macKeyV uint16 = 0x09
|
||||
macKeyNonUSBackslash uint16 = 0x0A // ISO_Section / 102nd
|
||||
macKeyB uint16 = 0x0B
|
||||
macKeyQ uint16 = 0x0C
|
||||
macKeyW uint16 = 0x0D
|
||||
macKeyE uint16 = 0x0E
|
||||
macKeyR uint16 = 0x0F
|
||||
macKeyY uint16 = 0x10
|
||||
macKeyT uint16 = 0x11
|
||||
macKey1 uint16 = 0x12
|
||||
macKey2 uint16 = 0x13
|
||||
macKey3 uint16 = 0x14
|
||||
macKey4 uint16 = 0x15
|
||||
macKey6 uint16 = 0x16
|
||||
macKey5 uint16 = 0x17
|
||||
macKeyEqual uint16 = 0x18
|
||||
macKey9 uint16 = 0x19
|
||||
macKey7 uint16 = 0x1A
|
||||
macKeyMinus uint16 = 0x1B
|
||||
macKey8 uint16 = 0x1C
|
||||
macKey0 uint16 = 0x1D
|
||||
macKeyRightBracket uint16 = 0x1E
|
||||
macKeyO uint16 = 0x1F
|
||||
macKeyU uint16 = 0x20
|
||||
macKeyLeftBracket uint16 = 0x21
|
||||
macKeyI uint16 = 0x22
|
||||
macKeyP uint16 = 0x23
|
||||
macKeyReturn uint16 = 0x24
|
||||
macKeyL uint16 = 0x25
|
||||
macKeyJ uint16 = 0x26
|
||||
macKeyApostrophe uint16 = 0x27
|
||||
macKeyK uint16 = 0x28
|
||||
macKeySemicolon uint16 = 0x29
|
||||
macKeyBackslash uint16 = 0x2A
|
||||
macKeyComma uint16 = 0x2B
|
||||
macKeySlash uint16 = 0x2C
|
||||
macKeyN uint16 = 0x2D
|
||||
macKeyM uint16 = 0x2E
|
||||
macKeyPeriod uint16 = 0x2F
|
||||
macKeyTab uint16 = 0x30
|
||||
macKeySpace uint16 = 0x31
|
||||
macKeyGrave uint16 = 0x32
|
||||
macKeyDelete uint16 = 0x33 // Backspace
|
||||
macKeyEscape uint16 = 0x35
|
||||
macKeyCommand uint16 = 0x37
|
||||
macKeyShift uint16 = 0x38
|
||||
macKeyCapsLock uint16 = 0x39
|
||||
macKeyOption uint16 = 0x3A // Alt
|
||||
macKeyControl uint16 = 0x3B
|
||||
macKeyRightShift uint16 = 0x3C
|
||||
macKeyRightOption uint16 = 0x3D
|
||||
macKeyRightControl uint16 = 0x3E
|
||||
macKeyFunction uint16 = 0x3F
|
||||
macKeyF17 uint16 = 0x40
|
||||
macKeyKPDecimal uint16 = 0x41
|
||||
macKeyKPMultiply uint16 = 0x43
|
||||
macKeyKPPlus uint16 = 0x45
|
||||
macKeyKPClear uint16 = 0x47 // numlock
|
||||
macKeyVolumeUp uint16 = 0x48
|
||||
macKeyVolumeDown uint16 = 0x49
|
||||
macKeyMute uint16 = 0x4A
|
||||
macKeyKPDivide uint16 = 0x4B
|
||||
macKeyKPEnter uint16 = 0x4C
|
||||
macKeyKPMinus uint16 = 0x4E
|
||||
macKeyF18 uint16 = 0x4F
|
||||
macKeyF19 uint16 = 0x50
|
||||
macKeyKPEqual uint16 = 0x51
|
||||
macKeyKP0 uint16 = 0x52
|
||||
macKeyKP1 uint16 = 0x53
|
||||
macKeyKP2 uint16 = 0x54
|
||||
macKeyKP3 uint16 = 0x55
|
||||
macKeyKP4 uint16 = 0x56
|
||||
macKeyKP5 uint16 = 0x57
|
||||
macKeyKP6 uint16 = 0x58
|
||||
macKeyKP7 uint16 = 0x59
|
||||
macKeyF20 uint16 = 0x5A
|
||||
macKeyKP8 uint16 = 0x5B
|
||||
macKeyKP9 uint16 = 0x5C
|
||||
macKeyF5 uint16 = 0x60
|
||||
macKeyF6 uint16 = 0x61
|
||||
macKeyF7 uint16 = 0x62
|
||||
macKeyF3 uint16 = 0x63
|
||||
macKeyF8 uint16 = 0x64
|
||||
macKeyF9 uint16 = 0x65
|
||||
macKeyF11 uint16 = 0x67
|
||||
macKeyF13 uint16 = 0x69 // PrintScreen on most layouts
|
||||
macKeyF16 uint16 = 0x6A
|
||||
macKeyF14 uint16 = 0x6B
|
||||
macKeyF10 uint16 = 0x6D
|
||||
macKeyF12 uint16 = 0x6F
|
||||
macKeyF15 uint16 = 0x71
|
||||
macKeyHelp uint16 = 0x72 // Insert on PC keyboards
|
||||
macKeyHome uint16 = 0x73
|
||||
macKeyPageUp uint16 = 0x74
|
||||
macKeyForwardDelete uint16 = 0x75
|
||||
macKeyF4 uint16 = 0x76
|
||||
macKeyEnd uint16 = 0x77
|
||||
macKeyF2 uint16 = 0x78
|
||||
macKeyPageDown uint16 = 0x79
|
||||
macKeyF1 uint16 = 0x7A
|
||||
macKeyLeft uint16 = 0x7B
|
||||
macKeyRight uint16 = 0x7C
|
||||
macKeyDown uint16 = 0x7D
|
||||
macKeyUp uint16 = 0x7E
|
||||
)
|
||||
|
||||
// qemuToMacVK maps PC AT Set 1 scancodes (as QEMU emits them, with the
|
||||
// 0xE0 prefix merged into the high byte) onto Apple virtual-key codes.
|
||||
// Layout-independent: the scancode names the physical key, the user's
|
||||
// active keyboard layout on the Mac decides what the key produces.
|
||||
var qemuToMacVK = map[uint32]uint16{
|
||||
// Single-byte (non-extended).
|
||||
0x01: macKeyEscape,
|
||||
0x02: macKey1,
|
||||
0x03: macKey2,
|
||||
0x04: macKey3,
|
||||
0x05: macKey4,
|
||||
0x06: macKey5,
|
||||
0x07: macKey6,
|
||||
0x08: macKey7,
|
||||
0x09: macKey8,
|
||||
0x0A: macKey9,
|
||||
0x0B: macKey0,
|
||||
0x0C: macKeyMinus,
|
||||
0x0D: macKeyEqual,
|
||||
0x0E: macKeyDelete, // PC Backspace -> mac "Delete"
|
||||
0x0F: macKeyTab,
|
||||
0x10: macKeyQ,
|
||||
0x11: macKeyW,
|
||||
0x12: macKeyE,
|
||||
0x13: macKeyR,
|
||||
0x14: macKeyT,
|
||||
0x15: macKeyY,
|
||||
0x16: macKeyU,
|
||||
0x17: macKeyI,
|
||||
0x18: macKeyO,
|
||||
0x19: macKeyP,
|
||||
0x1A: macKeyLeftBracket,
|
||||
0x1B: macKeyRightBracket,
|
||||
0x1C: macKeyReturn,
|
||||
0x1D: macKeyControl,
|
||||
0x1E: macKeyA,
|
||||
0x1F: macKeyS,
|
||||
0x20: macKeyD,
|
||||
0x21: macKeyF,
|
||||
0x22: macKeyG,
|
||||
0x23: macKeyH,
|
||||
0x24: macKeyJ,
|
||||
0x25: macKeyK,
|
||||
0x26: macKeyL,
|
||||
0x27: macKeySemicolon,
|
||||
0x28: macKeyApostrophe,
|
||||
0x29: macKeyGrave,
|
||||
0x2A: macKeyShift,
|
||||
0x2B: macKeyBackslash,
|
||||
0x2C: macKeyZ,
|
||||
0x2D: macKeyX,
|
||||
0x2E: macKeyC,
|
||||
0x2F: macKeyV,
|
||||
0x30: macKeyB,
|
||||
0x31: macKeyN,
|
||||
0x32: macKeyM,
|
||||
0x33: macKeyComma,
|
||||
0x34: macKeyPeriod,
|
||||
0x35: macKeySlash,
|
||||
0x36: macKeyRightShift,
|
||||
0x37: macKeyKPMultiply,
|
||||
0x38: macKeyOption, // Left Alt -> Option
|
||||
0x39: macKeySpace,
|
||||
0x3A: macKeyCapsLock,
|
||||
0x3B: macKeyF1,
|
||||
0x3C: macKeyF2,
|
||||
0x3D: macKeyF3,
|
||||
0x3E: macKeyF4,
|
||||
0x3F: macKeyF5,
|
||||
0x40: macKeyF6,
|
||||
0x41: macKeyF7,
|
||||
0x42: macKeyF8,
|
||||
0x43: macKeyF9,
|
||||
0x44: macKeyF10,
|
||||
0x45: macKeyKPClear, // PC NumLock -> mac Clear
|
||||
0x47: macKeyKP7,
|
||||
0x48: macKeyKP8,
|
||||
0x49: macKeyKP9,
|
||||
0x4A: macKeyKPMinus,
|
||||
0x4B: macKeyKP4,
|
||||
0x4C: macKeyKP5,
|
||||
0x4D: macKeyKP6,
|
||||
0x4E: macKeyKPPlus,
|
||||
0x4F: macKeyKP1,
|
||||
0x50: macKeyKP2,
|
||||
0x51: macKeyKP3,
|
||||
0x52: macKeyKP0,
|
||||
0x53: macKeyKPDecimal,
|
||||
0x56: macKeyNonUSBackslash,
|
||||
0x57: macKeyF11,
|
||||
0x58: macKeyF12,
|
||||
|
||||
// Extended (0xE0 prefix).
|
||||
0xE01C: macKeyKPEnter,
|
||||
0xE01D: macKeyRightControl,
|
||||
0xE020: macKeyMute,
|
||||
0xE02E: macKeyVolumeDown,
|
||||
0xE030: macKeyVolumeUp,
|
||||
0xE035: macKeyKPDivide,
|
||||
0xE037: macKeyF13, // PrintScreen
|
||||
0xE038: macKeyRightOption,
|
||||
0xE047: macKeyHome,
|
||||
0xE048: macKeyUp,
|
||||
0xE049: macKeyPageUp,
|
||||
0xE04B: macKeyLeft,
|
||||
0xE04D: macKeyRight,
|
||||
0xE04F: macKeyEnd,
|
||||
0xE050: macKeyDown,
|
||||
0xE051: macKeyPageDown,
|
||||
0xE052: macKeyHelp, // PC Insert -> mac Help
|
||||
0xE053: macKeyForwardDelete,
|
||||
0xE05B: macKeyCommand, // Left Windows -> Command
|
||||
0xE05C: macKeyCommand, // Right Windows -> Command (no separate code)
|
||||
}
|
||||
100
client/vnc/server/scancodes_test.go
Normal file
100
client/vnc/server/scancodes_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestQemuScancodeToLinuxKey_KnownLetters(t *testing.T) {
|
||||
// Spot-check a few familiar letter keys against the Linux KEY_*
|
||||
// values they're supposed to land on.
|
||||
tests := []struct {
|
||||
name string
|
||||
scancode uint32
|
||||
want int
|
||||
}{
|
||||
{"A", 0x1E, keyA},
|
||||
{"S", 0x1F, keyS},
|
||||
{"D", 0x20, keyD},
|
||||
{"Q", 0x10, keyQ},
|
||||
{"Z", 0x2C, keyZ},
|
||||
{"1", 0x02, key1},
|
||||
{"Esc", 0x01, keyEsc},
|
||||
{"Tab", 0x0F, keyTab},
|
||||
{"Space", 0x39, keySpace},
|
||||
{"LeftShift", 0x2A, keyLeftShift},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got := qemuScancodeToLinuxKey(tc.scancode)
|
||||
if got != tc.want {
|
||||
t.Errorf("%s: scancode 0x%X => %d, want %d", tc.name, tc.scancode, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQemuScancodeToLinuxKey_Extended(t *testing.T) {
|
||||
// Extended (0xE0-prefixed) scancodes for arrow + navigation cluster.
|
||||
tests := []struct {
|
||||
name string
|
||||
scancode uint32
|
||||
want int
|
||||
}{
|
||||
{"Up", 0xE048, keyUp},
|
||||
{"Down", 0xE050, keyDown},
|
||||
{"Left", 0xE04B, keyLeft},
|
||||
{"Right", 0xE04D, keyRight},
|
||||
{"Home", 0xE047, keyHome},
|
||||
{"End", 0xE04F, keyEnd},
|
||||
{"PageUp", 0xE049, keyPageUp},
|
||||
{"PageDown", 0xE051, keyPageDown},
|
||||
{"Insert", 0xE052, keyInsert},
|
||||
{"Delete", 0xE053, keyDelete},
|
||||
{"RightCtrl", 0xE01D, keyRightCtrl},
|
||||
{"RightAlt", 0xE038, keyRightAlt},
|
||||
{"KPEnter", 0xE01C, keyKPEnter},
|
||||
{"KPSlash", 0xE035, keyKPSlash},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got := qemuScancodeToLinuxKey(tc.scancode)
|
||||
if got != tc.want {
|
||||
t.Errorf("%s: scancode 0x%X => %d, want %d", tc.name, tc.scancode, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQemuScancodeToLinuxKey_Miss(t *testing.T) {
|
||||
// 0xE0FF is in the extended range but not a real key. Must return 0
|
||||
// so the caller can fall back to the keysym path.
|
||||
if got := qemuScancodeToLinuxKey(0xE0FF); got != 0 {
|
||||
t.Errorf("unknown scancode should miss: got %d, want 0", got)
|
||||
}
|
||||
if got := qemuScancodeToLinuxKey(0xFF); got != 0 {
|
||||
t.Errorf("unknown non-extended scancode should miss: got %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQemuScancodeIsExtended(t *testing.T) {
|
||||
cases := []struct {
|
||||
scancode uint32
|
||||
want bool
|
||||
}{
|
||||
{0x1E, false},
|
||||
{0xE048, true},
|
||||
{0xE000, true},
|
||||
{0xFF, false},
|
||||
{0xE0FF, true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
if got := qemuScancodeIsExtended(tc.scancode); got != tc.want {
|
||||
t.Errorf("isExtended(0x%X) = %v, want %v", tc.scancode, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQemuScancodeLowByte(t *testing.T) {
|
||||
if got := qemuScancodeLowByte(0xE048); got != 0x48 {
|
||||
t.Errorf("lowByte(0xE048) = 0x%X, want 0x48", got)
|
||||
}
|
||||
if got := qemuScancodeLowByte(0x1E); got != 0x1E {
|
||||
t.Errorf("lowByte(0x1E) = 0x%X, want 0x1E", got)
|
||||
}
|
||||
}
|
||||
1054
client/vnc/server/server.go
Normal file
1054
client/vnc/server/server.go
Normal file
File diff suppressed because it is too large
Load Diff
119
client/vnc/server/server_darwin.go
Normal file
119
client/vnc/server/server_darwin.go
Normal file
@@ -0,0 +1,119 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (s *Server) platformInit() {
|
||||
// no-op on macOS
|
||||
}
|
||||
|
||||
func (s *Server) platformShutdown() {
|
||||
// no-op on macOS
|
||||
}
|
||||
|
||||
func (s *Server) platformSessionManager() virtualSessionManager {
|
||||
return nil
|
||||
}
|
||||
|
||||
// serviceAcceptLoop runs in a LaunchDaemon and proxies each VNC
|
||||
// connection to a per-user agent. The agent is spawned lazily on the
|
||||
// first connection (and respawned after a console-user change) via
|
||||
// launchctl asuser, which is the only mechanism that lands a child
|
||||
// inside the user's Aqua session, where WindowServer and TCC grants
|
||||
// for screen capture work.
|
||||
func (s *Server) serviceAcceptLoop() {
|
||||
mgr := newDarwinAgentManager(s.ctx)
|
||||
defer mgr.stop()
|
||||
|
||||
log.Infof("service mode, proxying connections to per-user agent on 127.0.0.1:%d", agentPort)
|
||||
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
s.log.Debugf("accept VNC connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.tryAcquireConnSlot() {
|
||||
s.log.Warnf("rejecting VNC connection from %s: %d concurrent connections in flight", conn.RemoteAddr(), maxConcurrentVNCConns)
|
||||
_ = conn.Close()
|
||||
continue
|
||||
}
|
||||
enableTCPKeepAlive(conn, s.log)
|
||||
conn = newMetricsConn(conn, s.sessionRecorder)
|
||||
s.trackConn(conn)
|
||||
go func(c net.Conn) {
|
||||
defer s.releaseConnSlot()
|
||||
defer s.untrackConn(c)
|
||||
s.handleServiceConnectionDarwin(c, mgr)
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleServiceConnectionDarwin(conn net.Conn, mgr *darwinAgentManager) {
|
||||
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
|
||||
|
||||
if !s.isAllowedSource(conn.RemoteAddr()) {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var headerBuf bytes.Buffer
|
||||
tee := io.TeeReader(conn, &headerBuf)
|
||||
teeConn := &darwinPrefixConn{Reader: tee, Conn: conn}
|
||||
|
||||
header, err := s.readConnectionHeader(teeConn)
|
||||
if err != nil {
|
||||
connLog.Debugf("read connection header: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if !s.disableAuth {
|
||||
if _, err := s.authenticateSession(header); err != nil {
|
||||
rejectConnection(conn, codeMessage(RejectCodeAuthForbidden, err.Error()))
|
||||
connLog.Warnf("auth rejected: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
s.registerConnAuth(conn, header)
|
||||
|
||||
token, err := mgr.ensure(s.ctx)
|
||||
if err != nil {
|
||||
code := RejectCodeCapturerError
|
||||
if errors.Is(err, errNoConsoleUser) {
|
||||
code = RejectCodeNoConsoleUser
|
||||
}
|
||||
rejectConnection(conn, codeMessage(code, err.Error()))
|
||||
connLog.Warnf("spawn per-user agent: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
replayConn := &darwinPrefixConn{
|
||||
Reader: io.MultiReader(&headerBuf, conn),
|
||||
Conn: conn,
|
||||
}
|
||||
proxyToAgent(s.ctx, replayConn, agentPort, token)
|
||||
}
|
||||
|
||||
// darwinPrefixConn replays the already-consumed connection-header bytes
|
||||
// in front of the proxy stream, mirroring the Windows prefixConn shape.
|
||||
type darwinPrefixConn struct {
|
||||
io.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (p *darwinPrefixConn) Read(b []byte) (int, error) { return p.Reader.Read(b) }
|
||||
318
client/vnc/server/server_test.go
Normal file
318
client/vnc/server/server_test.go
Normal file
@@ -0,0 +1,318 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"image"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testCapturer returns a 100x100 image for test sessions.
|
||||
type testCapturer struct{}
|
||||
|
||||
func (t *testCapturer) Width() int { return 100 }
|
||||
func (t *testCapturer) Height() int { return 100 }
|
||||
func (t *testCapturer) Capture() (*image.RGBA, error) {
|
||||
return image.NewRGBA(image.Rect(0, 0, 100, 100)), nil
|
||||
}
|
||||
|
||||
func startTestServer(t *testing.T, disableAuth bool) (net.Addr, *Server) {
|
||||
t.Helper()
|
||||
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(disableAuth)
|
||||
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
require.NoError(t, srv.Start(t.Context(), addr, network))
|
||||
// Override local address so source validation doesn't reject 127.0.0.1 as "own IP".
|
||||
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
||||
t.Cleanup(func() { _ = srv.Stop() })
|
||||
|
||||
return srv.listener.Addr(), srv
|
||||
}
|
||||
|
||||
func TestAuthEnabled_NoSessionAuth_RejectsConnection(t *testing.T) {
|
||||
addr, _ := startTestServer(t, false)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
// Header with no Noise handshake. Auth-required servers must reject
|
||||
// because no client static was authenticated.
|
||||
header := make([]byte, 11) // mode + usernameLen + sessionID + w + h
|
||||
header[0] = ModeAttach
|
||||
_, err = conn.Write(header)
|
||||
require.NoError(t, err)
|
||||
|
||||
var version [12]byte
|
||||
_, err = io.ReadFull(conn, version[:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "RFB 003.008\n", string(version[:]))
|
||||
|
||||
_, err = conn.Write(version[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
var numTypes [1]byte
|
||||
_, err = io.ReadFull(conn, numTypes[:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, byte(0), numTypes[0], "should have 0 security types (failure)")
|
||||
|
||||
var reasonLen [4]byte
|
||||
_, err = io.ReadFull(conn, reasonLen[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:]))
|
||||
_, err = io.ReadFull(conn, reason)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(reason), "identity proof missing", "rejection reason should mention missing identity proof")
|
||||
}
|
||||
|
||||
func TestAuthDisabled_AllowsConnection(t *testing.T) {
|
||||
addr, _ := startTestServer(t, true)
|
||||
|
||||
conn, err := net.Dial("tcp", addr.String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
header := make([]byte, 11) // mode + usernameLen + sessionID + w + h
|
||||
header[0] = ModeAttach
|
||||
_, err = conn.Write(header)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Server should send RFB version.
|
||||
var version [12]byte
|
||||
_, err = io.ReadFull(conn, version[:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "RFB 003.008\n", string(version[:]))
|
||||
|
||||
// Write client version.
|
||||
_, err = conn.Write(version[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should get security types (not 0 = failure).
|
||||
var numTypes [1]byte
|
||||
_, err = io.ReadFull(conn, numTypes[:])
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, byte(0), numTypes[0], "should have at least one security type (auth disabled)")
|
||||
}
|
||||
|
||||
// TestAuth_NoUnauthBytesPastHeader proves the server does not send any RFB
|
||||
// content to a connection that fails source validation. Specifically, the
|
||||
// server must close immediately and the client must see EOF before any RFB
|
||||
// version greeting is written.
|
||||
func TestAuth_NoUnauthBytesPastHeader(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(true)
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
// Tight overlay that excludes 127.0.0.0/8 and a non-loopback local IP, so
|
||||
// the loopback short-circuit in isAllowedSource doesn't apply.
|
||||
require.NoError(t, srv.Start(t.Context(), addr, netip.MustParsePrefix("10.99.0.0/16")))
|
||||
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
||||
t.Cleanup(func() { _ = srv.Stop() })
|
||||
|
||||
conn, err := net.Dial("tcp", srv.listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
require.NoError(t, conn.SetDeadline(time.Now().Add(5*time.Second)))
|
||||
|
||||
// Reading even one byte must EOF: the source IP (127.0.0.1) is outside
|
||||
// the configured overlay, so handleConnection closes before writing.
|
||||
var b [1]byte
|
||||
_, err = io.ReadFull(conn, b[:])
|
||||
require.Error(t, err, "non-overlay client must see EOF, not an RFB greeting")
|
||||
}
|
||||
|
||||
func TestIsAllowedSource(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
localAddr netip.Addr
|
||||
network netip.Prefix
|
||||
remote net.Addr
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "non-tcp address rejected",
|
||||
localAddr: netip.MustParseAddr("10.99.99.1"),
|
||||
network: netip.MustParsePrefix("10.99.0.0/16"),
|
||||
remote: &net.UDPAddr{IP: net.ParseIP("10.99.99.2"), Port: 1234},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "own IP rejected",
|
||||
localAddr: netip.MustParseAddr("10.99.99.1"),
|
||||
network: netip.MustParsePrefix("10.99.0.0/16"),
|
||||
remote: &net.TCPAddr{IP: net.ParseIP("10.99.99.1"), Port: 5900},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "non-overlay IP rejected",
|
||||
localAddr: netip.MustParseAddr("10.99.99.1"),
|
||||
network: netip.MustParsePrefix("10.99.0.0/16"),
|
||||
remote: &net.TCPAddr{IP: net.ParseIP("192.168.1.1"), Port: 5900},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "overlay IP allowed",
|
||||
localAddr: netip.MustParseAddr("10.99.99.1"),
|
||||
network: netip.MustParsePrefix("10.99.0.0/16"),
|
||||
remote: &net.TCPAddr{IP: net.ParseIP("10.99.99.2"), Port: 5900},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "v4-mapped v6 in overlay allowed (unmapped)",
|
||||
localAddr: netip.MustParseAddr("10.99.99.1"),
|
||||
network: netip.MustParsePrefix("10.99.0.0/16"),
|
||||
remote: &net.TCPAddr{IP: net.ParseIP("::ffff:10.99.99.2"), Port: 5900},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "loopback allowed only when local is loopback",
|
||||
localAddr: netip.MustParseAddr("127.0.0.1"),
|
||||
network: netip.MustParsePrefix("127.0.0.0/8"),
|
||||
remote: &net.TCPAddr{IP: net.ParseIP("127.0.0.5"), Port: 5900},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "invalid network rejected (fail-closed)",
|
||||
localAddr: netip.MustParseAddr("10.99.99.1"),
|
||||
network: netip.Prefix{},
|
||||
remote: &net.TCPAddr{IP: net.ParseIP("10.99.99.2"), Port: 5900},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.localAddr = tc.localAddr
|
||||
srv.network = tc.network
|
||||
assert.Equal(t, tc.want, srv.isAllowedSource(tc.remote))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStart_InvalidNetworkRejected(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
err := srv.Start(t.Context(), addr, netip.Prefix{})
|
||||
require.Error(t, err, "Start must refuse an invalid overlay prefix")
|
||||
assert.Contains(t, err.Error(), "invalid overlay network prefix")
|
||||
}
|
||||
|
||||
func TestAgentToken_MismatchClosesConnection(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(true)
|
||||
srv.SetAgentToken("deadbeefcafebabe")
|
||||
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
require.NoError(t, srv.Start(t.Context(), addr, network))
|
||||
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
||||
t.Cleanup(func() { _ = srv.Stop() })
|
||||
|
||||
conn, err := net.Dial("tcp", srv.listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second)))
|
||||
|
||||
// Send a wrong token of the right length (8 bytes hex-decoded).
|
||||
if _, err := conn.Write([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}); err != nil {
|
||||
// Server may already have closed; either way the read below must EOF.
|
||||
_ = err
|
||||
}
|
||||
|
||||
// Server must close without sending the RFB greeting.
|
||||
var version [12]byte
|
||||
_, err = io.ReadFull(conn, version[:])
|
||||
require.Error(t, err, "server must close the connection on bad agent token")
|
||||
}
|
||||
|
||||
func TestAgentToken_MatchAllowsHandshake(t *testing.T) {
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(true)
|
||||
const tokenHex = "deadbeefcafebabe"
|
||||
srv.SetAgentToken(tokenHex)
|
||||
token, err := hex.DecodeString(tokenHex)
|
||||
require.NoError(t, err)
|
||||
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
require.NoError(t, srv.Start(t.Context(), addr, network))
|
||||
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
||||
t.Cleanup(func() { _ = srv.Stop() })
|
||||
|
||||
conn, err := net.Dial("tcp", srv.listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second)))
|
||||
|
||||
_, err = conn.Write(token)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send session header so handleConnection can proceed past readConnectionHeader.
|
||||
header := make([]byte, 11) // ModeAttach + usernameLen=0 + sessionID=0 + width=0 + height=0
|
||||
header[0] = ModeAttach
|
||||
_, err = conn.Write(header)
|
||||
require.NoError(t, err)
|
||||
|
||||
// With a matching token the server proceeds to the RFB greeting.
|
||||
var version [12]byte
|
||||
_, err = io.ReadFull(conn, version[:])
|
||||
require.NoError(t, err, "server must keep the connection open after a valid agent token")
|
||||
assert.Equal(t, "RFB 003.008\n", string(version[:]))
|
||||
}
|
||||
|
||||
func TestSessionMode_RejectedWhenNoVMGR(t *testing.T) {
|
||||
// Default platformSessionManager() on non-Linux returns nil, so ModeSession
|
||||
// must be rejected with the UNSUPPORTED reason rather than crashing.
|
||||
srv := New(&testCapturer{}, &StubInputInjector{}, nil)
|
||||
srv.SetDisableAuth(true)
|
||||
|
||||
addr := netip.MustParseAddrPort("127.0.0.1:0")
|
||||
network := netip.MustParsePrefix("127.0.0.0/8")
|
||||
require.NoError(t, srv.Start(t.Context(), addr, network))
|
||||
srv.localAddr = netip.MustParseAddr("10.99.99.1")
|
||||
// Force vmgr to nil regardless of platform so the test is deterministic.
|
||||
srv.vmgr = nil
|
||||
t.Cleanup(func() { _ = srv.Stop() })
|
||||
|
||||
conn, err := net.Dial("tcp", srv.listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
require.NoError(t, conn.SetDeadline(time.Now().Add(10*time.Second)))
|
||||
|
||||
// ModeSession with no username, so we exit on the vmgr==nil branch
|
||||
// before username validation runs.
|
||||
header := []byte{ModeSession, 0, 0, 0, 0}
|
||||
_, err = conn.Write(header)
|
||||
require.NoError(t, err)
|
||||
|
||||
var version [12]byte
|
||||
_, err = io.ReadFull(conn, version[:])
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Write(version[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
var numTypes [1]byte
|
||||
_, err = io.ReadFull(conn, numTypes[:])
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, byte(0), numTypes[0])
|
||||
|
||||
var reasonLen [4]byte
|
||||
_, err = io.ReadFull(conn, reasonLen[:])
|
||||
require.NoError(t, err)
|
||||
reason := make([]byte, binary.BigEndian.Uint32(reasonLen[:]))
|
||||
_, err = io.ReadFull(conn, reason)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(reason), RejectCodeUnsupportedOS)
|
||||
}
|
||||
322
client/vnc/server/server_windows.go
Normal file
322
client/vnc/server/server_windows.go
Normal file
@@ -0,0 +1,322 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
var (
|
||||
sasDLL = windows.NewLazySystemDLL("sas.dll")
|
||||
procSendSAS = sasDLL.NewProc("SendSAS")
|
||||
|
||||
procConvertStringSecurityDescriptorToSecurityDescriptor = advapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
|
||||
)
|
||||
|
||||
// sasSecurityAttributes builds a SECURITY_ATTRIBUTES that grants
|
||||
// EVENT_MODIFY_STATE only to the SYSTEM account, preventing unprivileged
|
||||
// local processes from triggering the Secure Attention Sequence.
|
||||
func sasSecurityAttributes() (*windows.SecurityAttributes, error) {
|
||||
// SDDL: grant full access to SYSTEM (creates/waits) and EVENT_MODIFY_STATE
|
||||
// to the interactive user (IU) so the VNC agent in the console session can
|
||||
// signal it. Other local users and network users are denied.
|
||||
sddl, err := windows.UTF16PtrFromString("D:(A;;GA;;;SY)(A;;0x0002;;;IU)")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var sd uintptr
|
||||
r, _, lerr := procConvertStringSecurityDescriptorToSecurityDescriptor.Call(
|
||||
uintptr(unsafe.Pointer(sddl)),
|
||||
1, // SDDL_REVISION_1
|
||||
uintptr(unsafe.Pointer(&sd)),
|
||||
0,
|
||||
)
|
||||
if r == 0 {
|
||||
return nil, lerr
|
||||
}
|
||||
return &windows.SecurityAttributes{
|
||||
Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})),
|
||||
SecurityDescriptor: (*windows.SECURITY_DESCRIPTOR)(unsafe.Pointer(sd)),
|
||||
InheritHandle: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// sasOriginalState tracks the SoftwareSASGeneration value present before we
|
||||
// changed it, so disableSoftwareSAS can restore the machine to its prior
|
||||
// state on shutdown instead of leaving the policy enabled.
|
||||
type sasOriginalState struct {
|
||||
had bool // true if the value existed before we wrote
|
||||
value uint32 // its prior DWORD value, if had == true
|
||||
}
|
||||
|
||||
var savedSASState sasOriginalState
|
||||
|
||||
// enableSoftwareSAS sets the SoftwareSASGeneration registry key to allow
|
||||
// services to trigger the Secure Attention Sequence via SendSAS. Without this,
|
||||
// SendSAS silently does nothing on most Windows editions. The original value
|
||||
// is snapshotted so disableSoftwareSAS can put the system back as it was.
|
||||
func enableSoftwareSAS() {
|
||||
key, _, err := registry.CreateKey(
|
||||
registry.LOCAL_MACHINE,
|
||||
`SOFTWARE\Microsoft\Windows\CurrentVersion\Policies\System`,
|
||||
registry.SET_VALUE|registry.QUERY_VALUE,
|
||||
)
|
||||
if err != nil {
|
||||
log.Warnf("open SoftwareSASGeneration registry key: %v", err)
|
||||
return
|
||||
}
|
||||
defer key.Close()
|
||||
|
||||
if prev, _, err := key.GetIntegerValue("SoftwareSASGeneration"); err == nil {
|
||||
savedSASState = sasOriginalState{had: true, value: uint32(prev)}
|
||||
} else {
|
||||
savedSASState = sasOriginalState{had: false}
|
||||
}
|
||||
|
||||
if err := key.SetDWordValue("SoftwareSASGeneration", 1); err != nil {
|
||||
log.Warnf("set SoftwareSASGeneration: %v", err)
|
||||
return
|
||||
}
|
||||
log.Debug("SoftwareSASGeneration registry key set to 1 (services allowed)")
|
||||
}
|
||||
|
||||
// disableSoftwareSAS restores the SoftwareSASGeneration value to its
|
||||
// pre-enable state. Idempotent; safe to call when enableSoftwareSAS never ran.
|
||||
func disableSoftwareSAS() {
|
||||
key, err := registry.OpenKey(
|
||||
registry.LOCAL_MACHINE,
|
||||
`SOFTWARE\Microsoft\Windows\CurrentVersion\Policies\System`,
|
||||
registry.SET_VALUE,
|
||||
)
|
||||
if err != nil {
|
||||
log.Debugf("open SoftwareSASGeneration for restore: %v", err)
|
||||
return
|
||||
}
|
||||
defer key.Close()
|
||||
|
||||
if savedSASState.had {
|
||||
if err := key.SetDWordValue("SoftwareSASGeneration", savedSASState.value); err != nil {
|
||||
log.Warnf("restore SoftwareSASGeneration to %d: %v", savedSASState.value, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := key.DeleteValue("SoftwareSASGeneration"); err != nil {
|
||||
log.Debugf("delete SoftwareSASGeneration: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// startSASListener creates a named event with a restricted DACL and waits for
|
||||
// the VNC input injector to signal it. When signaled, it calls SendSAS(FALSE)
|
||||
// from Session 0 to trigger the Secure Attention Sequence (Ctrl+Alt+Del).
|
||||
// Only SYSTEM processes can open the event.
|
||||
//
|
||||
// sas.dll / SendSAS is part of the Desktop Experience feature: present on
|
||||
// client SKUs (Win10/11) and Server SKUs with Desktop Experience installed,
|
||||
// missing on Server Core. We probe for the symbol at startup; if absent we
|
||||
// don't register the listener and the agent will silently drop SAS keysyms,
|
||||
// rather than panicking the entire service every time the user clicks
|
||||
// Ctrl+Alt+Del.
|
||||
func startSASListener(ctx context.Context) {
|
||||
ev, ok := createSASEvent()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Info("SAS listener ready (Session 0)")
|
||||
go runSASListenerLoop(ctx, ev)
|
||||
}
|
||||
|
||||
// createSASEvent prepares the named event handle on which the SAS listener
|
||||
// waits for client signals. Returns ok=false (with the failure already
|
||||
// logged) when the platform doesn't support SAS or the event cannot be
|
||||
// created; the caller must not spawn the listener goroutine in that case.
|
||||
func createSASEvent() (windows.Handle, bool) {
|
||||
if err := procSendSAS.Find(); err != nil {
|
||||
log.Warnf("SAS unavailable on this Windows SKU (sas.dll/SendSAS not present): %v", err)
|
||||
return 0, false
|
||||
}
|
||||
enableSoftwareSAS()
|
||||
namePtr, err := windows.UTF16PtrFromString(sasEventName)
|
||||
if err != nil {
|
||||
log.Warnf("SAS listener UTF16: %v", err)
|
||||
return 0, false
|
||||
}
|
||||
sa, err := sasSecurityAttributes()
|
||||
if err != nil {
|
||||
log.Warnf("build SAS security descriptor: %v", err)
|
||||
return 0, false
|
||||
}
|
||||
ev, err := windows.CreateEvent(sa, 0, 0, namePtr)
|
||||
if err != nil {
|
||||
log.Warnf("SAS CreateEvent: %v", err)
|
||||
return 0, false
|
||||
}
|
||||
return ev, true
|
||||
}
|
||||
|
||||
// runSASListenerLoop blocks on ev and invokes SendSAS each time it is
|
||||
// signalled, until ctx is cancelled. Recovers from panics inside SendSAS so
|
||||
// a future ABI surprise doesn't tear down the service.
|
||||
func runSASListenerLoop(ctx context.Context, ev windows.Handle) {
|
||||
defer func() { _ = windows.CloseHandle(ev) }()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Warnf("SAS listener recovered from panic: %v", r)
|
||||
}
|
||||
}()
|
||||
const pollMillis = 500
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
ret, _ := windows.WaitForSingleObject(ev, pollMillis)
|
||||
if ret != windows.WAIT_OBJECT_0 {
|
||||
continue
|
||||
}
|
||||
r, _, sasErr := procSendSAS.Call(0) // FALSE = not from service desktop
|
||||
if r == 0 {
|
||||
log.Warnf("SendSAS: %v", sasErr)
|
||||
continue
|
||||
}
|
||||
log.Info("SendSAS called from Session 0")
|
||||
}
|
||||
}
|
||||
|
||||
// enablePrivilege enables a named privilege on the current process token.
|
||||
func enablePrivilege(name string) error {
|
||||
var token windows.Token
|
||||
if err := windows.OpenProcessToken(windows.CurrentProcess(),
|
||||
windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, &token); err != nil {
|
||||
return err
|
||||
}
|
||||
defer token.Close()
|
||||
|
||||
var luid windows.LUID
|
||||
namePtr, err := windows.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("UTF16 privilege name: %w", err)
|
||||
}
|
||||
if err := windows.LookupPrivilegeValue(nil, namePtr, &luid); err != nil {
|
||||
return err
|
||||
}
|
||||
tp := windows.Tokenprivileges{PrivilegeCount: 1}
|
||||
tp.Privileges[0].Luid = luid
|
||||
tp.Privileges[0].Attributes = windows.SE_PRIVILEGE_ENABLED
|
||||
return windows.AdjustTokenPrivileges(token, false, &tp, 0, nil, nil)
|
||||
}
|
||||
|
||||
func (s *Server) platformSessionManager() virtualSessionManager {
|
||||
return nil
|
||||
}
|
||||
|
||||
// platformShutdown restores any machine state mutated by platformInit.
|
||||
func (s *Server) platformShutdown() {
|
||||
disableSoftwareSAS()
|
||||
}
|
||||
|
||||
// platformInit starts the SAS listener and enables privileges needed for
|
||||
// Session 0 operations (agent spawning, SendSAS).
|
||||
func (s *Server) platformInit() {
|
||||
for _, priv := range []string{"SeTcbPrivilege", "SeAssignPrimaryTokenPrivilege"} {
|
||||
if err := enablePrivilege(priv); err != nil {
|
||||
log.Debugf("enable %s: %v", priv, err)
|
||||
}
|
||||
}
|
||||
startSASListener(s.ctx)
|
||||
}
|
||||
|
||||
// serviceAcceptLoop runs in Session 0. It validates the source IP and
|
||||
// hands accepted connections to handleServiceConnection, which runs the
|
||||
// Noise_IK handshake before proxying to the user-session agent.
|
||||
func (s *Server) serviceAcceptLoop() {
|
||||
|
||||
sm := newSessionManager(agentPort)
|
||||
go sm.run()
|
||||
|
||||
log.Infof("service mode, proxying connections to agent on 127.0.0.1:%d", agentPort)
|
||||
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
sm.Stop()
|
||||
return
|
||||
default:
|
||||
}
|
||||
s.log.Debugf("accept VNC connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.tryAcquireConnSlot() {
|
||||
s.log.Warnf("rejecting VNC connection from %s: %d concurrent connections in flight", conn.RemoteAddr(), maxConcurrentVNCConns)
|
||||
_ = conn.Close()
|
||||
continue
|
||||
}
|
||||
enableTCPKeepAlive(conn, s.log)
|
||||
conn = newMetricsConn(conn, s.sessionRecorder)
|
||||
s.trackConn(conn)
|
||||
go func(c net.Conn) {
|
||||
defer s.releaseConnSlot()
|
||||
defer s.untrackConn(c)
|
||||
s.handleServiceConnection(c, sm)
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// handleServiceConnection runs the connection-header handshake (including
|
||||
// Noise_IK), then proxies the connection (with header bytes replayed) to
|
||||
// the agent listening on loopback.
|
||||
func (s *Server) handleServiceConnection(conn net.Conn, sm *sessionManager) {
|
||||
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
|
||||
|
||||
if !s.isAllowedSource(conn.RemoteAddr()) {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var headerBuf bytes.Buffer
|
||||
tee := io.TeeReader(conn, &headerBuf)
|
||||
teeConn := &prefixConn{Reader: tee, Conn: conn}
|
||||
|
||||
header, err := s.readConnectionHeader(teeConn)
|
||||
if err != nil {
|
||||
connLog.Debugf("read connection header: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
if !s.disableAuth {
|
||||
if _, err := s.authenticateSession(header); err != nil {
|
||||
rejectConnection(conn, codeMessage(RejectCodeAuthForbidden, err.Error()))
|
||||
connLog.Warnf("auth rejected: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
s.registerConnAuth(conn, header)
|
||||
|
||||
// Replay buffered header bytes + remaining stream to the agent.
|
||||
replayConn := &prefixConn{
|
||||
Reader: io.MultiReader(&headerBuf, conn),
|
||||
Conn: conn,
|
||||
}
|
||||
proxyToAgent(s.ctx, replayConn, agentPort, sm.AuthToken())
|
||||
}
|
||||
|
||||
// prefixConn wraps a net.Conn, overriding Read to use a different reader.
|
||||
type prefixConn struct {
|
||||
io.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (p *prefixConn) Read(b []byte) (int, error) {
|
||||
return p.Reader.Read(b)
|
||||
}
|
||||
21
client/vnc/server/server_x11.go
Normal file
21
client/vnc/server/server_x11.go
Normal file
@@ -0,0 +1,21 @@
|
||||
//go:build unix && !darwin && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
func (s *Server) platformInit() {
|
||||
// no-op on X11
|
||||
}
|
||||
|
||||
// serviceAcceptLoop is not supported on Linux.
|
||||
func (s *Server) serviceAcceptLoop() {
|
||||
s.log.Warn("service mode not supported on Linux, falling back to direct mode")
|
||||
s.acceptLoop()
|
||||
}
|
||||
|
||||
func (s *Server) platformSessionManager() virtualSessionManager {
|
||||
return newSessionManager(s.log)
|
||||
}
|
||||
|
||||
func (s *Server) platformShutdown() {
|
||||
// no-op on this platform
|
||||
}
|
||||
633
client/vnc/server/session.go
Normal file
633
client/vnc/server/session.go
Normal file
@@ -0,0 +1,633 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"image"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
readDeadline = 60 * time.Second
|
||||
maxCutTextBytes = 1 << 20 // 1 MiB
|
||||
)
|
||||
|
||||
const tileSize = 64 // pixels per tile for dirty-rect detection
|
||||
|
||||
// fullFramePromoteNum/Den trigger full-frame encoding when the dirty area
|
||||
// exceeds num/den of the screen. Once past the crossover (benchmarks put it
|
||||
// around 60% at 1080p) a single zlib rect is faster than many per-tile
|
||||
// encodes AND produces about the same wire bytes: the per-tile path keeps
|
||||
// restarting zlib dictionaries and re-emitting rect headers.
|
||||
const (
|
||||
fullFramePromoteNum = 60
|
||||
fullFramePromoteDen = 100
|
||||
)
|
||||
|
||||
// bboxPromoteDensityPct collapses the coalesced rect list down to its
|
||||
// bounding box when the dirty pixels occupy at least this fraction of the
|
||||
// bbox. Catches the "windowed video" case where the player area dirties as
|
||||
// a dense block but is split into many sibling rects by overlays or by
|
||||
// non-uniform tile coverage. Sending one JPEG over the bbox beats sending
|
||||
// dozens of small JPEGs that each carry their own header and Tight stream
|
||||
// restart.
|
||||
const (
|
||||
bboxPromoteDensityPct = 70
|
||||
// bboxPromoteMinArea avoids promoting a handful of small scattered
|
||||
// rects whose bbox would span most of the screen and pull in mostly
|
||||
// clean pixels.
|
||||
bboxPromoteMinArea = tileSize * tileSize * 16
|
||||
)
|
||||
|
||||
type session struct {
|
||||
conn net.Conn
|
||||
capturer ScreenCapturer
|
||||
injector InputInjector
|
||||
serverW int
|
||||
serverH int
|
||||
desktopName string
|
||||
log *log.Entry
|
||||
|
||||
writeMu sync.Mutex
|
||||
// encMu guards the negotiated pixel format and encoding state below.
|
||||
// messageLoop writes these on SetPixelFormat/SetEncodings, which RFB
|
||||
// clients may send at any time after the handshake, while encoderLoop
|
||||
// reads them on every frame.
|
||||
encMu sync.RWMutex
|
||||
pf clientPixelFormat
|
||||
useTight bool
|
||||
useCopyRect bool
|
||||
useZlib bool
|
||||
useHextile bool
|
||||
tight *tightState
|
||||
zlib *zlibState
|
||||
copyRectDet *copyRectDetector
|
||||
// Pseudo-encodings the client advertised support for. Updated under
|
||||
// encMu by handleSetEncodings and read by the encoder goroutine.
|
||||
clientSupportsDesktopSize bool
|
||||
clientSupportsExtendedDesktopSize bool
|
||||
clientSupportsDesktopName bool
|
||||
clientSupportsLastRect bool
|
||||
clientSupportsQEMUKey bool
|
||||
clientSupportsExtClipboard bool
|
||||
clientSupportsCursor bool
|
||||
// clientSupportsExtMouseButtons is set when the client advertises the
|
||||
// ExtendedMouseButtons pseudo-encoding (-316). Once the server emits
|
||||
// the ack rect, the client switches its pointer events to the 6-byte
|
||||
// extended format that carries back/forward buttons in a second mask
|
||||
// byte. Without this gate the byte after the type field would still
|
||||
// be a standard 7-bit mask and our parser must not look further.
|
||||
clientSupportsExtMouseButtons bool
|
||||
// extMouseAckSent is set once we've emitted the pseudo-rect ack that
|
||||
// flips the client into extended-pointer mode. Sticky for the
|
||||
// session because the client only needs to see it once.
|
||||
extMouseAckSent bool
|
||||
extClipCapsSent bool
|
||||
// lastCursorSerial is the serial of the cursor sprite last emitted.
|
||||
// The encoder re-queries the source each cycle and only emits when
|
||||
// the serial changes.
|
||||
lastCursorSerial uint64
|
||||
// cursorSourceFailed latches a permanent failure from the cursor
|
||||
// source so the encoder stops polling for the rest of the session.
|
||||
// Reset on SetEncodings so a reconnect can retry.
|
||||
cursorSourceFailed bool
|
||||
// showRemoteCursor switches the encoder to compositing the server
|
||||
// cursor sprite into the captured framebuffer at the remote position
|
||||
// instead of emitting the Cursor pseudo-encoding. Toggled by the
|
||||
// client via clientNetbirdShowRemoteCursor.
|
||||
showRemoteCursor bool
|
||||
// cursorWarnOnce throttles the diagnostic emitted when remote-cursor
|
||||
// compositing falls back to a no-op (capturer cannot supply a sprite
|
||||
// or position). One line per session is enough to point at the cause.
|
||||
cursorWarnOnce sync.Once
|
||||
// clientJPEGQuality and clientZlibLevel hold the 0..9 levels the client
|
||||
// advertised via the QualityLevel / CompressLevel pseudo-encodings, or
|
||||
// -1 when the client has not expressed a preference. Applied to the
|
||||
// tight encoder state after every SetEncodings.
|
||||
clientJPEGQuality int
|
||||
clientZlibLevel int
|
||||
// prevFrame, curFrame and idleFrames live on the encoder goroutine and
|
||||
// must not be touched elsewhere. curFrame holds a session-owned copy of
|
||||
// the capturer's latest frame so the encoder works on a stable buffer
|
||||
// even when the capturer double-buffers and recycles memory underneath.
|
||||
prevFrame *image.RGBA
|
||||
curFrame *image.RGBA
|
||||
idleFrames int
|
||||
|
||||
// captureErrLast throttles "capture (transient)" logs while the
|
||||
// capturer is in a sustained failure state (e.g. X server died but a
|
||||
// client is still connected). Owned by the encoder goroutine.
|
||||
captureErrLast time.Time
|
||||
captureErrSeen bool
|
||||
|
||||
// encodeCh carries framebuffer-update requests from the read loop to the
|
||||
// encoder goroutine. Buffered size 1: RFB clients have one outstanding
|
||||
// request at a time, so a new request always replaces any pending one.
|
||||
encodeCh chan fbRequest
|
||||
|
||||
// pointerMu guards the cached last cursor position used by
|
||||
// releaseStickyInput so the disconnect-time button-release event
|
||||
// targets the cursor's current spot instead of warping to (0, 0).
|
||||
pointerMu sync.Mutex
|
||||
lastPointerX int
|
||||
lastPointerY int
|
||||
}
|
||||
|
||||
type fbRequest struct {
|
||||
incremental bool
|
||||
}
|
||||
|
||||
func (s *session) addr() string { return s.conn.RemoteAddr().String() }
|
||||
|
||||
// serve runs the full RFB session lifecycle.
|
||||
func (s *session) serve() {
|
||||
defer s.conn.Close()
|
||||
s.pf = defaultClientPixelFormat()
|
||||
s.clientJPEGQuality = -1
|
||||
s.clientZlibLevel = -1
|
||||
s.encodeCh = make(chan fbRequest, 1)
|
||||
|
||||
if err := s.handshake(); err != nil {
|
||||
s.log.Warnf("handshake with %s: %v", s.addr(), err)
|
||||
return
|
||||
}
|
||||
s.log.Infof("client connected: %s", s.addr())
|
||||
|
||||
// On any exit path (clean disconnect, transport error, panic) release
|
||||
// modifier keys and mouse buttons so the host doesn't end up with
|
||||
// Shift/Ctrl/Alt or a mouse button stuck because the client dropped
|
||||
// while holding them.
|
||||
defer s.releaseStickyInput()
|
||||
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
go s.clipboardPoll(done)
|
||||
|
||||
encoderDone := make(chan struct{})
|
||||
go s.encoderLoop(encoderDone)
|
||||
defer func() {
|
||||
close(s.encodeCh)
|
||||
<-encoderDone
|
||||
}()
|
||||
|
||||
if err := s.messageLoop(); err != nil && err != io.EOF {
|
||||
s.log.Warnf("client %s disconnected: %v", s.addr(), err)
|
||||
} else {
|
||||
s.log.Infof("client disconnected: %s", s.addr())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) handshake() error {
|
||||
// Send protocol version.
|
||||
if _, err := io.WriteString(s.conn, rfbProtocolVersion); err != nil {
|
||||
return fmt.Errorf("send version: %w", err)
|
||||
}
|
||||
|
||||
// Read client version.
|
||||
var clientVer [12]byte
|
||||
if _, err := io.ReadFull(s.conn, clientVer[:]); err != nil {
|
||||
return fmt.Errorf("read client version: %w", err)
|
||||
}
|
||||
|
||||
// Send supported security types.
|
||||
if err := s.sendSecurityTypes(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read chosen security type.
|
||||
var secType [1]byte
|
||||
if _, err := io.ReadFull(s.conn, secType[:]); err != nil {
|
||||
return fmt.Errorf("read security type: %w", err)
|
||||
}
|
||||
|
||||
if err := s.handleSecurity(secType[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read ClientInit.
|
||||
var clientInit [1]byte
|
||||
if _, err := io.ReadFull(s.conn, clientInit[:]); err != nil {
|
||||
return fmt.Errorf("read ClientInit: %w", err)
|
||||
}
|
||||
|
||||
return s.sendServerInit()
|
||||
}
|
||||
|
||||
// sendSecurityTypes advertises only secNone. Authentication and access
|
||||
// control happen in the NetBird connection header (Noise_IK handshake,
|
||||
// mode, username) that precedes the RFB handshake; the protocol-level
|
||||
// password scheme is not supported.
|
||||
func (s *session) sendSecurityTypes() error {
|
||||
_, err := s.conn.Write([]byte{1, secNone})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *session) handleSecurity(secType byte) error {
|
||||
if secType != secNone {
|
||||
return fmt.Errorf("unsupported security type: %d", secType)
|
||||
}
|
||||
return binary.Write(s.conn, binary.BigEndian, uint32(0))
|
||||
}
|
||||
|
||||
func (s *session) sendServerInit() error {
|
||||
desktop := s.desktopName
|
||||
if desktop == "" {
|
||||
desktop = "NetBird VNC"
|
||||
}
|
||||
name := []byte(desktop)
|
||||
buf := make([]byte, 0, 4+16+4+len(name))
|
||||
|
||||
// Framebuffer width and height.
|
||||
buf = append(buf, byte(s.serverW>>8), byte(s.serverW))
|
||||
buf = append(buf, byte(s.serverH>>8), byte(s.serverH))
|
||||
|
||||
// Server pixel format.
|
||||
buf = append(buf, serverPixelFormat[:]...)
|
||||
|
||||
// Desktop name.
|
||||
buf = append(buf,
|
||||
byte(len(name)>>24), byte(len(name)>>16),
|
||||
byte(len(name)>>8), byte(len(name)),
|
||||
)
|
||||
buf = append(buf, name...)
|
||||
|
||||
_, err := s.conn.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *session) messageLoop() error {
|
||||
for {
|
||||
var msgType [1]byte
|
||||
if err := s.conn.SetDeadline(time.Now().Add(readDeadline)); err != nil {
|
||||
return fmt.Errorf("set deadline: %w", err)
|
||||
}
|
||||
if _, err := io.ReadFull(s.conn, msgType[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var err error
|
||||
switch msgType[0] {
|
||||
case clientSetPixelFormat:
|
||||
err = s.handleSetPixelFormat()
|
||||
case clientSetEncodings:
|
||||
err = s.handleSetEncodings()
|
||||
case clientFramebufferUpdateRequest:
|
||||
err = s.handleFBUpdateRequest()
|
||||
case clientKeyEvent:
|
||||
err = s.handleKeyEvent()
|
||||
case clientPointerEvent:
|
||||
err = s.handlePointerEvent()
|
||||
case clientCutText:
|
||||
err = s.handleCutText()
|
||||
case clientQEMUMessage:
|
||||
err = s.handleQEMUMessage()
|
||||
case clientNetbirdTypeText:
|
||||
err = s.handleTypeText()
|
||||
case clientNetbirdShowRemoteCursor:
|
||||
err = s.handleShowRemoteCursor()
|
||||
default:
|
||||
return fmt.Errorf("unknown client message type: %d", msgType[0])
|
||||
}
|
||||
// Clear the deadline only after the full message has been read and
|
||||
// processed so payload reads in the handlers stay bounded.
|
||||
_ = s.conn.SetDeadline(time.Time{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) handleSetPixelFormat() error {
|
||||
var buf [19]byte // 3 padding + 16 pixel format
|
||||
if _, err := io.ReadFull(s.conn, buf[:]); err != nil {
|
||||
return fmt.Errorf("read SetPixelFormat: %w", err)
|
||||
}
|
||||
pf, err := parsePixelFormat(buf[3:19])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.encMu.Lock()
|
||||
s.pf = pf
|
||||
s.encMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) handleSetEncodings() error {
|
||||
var header [3]byte // 1 padding + 2 number-of-encodings
|
||||
if _, err := io.ReadFull(s.conn, header[:]); err != nil {
|
||||
return fmt.Errorf("read SetEncodings header: %w", err)
|
||||
}
|
||||
numEnc := binary.BigEndian.Uint16(header[1:3])
|
||||
// RFB clients advertise a handful of real encodings plus pseudo-encodings.
|
||||
// Cap to keep a malicious client from forcing a 256 KiB allocation per
|
||||
// SetEncodings message.
|
||||
const maxEncodings = 64
|
||||
if numEnc > maxEncodings {
|
||||
return fmt.Errorf("SetEncodings: too many encodings (%d)", numEnc)
|
||||
}
|
||||
buf := make([]byte, int(numEnc)*4)
|
||||
if _, err := io.ReadFull(s.conn, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encs, sendExtClipCaps, sendExtMouseAck := s.applyEncodings(buf, int(numEnc))
|
||||
if len(encs) > 0 {
|
||||
s.log.Debugf("client supports encodings: %s", strings.Join(encs, ", "))
|
||||
}
|
||||
if sendExtClipCaps {
|
||||
if err := s.writeExtClipMessage(buildExtClipCaps()); err != nil {
|
||||
return fmt.Errorf("send ext clipboard caps: %w", err)
|
||||
}
|
||||
}
|
||||
if sendExtMouseAck {
|
||||
if err := s.sendExtMouseAck(); err != nil {
|
||||
return fmt.Errorf("send ext mouse ack: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyEncodings parses the SetEncodings body, updates capability flags,
|
||||
// rebuilds the tight state if quality/level changed, and reports which
|
||||
// one-shot acknowledgements still need to be sent.
|
||||
func (s *session) applyEncodings(buf []byte, numEnc int) (names []string, sendExtClipCaps, sendExtMouseAck bool) {
|
||||
s.encMu.Lock()
|
||||
defer s.encMu.Unlock()
|
||||
// Per RFC 6143 §7.5.3 each SetEncodings replaces the previous list, so
|
||||
// reset all flags before re-applying. extClipCapsSent stays sticky so
|
||||
// we don't re-emit Caps every refresh.
|
||||
s.resetEncodingCaps()
|
||||
for i := range numEnc {
|
||||
enc := int32(binary.BigEndian.Uint32(buf[i*4 : i*4+4]))
|
||||
if name := s.applyEncoding(enc); name != "" {
|
||||
names = append(names, name)
|
||||
}
|
||||
}
|
||||
s.refreshTightStateLocked()
|
||||
sendExtClipCaps = s.clientSupportsExtClipboard && !s.extClipCapsSent
|
||||
if sendExtClipCaps {
|
||||
s.extClipCapsSent = true
|
||||
}
|
||||
sendExtMouseAck = s.clientSupportsExtMouseButtons && !s.extMouseAckSent
|
||||
if sendExtMouseAck {
|
||||
s.extMouseAckSent = true
|
||||
}
|
||||
return names, sendExtClipCaps, sendExtMouseAck
|
||||
}
|
||||
|
||||
// refreshTightStateLocked reallocates s.tight when the requested quality
|
||||
// or compression level no longer matches the cached state. Caller holds
|
||||
// s.encMu.
|
||||
func (s *session) refreshTightStateLocked() {
|
||||
if !s.useTight {
|
||||
return
|
||||
}
|
||||
if s.tight != nil &&
|
||||
s.tight.qualityLevel == s.clientJPEGQuality &&
|
||||
s.tight.compressLevel == s.clientZlibLevel {
|
||||
return
|
||||
}
|
||||
// When we replace an in-use tightState the client's stream-0
|
||||
// inflater carries dictionary state from the old deflater. Carry
|
||||
// the pending-reset flag so the next Basic rect tells the client
|
||||
// to reset its inflater before decoding.
|
||||
replacing := s.tight != nil
|
||||
s.tight = newTightStateWithLevels(s.clientJPEGQuality, s.clientZlibLevel)
|
||||
if replacing {
|
||||
s.tight.pendingZlibReset = true
|
||||
}
|
||||
}
|
||||
|
||||
// resetEncodingCaps zeroes the encoding capability flags so the next pass
|
||||
// through applyEncoding reflects exactly what the client just advertised.
|
||||
// Caller holds s.encMu. tight / copyRectDet allocations are kept; their
|
||||
// runtime use is gated by the boolean flags here.
|
||||
func (s *session) resetEncodingCaps() {
|
||||
s.useTight = false
|
||||
s.useCopyRect = false
|
||||
s.useZlib = false
|
||||
s.useHextile = false
|
||||
s.clientSupportsDesktopSize = false
|
||||
s.clientSupportsExtendedDesktopSize = false
|
||||
s.clientSupportsDesktopName = false
|
||||
s.clientSupportsLastRect = false
|
||||
s.clientSupportsQEMUKey = false
|
||||
s.clientSupportsExtClipboard = false
|
||||
s.clientSupportsCursor = false
|
||||
s.clientSupportsExtMouseButtons = false
|
||||
s.cursorSourceFailed = false
|
||||
s.clientJPEGQuality = -1
|
||||
s.clientZlibLevel = -1
|
||||
}
|
||||
|
||||
// applyEncoding records a single encoding/pseudo-encoding from a SetEncodings
|
||||
// message. Returns the short name used in the debug log, or "" if the value
|
||||
// is one we don't recognise. Caller holds s.encMu.
|
||||
func (s *session) applyEncoding(enc int32) string {
|
||||
switch enc {
|
||||
case encCopyRect:
|
||||
s.useCopyRect = true
|
||||
if s.copyRectDet == nil {
|
||||
s.copyRectDet = newCopyRectDetector(tileSize)
|
||||
}
|
||||
return "copyrect"
|
||||
case pseudoEncDesktopSize:
|
||||
s.clientSupportsDesktopSize = true
|
||||
return "desktop-size"
|
||||
case pseudoEncExtendedDesktopSize:
|
||||
s.clientSupportsExtendedDesktopSize = true
|
||||
return "ext-desktop-size"
|
||||
case pseudoEncDesktopName:
|
||||
s.clientSupportsDesktopName = true
|
||||
return "desktop-name"
|
||||
case pseudoEncLastRect:
|
||||
s.clientSupportsLastRect = true
|
||||
return "last-rect"
|
||||
case pseudoEncQEMUExtendedKeyEvent:
|
||||
s.clientSupportsQEMUKey = true
|
||||
return "qemu-key"
|
||||
case pseudoEncExtendedClipboard:
|
||||
s.clientSupportsExtClipboard = true
|
||||
return "ext-clipboard"
|
||||
case pseudoEncCursor:
|
||||
s.clientSupportsCursor = true
|
||||
return "cursor"
|
||||
case pseudoEncExtendedMouseButtons:
|
||||
s.clientSupportsExtMouseButtons = true
|
||||
return "ext-mouse-buttons"
|
||||
case encTight:
|
||||
s.useTight = true
|
||||
return "tight"
|
||||
case encZlib:
|
||||
s.useZlib = true
|
||||
if s.zlib == nil {
|
||||
s.zlib = newZlibStateLevel(zlibLevelFor(-1))
|
||||
}
|
||||
return "zlib"
|
||||
case encHextile:
|
||||
s.useHextile = true
|
||||
return "hextile"
|
||||
}
|
||||
if enc >= pseudoEncQualityLevelMin && enc <= pseudoEncQualityLevelMax {
|
||||
s.clientJPEGQuality = int(enc - pseudoEncQualityLevelMin)
|
||||
return fmt.Sprintf("quality=%d", s.clientJPEGQuality)
|
||||
}
|
||||
if enc >= pseudoEncCompressLevelMin && enc <= pseudoEncCompressLevelMax {
|
||||
s.clientZlibLevel = int(enc - pseudoEncCompressLevelMin)
|
||||
return fmt.Sprintf("compress=%d", s.clientZlibLevel)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// handleFBUpdateRequest parses the request and hands it to the encoder
|
||||
// goroutine. It never blocks on capture/encode, so the input dispatch loop
|
||||
// stays responsive even when a previous frame is still being encoded.
|
||||
func (s *session) handleFBUpdateRequest() error {
|
||||
var req [9]byte
|
||||
if _, err := io.ReadFull(s.conn, req[:]); err != nil {
|
||||
return fmt.Errorf("read FBUpdateRequest: %w", err)
|
||||
}
|
||||
r := fbRequest{incremental: req[0] == 1}
|
||||
// Channel is size 1. If a request is already pending, replace it with
|
||||
// this fresher one so the encoder always works on the latest ask.
|
||||
select {
|
||||
case s.encodeCh <- r:
|
||||
default:
|
||||
select {
|
||||
case <-s.encodeCh:
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case s.encodeCh <- r:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendDesktopName pushes a DesktopName pseudo-encoded update to the
|
||||
// client if it advertised support. Lets the client keep its window title
|
||||
// in sync with the active session (e.g. username changes after login on
|
||||
// a virtual session).
|
||||
func (s *session) SendDesktopName(name string) error {
|
||||
s.encMu.RLock()
|
||||
supported := s.clientSupportsDesktopName
|
||||
s.encMu.RUnlock()
|
||||
if !supported {
|
||||
s.desktopName = name
|
||||
return nil
|
||||
}
|
||||
s.desktopName = name
|
||||
header := make([]byte, 4)
|
||||
header[0] = serverFramebufferUpdate
|
||||
binary.BigEndian.PutUint16(header[2:4], 1)
|
||||
|
||||
body := encodeDesktopNameBody(name)
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
if _, err := s.conn.Write(header); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := s.conn.Write(body)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *session) handleKeyEvent() error {
|
||||
var data [7]byte
|
||||
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
|
||||
return fmt.Errorf("read KeyEvent: %w", err)
|
||||
}
|
||||
down := data[0] == 1
|
||||
keysym := binary.BigEndian.Uint32(data[3:7])
|
||||
s.injector.InjectKey(keysym, down)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleQEMUMessage parses one QEMU vendor message. Today we only handle
|
||||
// subtype 0 (Extended Key Event); the message itself is 12 bytes total so
|
||||
// reading 11 more after the type byte covers the layout regardless of
|
||||
// subtype, and unknown subtypes are dropped without aborting the session.
|
||||
func (s *session) handleQEMUMessage() error {
|
||||
var data [11]byte // subtype(1) + down(2) + keysym(4) + keycode(4)
|
||||
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
|
||||
return fmt.Errorf("read QEMU message: %w", err)
|
||||
}
|
||||
subtype := data[0]
|
||||
if subtype != qemuSubtypeExtendedKeyEvent {
|
||||
s.log.Tracef("ignoring QEMU subtype %d", subtype)
|
||||
return nil
|
||||
}
|
||||
down := binary.BigEndian.Uint16(data[1:3]) != 0
|
||||
keysym := binary.BigEndian.Uint32(data[3:7])
|
||||
scancode := binary.BigEndian.Uint32(data[7:11])
|
||||
s.injector.InjectKeyScancode(scancode, keysym, down)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) handlePointerEvent() error {
|
||||
var data [5]byte
|
||||
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
|
||||
return fmt.Errorf("read PointerEvent: %w", err)
|
||||
}
|
||||
mask := uint16(data[0])
|
||||
x := int(binary.BigEndian.Uint16(data[1:3]))
|
||||
y := int(binary.BigEndian.Uint16(data[3:5]))
|
||||
|
||||
s.encMu.RLock()
|
||||
extended := s.clientSupportsExtMouseButtons && s.extMouseAckSent
|
||||
s.encMu.RUnlock()
|
||||
if extended && mask&0x80 != 0 {
|
||||
var hi [1]byte
|
||||
if _, err := io.ReadFull(s.conn, hi[:]); err != nil {
|
||||
return fmt.Errorf("read ExtendedPointerEvent tail: %w", err)
|
||||
}
|
||||
// Strip the marker bit; bits 0..6 are the low part of the mask,
|
||||
// hi byte holds bits 7..14 (back at bit 7, forward at bit 8).
|
||||
mask = (mask & 0x7f) | uint16(hi[0])<<7
|
||||
}
|
||||
|
||||
s.pointerMu.Lock()
|
||||
s.lastPointerX = x
|
||||
s.lastPointerY = y
|
||||
s.pointerMu.Unlock()
|
||||
s.injector.InjectPointer(mask, x, y, s.serverW, s.serverH)
|
||||
return nil
|
||||
}
|
||||
|
||||
// stickyModifierKeysyms are the X11 keysyms we send "up" events for on
|
||||
// disconnect. Modifier-up while not held is a no-op on every supported
|
||||
// platform, so we can blanket-release without per-key tracking. This
|
||||
// covers the practical sticky-state bug: client drops while user is
|
||||
// holding Shift / Ctrl / Alt / Meta / Super.
|
||||
var stickyModifierKeysyms = [...]uint32{
|
||||
0xffe1, 0xffe2, // Shift_L, Shift_R
|
||||
0xffe3, 0xffe4, // Control_L, Control_R
|
||||
0xffe9, 0xffea, // Alt_L, Alt_R
|
||||
0xffe7, 0xffe8, // Meta_L, Meta_R
|
||||
0xffeb, 0xffec, // Super_L, Super_R
|
||||
0xff7e, // Mode_switch
|
||||
0xfe03, // ISO_Level3_Shift (AltGr)
|
||||
0xffe5, // Caps_Lock (release if user dropped mid-press)
|
||||
}
|
||||
|
||||
// releaseStickyInput synthesizes key-up for modifier keysyms and a
|
||||
// zero-button PointerEvent so the host doesn't end up with stuck input
|
||||
// when the client disconnects mid-press. Mouse coordinates are reused
|
||||
// from the last PointerEvent so we don't warp the cursor.
|
||||
func (s *session) releaseStickyInput() {
|
||||
for _, ks := range stickyModifierKeysyms {
|
||||
s.injector.InjectKey(ks, false)
|
||||
}
|
||||
s.pointerMu.Lock()
|
||||
x, y := s.lastPointerX, s.lastPointerY
|
||||
s.pointerMu.Unlock()
|
||||
s.injector.InjectPointer(0, x, y, s.serverW, s.serverH)
|
||||
}
|
||||
260
client/vnc/server/session_clipboard.go
Normal file
260
client/vnc/server/session_clipboard.go
Normal file
@@ -0,0 +1,260 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// clipboardPoll periodically checks the server-side clipboard and sends
|
||||
// changes to the VNC client. Only runs during active sessions.
|
||||
func (s *session) clipboardPoll(done <-chan struct{}) {
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastClip string
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
text := s.injector.GetClipboard()
|
||||
if len(text) > maxCutTextBytes {
|
||||
text = text[:maxCutTextBytes]
|
||||
}
|
||||
if text == "" || text == lastClip {
|
||||
continue
|
||||
}
|
||||
lastClip = text
|
||||
s.encMu.RLock()
|
||||
ext := s.clientSupportsExtClipboard
|
||||
s.encMu.RUnlock()
|
||||
if ext {
|
||||
if err := s.writeExtClipMessage(buildExtClipNotify(extClipFormatText)); err != nil {
|
||||
s.log.Debugf("send ext clipboard notify: %v", err)
|
||||
return
|
||||
}
|
||||
} else if err := s.sendServerCutText(text); err != nil {
|
||||
s.log.Debugf("send clipboard to client: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) handleCutText() error {
|
||||
var header [7]byte // 3 padding + 4 length
|
||||
if _, err := io.ReadFull(s.conn, header[:]); err != nil {
|
||||
return fmt.Errorf("read CutText header: %w", err)
|
||||
}
|
||||
rawLen := int32(binary.BigEndian.Uint32(header[3:7]))
|
||||
if rawLen < 0 {
|
||||
// Negative length signals ExtendedClipboard; absolute value is the
|
||||
// payload size. Guard against MinInt32 overflow before negating.
|
||||
if rawLen == -2147483648 {
|
||||
return fmt.Errorf("ext clipboard payload too large")
|
||||
}
|
||||
return s.handleExtCutText(uint32(-rawLen))
|
||||
}
|
||||
length := uint32(rawLen)
|
||||
if length > maxCutTextBytes {
|
||||
return fmt.Errorf("cut text too large: %d bytes", length)
|
||||
}
|
||||
buf := make([]byte, length)
|
||||
if _, err := io.ReadFull(s.conn, buf); err != nil {
|
||||
return fmt.Errorf("read CutText payload: %w", err)
|
||||
}
|
||||
s.injector.SetClipboard(latin1ToUTF8(buf))
|
||||
return nil
|
||||
}
|
||||
|
||||
// drainBytes consumes and discards n bytes from the connection. Used to
|
||||
// skip the payload of a malformed clipboard message after we've decided
|
||||
// not to honour it, so the next message stays aligned.
|
||||
func (s *session) drainBytes(n uint32) error {
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
if _, err := io.CopyN(io.Discard, s.conn, int64(n)); err != nil {
|
||||
return fmt.Errorf("drain %d bytes: %w", n, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// latin1ToUTF8 converts an RFB ClientCutText payload (ISO 8859-1 per
|
||||
// RFC 6143 §7.5.6) into a UTF-8 string. Bytes 0x80..0xFF map to the
|
||||
// matching U+0080..U+00FF code points; passing them through Go's
|
||||
// `string([]byte)` instead would produce invalid UTF-8 that downstream
|
||||
// clipboard backends mangle.
|
||||
func latin1ToUTF8(b []byte) string {
|
||||
runes := make([]rune, len(b))
|
||||
for i, c := range b {
|
||||
runes[i] = rune(c)
|
||||
}
|
||||
return string(runes)
|
||||
}
|
||||
|
||||
// utf8ToLatin1 converts a UTF-8 string into the Latin-1 byte sequence
|
||||
// required by legacy ServerCutText (RFC 6143 §7.6.4). Runes outside
|
||||
// U+0000..U+00FF are not representable in Latin-1; we substitute '?' so the
|
||||
// peer still receives a coherent message instead of a truncated or
|
||||
// silently mojibake'd payload. ExtendedClipboard clients take a separate
|
||||
// path that preserves full UTF-8.
|
||||
func utf8ToLatin1(s string) []byte {
|
||||
out := make([]byte, 0, len(s))
|
||||
for _, r := range s {
|
||||
if r > 0xFF {
|
||||
out = append(out, '?')
|
||||
continue
|
||||
}
|
||||
out = append(out, byte(r))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// handleExtCutText parses an ExtendedClipboard message (any of Caps,
|
||||
// Notify, Request, Peek, Provide) carried as a negative-length CutText.
|
||||
// Unknown actions, oversized payloads, and formats we don't handle
|
||||
// (RTF/HTML/DIB/Files) are logged and dropped instead of aborting the
|
||||
// session: a malformed clipboard message must never cost the user their
|
||||
// VNC connection. Read errors on the socket itself still propagate.
|
||||
func (s *session) handleExtCutText(payloadLen uint32) error {
|
||||
if payloadLen < 4 {
|
||||
s.log.Debugf("ext clipboard payload too short: %d", payloadLen)
|
||||
return s.drainBytes(payloadLen)
|
||||
}
|
||||
if payloadLen > extClipMaxPayload {
|
||||
s.log.Debugf("ext clipboard payload too large: %d", payloadLen)
|
||||
return s.drainBytes(payloadLen)
|
||||
}
|
||||
buf := make([]byte, payloadLen)
|
||||
if _, err := io.ReadFull(s.conn, buf); err != nil {
|
||||
return fmt.Errorf("read ext clipboard payload: %w", err)
|
||||
}
|
||||
flags := binary.BigEndian.Uint32(buf[0:4])
|
||||
action := flags & extClipActionMask
|
||||
formats := flags & extClipFormatMask
|
||||
rest := buf[4:]
|
||||
|
||||
// A Caps message sets the Caps bit alongside one bit per action the
|
||||
// peer supports, so the action byte is multi-bit. Detect it first; the
|
||||
// remaining actions are single-bit and are dispatched after.
|
||||
if action&extClipActionCaps != 0 {
|
||||
// Client max sizes are informational for us today: we only emit
|
||||
// text and already cap it at extClipMaxText.
|
||||
return nil
|
||||
}
|
||||
|
||||
switch action {
|
||||
case extClipActionRequest:
|
||||
if formats&extClipFormatText != 0 {
|
||||
return s.sendExtClipProvideText()
|
||||
}
|
||||
return nil
|
||||
case extClipActionPeek:
|
||||
return s.writeExtClipMessage(buildExtClipNotify(extClipFormatText))
|
||||
case extClipActionNotify:
|
||||
if formats&extClipFormatText != 0 {
|
||||
return s.writeExtClipMessage(buildExtClipRequest(extClipFormatText))
|
||||
}
|
||||
return nil
|
||||
case extClipActionProvide:
|
||||
s.handleExtClipProvide(flags, rest)
|
||||
return nil
|
||||
default:
|
||||
s.log.Debugf("unknown ext clipboard action 0x%x", action)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// handleExtClipProvide decodes a Provide payload and pushes the recovered
|
||||
// text into the host clipboard. Decode errors and unsupported formats (RTF,
|
||||
// HTML, etc.) are logged and dropped so a malformed message doesn't tear
|
||||
// down the session.
|
||||
func (s *session) handleExtClipProvide(flags uint32, payload []byte) {
|
||||
if len(payload) == 0 {
|
||||
return
|
||||
}
|
||||
text, err := parseExtClipProvideText(flags, payload)
|
||||
if err != nil {
|
||||
s.log.Debugf("parse ext clipboard provide: %v", err)
|
||||
return
|
||||
}
|
||||
if text != "" {
|
||||
s.injector.SetClipboard(text)
|
||||
}
|
||||
}
|
||||
|
||||
// sendExtClipProvideText answers an inbound Request(text) with the current
|
||||
// host clipboard contents, capped to extClipMaxText.
|
||||
func (s *session) sendExtClipProvideText() error {
|
||||
text := s.injector.GetClipboard()
|
||||
if len(text) > extClipMaxText {
|
||||
text = text[:extClipMaxText]
|
||||
}
|
||||
payload, err := buildExtClipProvideText(text)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build provide: %w", err)
|
||||
}
|
||||
return s.writeExtClipMessage(payload)
|
||||
}
|
||||
|
||||
// writeExtClipMessage frames an ExtendedClipboard payload as a ServerCutText
|
||||
// message with a negative length, then writes it under writeMu.
|
||||
func (s *session) writeExtClipMessage(payload []byte) error {
|
||||
if len(payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
buf := make([]byte, 8+len(payload))
|
||||
buf[0] = serverCutText
|
||||
// buf[1:4] = padding (zero)
|
||||
binary.BigEndian.PutUint32(buf[4:8], uint32(-int32(len(payload))))
|
||||
copy(buf[8:], payload)
|
||||
|
||||
s.writeMu.Lock()
|
||||
_, err := s.conn.Write(buf)
|
||||
s.writeMu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// handleTypeText handles the NetBird-specific PasteAndType message that
|
||||
// pushes host clipboard content as synthesized keystrokes, used to reach
|
||||
// secure desktops where the clipboard is isolated. Wire format mirrors
|
||||
// CutText: 3-byte padding + 4-byte length + text bytes.
|
||||
func (s *session) handleTypeText() error {
|
||||
var header [7]byte
|
||||
if _, err := io.ReadFull(s.conn, header[:]); err != nil {
|
||||
return fmt.Errorf("read TypeText header: %w", err)
|
||||
}
|
||||
length := binary.BigEndian.Uint32(header[3:7])
|
||||
if length > maxCutTextBytes {
|
||||
return fmt.Errorf("type text too large: %d bytes", length)
|
||||
}
|
||||
buf := make([]byte, length)
|
||||
if _, err := io.ReadFull(s.conn, buf); err != nil {
|
||||
return fmt.Errorf("read TypeText payload: %w", err)
|
||||
}
|
||||
s.injector.TypeText(string(buf))
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendServerCutText sends clipboard text from the server to the legacy
|
||||
// (non-ExtendedClipboard) client. The wire encoding is Latin-1; runes that
|
||||
// fall outside U+0000..U+00FF are best-effort replaced with '?' since the
|
||||
// peer cannot represent them.
|
||||
func (s *session) sendServerCutText(text string) error {
|
||||
data := utf8ToLatin1(text)
|
||||
buf := make([]byte, 8+len(data))
|
||||
buf[0] = serverCutText
|
||||
// buf[1:4] = padding (zero)
|
||||
binary.BigEndian.PutUint32(buf[4:8], uint32(len(data)))
|
||||
copy(buf[8:], data)
|
||||
|
||||
s.writeMu.Lock()
|
||||
_, err := s.conn.Write(buf)
|
||||
s.writeMu.Unlock()
|
||||
return err
|
||||
}
|
||||
88
client/vnc/server/session_cursor.go
Normal file
88
client/vnc/server/session_cursor.go
Normal file
@@ -0,0 +1,88 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"image"
|
||||
)
|
||||
|
||||
// pendingCursorRect returns the Cursor pseudo-rect for the current sprite
|
||||
// when the client negotiated the encoding and the platform exposes a
|
||||
// cursor source whose serial has changed since the last emission. A nil
|
||||
// return means "do not include a cursor rect in this FramebufferUpdate".
|
||||
func (s *session) pendingCursorRect() []byte {
|
||||
s.encMu.RLock()
|
||||
supported := s.clientSupportsCursor
|
||||
failed := s.cursorSourceFailed
|
||||
composite := s.showRemoteCursor
|
||||
lastSerial := s.lastCursorSerial
|
||||
s.encMu.RUnlock()
|
||||
if !supported || failed || composite {
|
||||
return nil
|
||||
}
|
||||
src, ok := s.capturer.(cursorSource)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
img, hotX, hotY, serial, err := src.Cursor()
|
||||
if err != nil {
|
||||
s.encMu.Lock()
|
||||
s.cursorSourceFailed = true
|
||||
s.encMu.Unlock()
|
||||
s.log.Debugf("cursor source unavailable: %v", err)
|
||||
return nil
|
||||
}
|
||||
if img == nil || serial == lastSerial {
|
||||
return nil
|
||||
}
|
||||
buf := encodeCursorPseudoRect(img, hotX, hotY)
|
||||
s.encMu.Lock()
|
||||
s.lastCursorSerial = serial
|
||||
s.encMu.Unlock()
|
||||
return buf
|
||||
}
|
||||
|
||||
// encodeCursorPseudoRect packs the cursor sprite into a Cursor pseudo
|
||||
// rectangle (RFB 7.7.4, pseudo-encoding -239). Layout: 12-byte rect header
|
||||
// followed by w*h*4 BGRX pixel bytes and a 1-bit mask of (w+7)/8 bytes per
|
||||
// row, MSB-first, with each row independently padded.
|
||||
func encodeCursorPseudoRect(img *image.RGBA, hotX, hotY int) []byte {
|
||||
w, h := img.Rect.Dx(), img.Rect.Dy()
|
||||
pixelBytes := w * h * 4
|
||||
maskStride := (w + 7) / 8
|
||||
maskBytes := maskStride * h
|
||||
buf := make([]byte, 12+pixelBytes+maskBytes)
|
||||
|
||||
binary.BigEndian.PutUint16(buf[0:2], uint16(hotX))
|
||||
binary.BigEndian.PutUint16(buf[2:4], uint16(hotY))
|
||||
binary.BigEndian.PutUint16(buf[4:6], uint16(w))
|
||||
binary.BigEndian.PutUint16(buf[6:8], uint16(h))
|
||||
enc := int32(pseudoEncCursor)
|
||||
binary.BigEndian.PutUint32(buf[8:12], uint32(enc))
|
||||
|
||||
pix := buf[12 : 12+pixelBytes]
|
||||
mask := buf[12+pixelBytes:]
|
||||
src := img.Pix
|
||||
stride := img.Stride
|
||||
for y := 0; y < h; y++ {
|
||||
row := y * stride
|
||||
dstRow := y * w * 4
|
||||
maskRow := y * maskStride
|
||||
for x := 0; x < w; x++ {
|
||||
r := src[row+x*4+0]
|
||||
g := src[row+x*4+1]
|
||||
b := src[row+x*4+2]
|
||||
a := src[row+x*4+3]
|
||||
off := dstRow + x*4
|
||||
pix[off+0] = b
|
||||
pix[off+1] = g
|
||||
pix[off+2] = r
|
||||
pix[off+3] = 0
|
||||
if a >= 0x80 {
|
||||
mask[maskRow+x/8] |= 0x80 >> (x % 8)
|
||||
}
|
||||
}
|
||||
}
|
||||
return buf
|
||||
}
|
||||
586
client/vnc/server/session_encode.go
Normal file
586
client/vnc/server/session_encode.go
Normal file
@@ -0,0 +1,586 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
"time"
|
||||
)
|
||||
|
||||
// encoderLoop owns the capture → diff → encode → write pipeline. Running it
|
||||
// off the read loop prevents a slow encode (zlib full-frame, many dirty
|
||||
// tiles) from blocking inbound input events.
|
||||
func (s *session) encoderLoop(done chan<- struct{}) {
|
||||
defer close(done)
|
||||
for req := range s.encodeCh {
|
||||
if err := s.processFBRequest(req); err != nil {
|
||||
s.log.Debugf("encode: %v", err)
|
||||
// On write/capture error, close the connection so messageLoop
|
||||
// exits and the session terminates cleanly.
|
||||
s.conn.Close()
|
||||
drainRequests(s.encodeCh)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) processFBRequest(req fbRequest) error {
|
||||
// Watch for resolution changes between cycles. When the capturer
|
||||
// reports a new size, tell the client via DesktopSize so it can
|
||||
// reallocate its backing buffer; the next full update will then fill
|
||||
// the new dimensions. Clients that didn't advertise support are stuck
|
||||
// with the original handshake size and just see clipping on resize.
|
||||
if err := s.handleResize(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
busy := s.applyBackpressure()
|
||||
if busy >= backpressureSkipThreshold {
|
||||
return s.sendEmptyUpdate()
|
||||
}
|
||||
|
||||
img, err := s.captureFrame()
|
||||
if errors.Is(err, errFrameUnchanged) {
|
||||
// macOS hashes the raw capture bytes and short-circuits when the
|
||||
// screen is byte-identical. Treat as "no dirty rects" to skip the
|
||||
// diff and send an empty update.
|
||||
s.idleFrames++
|
||||
delay := min(s.idleFrames*5, 100)
|
||||
time.Sleep(time.Duration(delay) * time.Millisecond)
|
||||
return s.sendEmptyUpdate()
|
||||
}
|
||||
if err != nil {
|
||||
// Capture failures are transient on Windows: a Ctrl+Alt+Del or
|
||||
// sign-out switches the OS to the secure desktop, and the DXGI
|
||||
// duplicator on the previous desktop returns an error until the
|
||||
// capturer reattaches on the new desktop. On Linux the X server
|
||||
// behind a virtual session may exit and the capturer reports
|
||||
// "unavailable" on every retry tick. Don't tear down the session
|
||||
// and don't spam the log: emit one line on the first failure, then
|
||||
// throttle further "still failing" lines to once per 5 s.
|
||||
s.captureErrorLog(err)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return s.sendEmptyUpdate()
|
||||
}
|
||||
s.captureRecovered()
|
||||
|
||||
s.maybeCompositeCursor(img)
|
||||
|
||||
if req.incremental && s.prevFrame != nil {
|
||||
return s.processIncremental(img)
|
||||
}
|
||||
|
||||
// Full update.
|
||||
s.idleFrames = 0
|
||||
if err := s.sendFullUpdate(img); err != nil {
|
||||
return err
|
||||
}
|
||||
s.swapPrevCur()
|
||||
s.refreshCopyRectIndex()
|
||||
return nil
|
||||
}
|
||||
|
||||
// processIncremental handles the diff/encode path for a non-initial frame.
|
||||
// Returns nil after writing either an empty update (no changes) or a mix of
|
||||
// CopyRect moves and pixel-encoded dirty rects.
|
||||
func (s *session) processIncremental(img *image.RGBA) error {
|
||||
tiles := diffTiles(s.prevFrame, img, s.serverW, s.serverH, tileSize)
|
||||
if len(tiles) == 0 {
|
||||
// Nothing changed. Back off briefly before responding to reduce
|
||||
// CPU usage when the screen is static. The client re-requests
|
||||
// immediately after receiving our empty response, so without
|
||||
// this delay we'd spin at ~1000fps checking for changes.
|
||||
s.idleFrames++
|
||||
delay := min(s.idleFrames*5, 100) // 5ms → 100ms adaptive backoff
|
||||
time.Sleep(time.Duration(delay) * time.Millisecond)
|
||||
s.swapPrevCur()
|
||||
return s.sendEmptyUpdate()
|
||||
}
|
||||
s.idleFrames = 0
|
||||
|
||||
// Snapshot the dirty set before extractCopyRectTiles consumes it.
|
||||
// extract mutates in place, so without the copy we lose the
|
||||
// move-destination positions needed to incrementally update the
|
||||
// CopyRect index after the swap.
|
||||
dirty := make([][4]int, len(tiles))
|
||||
copy(dirty, tiles)
|
||||
|
||||
var moves []copyRectMove
|
||||
if s.useCopyRect && s.copyRectDet != nil {
|
||||
moves, tiles = s.copyRectDet.extractCopyRectTiles(img, tiles)
|
||||
}
|
||||
|
||||
rects := coalesceRects(tiles)
|
||||
if s.shouldPromoteToFullFrame(rects) && len(moves) == 0 {
|
||||
if err := s.sendFullUpdate(img); err != nil {
|
||||
return err
|
||||
}
|
||||
s.swapPrevCur()
|
||||
s.refreshCopyRectIndex()
|
||||
return nil
|
||||
}
|
||||
if len(moves) == 0 {
|
||||
if bb, ok := promoteToBoundingBox(rects); ok {
|
||||
rects = bb
|
||||
}
|
||||
}
|
||||
if err := s.sendDirtyAndMoves(img, moves, rects); err != nil {
|
||||
return err
|
||||
}
|
||||
s.swapPrevCur()
|
||||
s.updateCopyRectIndex(dirty)
|
||||
return nil
|
||||
}
|
||||
|
||||
// backpressureSkipThreshold is the BusyFraction at and above which we drop
|
||||
// the next encode entirely and respond with an empty FramebufferUpdate.
|
||||
// Above this level the encoder would only stack more bytes behind a socket
|
||||
// that is already write-blocked, raising end-to-end latency.
|
||||
const backpressureSkipThreshold = 0.65
|
||||
|
||||
// backpressureRampStart is where adaptive quality begins clipping. Below
|
||||
// this fraction the honoured client quality is used as-is.
|
||||
const backpressureRampStart = 0.2
|
||||
|
||||
// backpressureMinQuality is the floor JPEG quality picked when the socket
|
||||
// is fully saturated short of the skip threshold.
|
||||
const backpressureMinQuality = 25
|
||||
|
||||
// applyBackpressure samples the socket BusyFraction (if available) and, if
|
||||
// Tight is in use, ramps the active JPEG quality from the client-honoured
|
||||
// value down to backpressureMinQuality as the fraction climbs from
|
||||
// backpressureRampStart toward backpressureSkipThreshold. Returns the
|
||||
// observed fraction so the caller can decide whether to skip the frame.
|
||||
func (s *session) applyBackpressure() float64 {
|
||||
type busyReporter interface{ BusyFraction() float64 }
|
||||
bs, ok := s.conn.(busyReporter)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
frac := bs.BusyFraction()
|
||||
|
||||
s.encMu.RLock()
|
||||
tight := s.tight
|
||||
s.encMu.RUnlock()
|
||||
if tight == nil {
|
||||
return frac
|
||||
}
|
||||
|
||||
base := jpegQualityForLevel(tight.qualityLevel)
|
||||
if base == 0 {
|
||||
// No client-negotiated quality; let tightQualityFor pick the
|
||||
// area-based default and skip backpressure adjustments that
|
||||
// would otherwise lock in a wrong starting point.
|
||||
tight.jpegQualityOverride = 0
|
||||
return frac
|
||||
}
|
||||
q := base
|
||||
if frac > backpressureRampStart {
|
||||
span := backpressureSkipThreshold - backpressureRampStart
|
||||
t := (frac - backpressureRampStart) / span
|
||||
if t > 1 {
|
||||
t = 1
|
||||
}
|
||||
q = base - int(float64(base-backpressureMinQuality)*t)
|
||||
if q < backpressureMinQuality {
|
||||
q = backpressureMinQuality
|
||||
}
|
||||
}
|
||||
tight.jpegQualityOverride = q
|
||||
return frac
|
||||
}
|
||||
|
||||
// captureErrorLog emits one log line on the first failure after success,
|
||||
// then at most once every captureErrThrottle while the capturer keeps
|
||||
// failing. The "recovered" transition is logged once when err is nil and
|
||||
// captureErrSeen was set.
|
||||
func (s *session) captureErrorLog(err error) {
|
||||
const captureErrThrottle = 5 * time.Second
|
||||
now := time.Now()
|
||||
if !s.captureErrSeen || now.Sub(s.captureErrLast) >= captureErrThrottle {
|
||||
s.log.Debugf("capture (transient): %v", err)
|
||||
s.captureErrLast = now
|
||||
}
|
||||
s.captureErrSeen = true
|
||||
}
|
||||
|
||||
// captureRecovered emits a one-shot debug line when capture works again
|
||||
// after a failure streak. Called by the success paths.
|
||||
func (s *session) captureRecovered() {
|
||||
if s.captureErrSeen {
|
||||
s.log.Debugf("capture recovered")
|
||||
s.captureErrSeen = false
|
||||
}
|
||||
}
|
||||
|
||||
// handleResize detects framebuffer-size changes between encode cycles and
|
||||
// notifies the client via the DesktopSize pseudo-encoding. Returns an
|
||||
// error only on write failure; capturers that don't expose Width/Height
|
||||
// yet (zero values during early startup) are silently ignored.
|
||||
func (s *session) handleResize() error {
|
||||
w, h := s.capturer.Width(), s.capturer.Height()
|
||||
if w <= 0 || h <= 0 {
|
||||
return nil
|
||||
}
|
||||
if w > maxFramebufferDim || h > maxFramebufferDim {
|
||||
s.log.Warnf("ignoring resize: %dx%d exceeds cap %d", w, h, maxFramebufferDim)
|
||||
return nil
|
||||
}
|
||||
if w == s.serverW && h == s.serverH {
|
||||
return nil
|
||||
}
|
||||
s.log.Debugf("framebuffer resized: %dx%d -> %dx%d", s.serverW, s.serverH, w, h)
|
||||
s.serverW = w
|
||||
s.serverH = h
|
||||
// Drop the prev frame so the next encode produces a full update at
|
||||
// the new dimensions rather than diffing against a stale-sized buffer.
|
||||
s.prevFrame = nil
|
||||
s.curFrame = nil
|
||||
if s.copyRectDet != nil {
|
||||
// Tile geometry changed; let updateDirty rebuild from scratch on
|
||||
// the next pass instead of reusing stale hashes keyed on old
|
||||
// (cols, rows).
|
||||
s.copyRectDet.prevTiles = nil
|
||||
s.copyRectDet.tileHash = nil
|
||||
}
|
||||
if err := s.sendDesktopSize(w, h); err != nil {
|
||||
return fmt.Errorf("send desktop size: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendDesktopSize emits a single-rect FramebufferUpdate carrying the
|
||||
// DesktopSize pseudo-encoding. No-op if the client did not negotiate it,
|
||||
// in which case the client just sees the new dimensions on the next full
|
||||
// update and will likely clip or scale.
|
||||
func (s *session) sendDesktopSize(w, h int) error {
|
||||
s.encMu.RLock()
|
||||
supported := s.clientSupportsDesktopSize || s.clientSupportsExtendedDesktopSize
|
||||
s.encMu.RUnlock()
|
||||
if !supported {
|
||||
return nil
|
||||
}
|
||||
header := make([]byte, 4)
|
||||
header[0] = serverFramebufferUpdate
|
||||
binary.BigEndian.PutUint16(header[2:4], 1)
|
||||
|
||||
body := encodeDesktopSizeBody(w, h)
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
if _, err := s.conn.Write(header); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := s.conn.Write(body)
|
||||
return err
|
||||
}
|
||||
|
||||
// sendExtMouseAck emits the pseudo-rect that flips the client into
|
||||
// ExtendedMouseButtons mode, where mouse-back and mouse-forward are
|
||||
// carried in a second mask byte. The rect has zero geometry and no
|
||||
// body; the encoding number alone is the signal.
|
||||
func (s *session) sendExtMouseAck() error {
|
||||
header := make([]byte, 4)
|
||||
header[0] = serverFramebufferUpdate
|
||||
binary.BigEndian.PutUint16(header[2:4], 1)
|
||||
|
||||
rect := make([]byte, 12)
|
||||
enc := int32(pseudoEncExtendedMouseButtons)
|
||||
binary.BigEndian.PutUint32(rect[8:12], uint32(enc))
|
||||
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
if _, err := s.conn.Write(header); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := s.conn.Write(rect)
|
||||
return err
|
||||
}
|
||||
|
||||
// refreshCopyRectIndex does a full hash sweep of the just-swapped prevFrame.
|
||||
// Used after full-frame sends, where we don't have a per-tile dirty list to
|
||||
// drive an incremental update.
|
||||
func (s *session) refreshCopyRectIndex() {
|
||||
if s.copyRectDet == nil || s.prevFrame == nil {
|
||||
return
|
||||
}
|
||||
s.copyRectDet.rebuild(s.prevFrame, s.serverW, s.serverH)
|
||||
}
|
||||
|
||||
// updateCopyRectIndex incrementally updates the CopyRect detector's hash
|
||||
// tables for the tiles that just changed. On first use (or after resize)
|
||||
// updateDirty internally falls back to a full rebuild.
|
||||
func (s *session) updateCopyRectIndex(dirty [][4]int) {
|
||||
if s.copyRectDet == nil || s.prevFrame == nil {
|
||||
return
|
||||
}
|
||||
s.copyRectDet.updateDirty(s.prevFrame, s.serverW, s.serverH, dirty)
|
||||
}
|
||||
|
||||
// captureFrame returns a session-owned frame for this encode cycle.
|
||||
// Capturers that implement captureIntoer (Linux X11, macOS) write directly
|
||||
// into curFrame, saving a per-frame full-screen memcpy. Capturers that
|
||||
// don't (Windows DXGI) return their own buffer which we copy into curFrame
|
||||
// to keep the encoder's prevFrame stable across the next capture cycle.
|
||||
func (s *session) captureFrame() (*image.RGBA, error) {
|
||||
w, h := s.serverW, s.serverH
|
||||
if s.curFrame == nil || s.curFrame.Rect.Dx() != w || s.curFrame.Rect.Dy() != h {
|
||||
s.curFrame = image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
}
|
||||
|
||||
if ci, ok := s.capturer.(captureIntoer); ok {
|
||||
if err := ci.CaptureInto(s.curFrame); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.curFrame, nil
|
||||
}
|
||||
|
||||
src, err := s.capturer.Capture()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.curFrame.Rect != src.Rect {
|
||||
s.curFrame = image.NewRGBA(src.Rect)
|
||||
}
|
||||
copy(s.curFrame.Pix, src.Pix)
|
||||
return s.curFrame, nil
|
||||
}
|
||||
|
||||
// promoteToBoundingBox replaces the rect list with a single rect covering
|
||||
// the bounding box of all inputs, provided the bbox is at least
|
||||
// bboxPromoteMinArea and the dirty pixels fill at least
|
||||
// bboxPromoteDensityPct of it. Returns the new rect list and true when the
|
||||
// promotion fires; otherwise returns nil, false and the caller keeps the
|
||||
// original list.
|
||||
func promoteToBoundingBox(rects [][4]int) ([][4]int, bool) {
|
||||
if len(rects) < 2 {
|
||||
return nil, false
|
||||
}
|
||||
x0, y0 := rects[0][0], rects[0][1]
|
||||
x1, y1 := x0+rects[0][2], y0+rects[0][3]
|
||||
dirty := 0
|
||||
for _, r := range rects {
|
||||
if r[0] < x0 {
|
||||
x0 = r[0]
|
||||
}
|
||||
if r[1] < y0 {
|
||||
y0 = r[1]
|
||||
}
|
||||
if r[0]+r[2] > x1 {
|
||||
x1 = r[0] + r[2]
|
||||
}
|
||||
if r[1]+r[3] > y1 {
|
||||
y1 = r[1] + r[3]
|
||||
}
|
||||
dirty += r[2] * r[3]
|
||||
}
|
||||
w, h := x1-x0, y1-y0
|
||||
bbox := w * h
|
||||
if bbox < bboxPromoteMinArea {
|
||||
return nil, false
|
||||
}
|
||||
if dirty*100 < bbox*bboxPromoteDensityPct {
|
||||
return nil, false
|
||||
}
|
||||
return [][4]int{{x0, y0, w, h}}, true
|
||||
}
|
||||
|
||||
// shouldPromoteToFullFrame returns true when the dirty rect set covers a
|
||||
// large enough fraction of the screen that a single full-frame zlib rect
|
||||
// beats per-tile encoding on both CPU time and wire bytes. The crossover
|
||||
// is measured via BenchmarkEncodeManyTilesVsFullFrame.
|
||||
func (s *session) shouldPromoteToFullFrame(rects [][4]int) bool {
|
||||
if s.serverW == 0 || s.serverH == 0 {
|
||||
return false
|
||||
}
|
||||
var dirty int
|
||||
for _, r := range rects {
|
||||
dirty += r[2] * r[3]
|
||||
}
|
||||
return dirty*fullFramePromoteDen > s.serverW*s.serverH*fullFramePromoteNum
|
||||
}
|
||||
|
||||
// swapPrevCur makes the just-encoded frame the new prevFrame (for the next
|
||||
// diff) and lets the old prevFrame buffer become the next curFrame. Avoids
|
||||
// an 8 MB copy per frame compared to the old savePrevFrame path.
|
||||
func (s *session) swapPrevCur() {
|
||||
s.prevFrame, s.curFrame = s.curFrame, s.prevFrame
|
||||
}
|
||||
|
||||
// sendEmptyUpdate sends a FramebufferUpdate with zero pixel rectangles.
|
||||
// When the cursor source reports a fresh sprite we still slip the Cursor
|
||||
// pseudo-rect into the same message so a shape change (e.g. hovering onto
|
||||
// a resize handle) reaches the client without waiting for a dirty frame.
|
||||
func (s *session) sendEmptyUpdate() error {
|
||||
cursorRect := s.pendingCursorRect()
|
||||
if cursorRect == nil {
|
||||
var buf [4]byte
|
||||
buf[0] = serverFramebufferUpdate
|
||||
return s.writeFramed(buf[:])
|
||||
}
|
||||
buf := make([]byte, 4+len(cursorRect))
|
||||
buf[0] = serverFramebufferUpdate
|
||||
binary.BigEndian.PutUint16(buf[2:4], 1)
|
||||
copy(buf[4:], cursorRect)
|
||||
return s.writeFramed(buf)
|
||||
}
|
||||
|
||||
func (s *session) sendFullUpdate(img *image.RGBA) error {
|
||||
w, h := s.serverW, s.serverH
|
||||
|
||||
s.encMu.RLock()
|
||||
pf := s.pf
|
||||
useTight := s.useTight
|
||||
tight := s.tight
|
||||
useZlib := s.useZlib
|
||||
zlib := s.zlib
|
||||
s.encMu.RUnlock()
|
||||
|
||||
cursorRect := s.pendingCursorRect()
|
||||
rectCount := uint16(1)
|
||||
if cursorRect != nil {
|
||||
rectCount++
|
||||
}
|
||||
|
||||
var rectBuf []byte
|
||||
switch {
|
||||
case useTight && tight != nil && pfIsTightCompatible(pf):
|
||||
rectBuf = encodeTightRect(img, pf, 0, 0, w, h, tight)
|
||||
case useZlib && zlib != nil:
|
||||
// encodeZlibRect bakes in its own FBU header; reuse it for the
|
||||
// single-rect path when there is no cursor to prepend.
|
||||
if cursorRect == nil {
|
||||
return s.writeFramed(encodeZlibRect(img, pf, 0, 0, w, h, zlib))
|
||||
}
|
||||
rectBuf = encodeZlibRect(img, pf, 0, 0, w, h, zlib)[4:]
|
||||
default:
|
||||
if cursorRect == nil {
|
||||
return s.writeFramed(encodeRawRect(img, pf, 0, 0, w, h))
|
||||
}
|
||||
rectBuf = encodeRawRect(img, pf, 0, 0, w, h)[4:]
|
||||
}
|
||||
|
||||
buf := make([]byte, 4+len(cursorRect)+len(rectBuf))
|
||||
buf[0] = serverFramebufferUpdate
|
||||
binary.BigEndian.PutUint16(buf[2:4], rectCount)
|
||||
off := 4
|
||||
off += copy(buf[off:], cursorRect)
|
||||
copy(buf[off:], rectBuf)
|
||||
return s.writeFramed(buf)
|
||||
}
|
||||
|
||||
func (s *session) writeFramed(buf []byte) error {
|
||||
s.writeMu.Lock()
|
||||
_, err := s.conn.Write(buf)
|
||||
s.writeMu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// sendDirtyAndMoves writes one FramebufferUpdate combining CopyRect moves
|
||||
// (cheap, 16 bytes each) and pixel-encoded dirty rects. Moves come first so
|
||||
// their source tiles are read from the client's pre-update framebuffer state,
|
||||
// before any subsequent rect overwrites them.
|
||||
func (s *session) sendDirtyAndMoves(img *image.RGBA, moves []copyRectMove, rects [][4]int) error {
|
||||
cursorRect := s.pendingCursorRect()
|
||||
if len(moves) == 0 && len(rects) == 0 && cursorRect == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
total := len(moves) + len(rects)
|
||||
if cursorRect != nil {
|
||||
total++
|
||||
}
|
||||
header := make([]byte, 4)
|
||||
header[0] = serverFramebufferUpdate
|
||||
binary.BigEndian.PutUint16(header[2:4], uint16(total))
|
||||
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
|
||||
if _, err := s.conn.Write(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cursorRect != nil {
|
||||
if _, err := s.conn.Write(cursorRect); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ts := tileSize
|
||||
for _, m := range moves {
|
||||
body := encodeCopyRectBody(m.srcX, m.srcY, m.dstX, m.dstY, ts, ts)
|
||||
if _, err := s.conn.Write(body); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range rects {
|
||||
x, y, w, h := r[0], r[1], r[2], r[3]
|
||||
rectBuf := s.encodeTile(img, x, y, w, h)
|
||||
if _, err := s.conn.Write(rectBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeTile produces the on-wire rect bytes for a single dirty tile. Tight
|
||||
// is the only non-Raw encoding we negotiate: uniform tiles collapse to its
|
||||
// Fill subencoding (~16 bytes), photo-like rects route to JPEG, and the
|
||||
// rest take the Basic+zlib path. Raw is the fallback when Tight is not
|
||||
// negotiated or the negotiated pixel format is incompatible with Tight's
|
||||
// mandatory 24-bit RGB TPIXEL encoding.
|
||||
//
|
||||
// Output omits the 4-byte FramebufferUpdate header; callers combine multiple
|
||||
// tiles into one message.
|
||||
func (s *session) encodeTile(img *image.RGBA, x, y, w, h int) []byte {
|
||||
s.encMu.RLock()
|
||||
pf := s.pf
|
||||
useHextile := s.useHextile
|
||||
useTight := s.useTight
|
||||
tight := s.tight
|
||||
useZlib := s.useZlib
|
||||
zlib := s.zlib
|
||||
s.encMu.RUnlock()
|
||||
|
||||
if useHextile {
|
||||
if pixel, uniform := tileIsUniform(img, x, y, w, h); uniform {
|
||||
r := byte(pixel)
|
||||
g := byte(pixel >> 8)
|
||||
b := byte(pixel >> 16)
|
||||
return encodeHextileSolidRect(r, g, b, pf, rect{x, y, w, h})
|
||||
}
|
||||
}
|
||||
if useTight && tight != nil && pfIsTightCompatible(pf) {
|
||||
return encodeTightRect(img, pf, x, y, w, h, tight)
|
||||
}
|
||||
if useZlib && zlib != nil {
|
||||
return encodeZlibRect(img, pf, x, y, w, h, zlib)[4:]
|
||||
}
|
||||
return encodeRawRect(img, pf, x, y, w, h)[4:]
|
||||
}
|
||||
|
||||
// drainRequests consumes any pending requests so the sender's close completes
|
||||
// cleanly after the encoder loop has decided to exit on error. Returns the
|
||||
// number of drained requests to defeat empty-block lints; callers ignore it.
|
||||
func drainRequests(ch chan fbRequest) int {
|
||||
var drained int
|
||||
for range ch {
|
||||
drained++
|
||||
}
|
||||
return drained
|
||||
}
|
||||
|
||||
// pfIsTightCompatible reports whether the negotiated client pixel format
|
||||
// satisfies Tight's TPIXEL constraint (RFB 7.7.6): the three RGB shifts form
|
||||
// a permutation of {0, 8, 16} so the colour values live in the low 24 bits.
|
||||
// bpp, endianness, and 8-bit channels are already enforced at SetPixelFormat
|
||||
// time. Any permutation works because Tight always emits a three-byte R, G,
|
||||
// B triple regardless of where the client stores each channel.
|
||||
func pfIsTightCompatible(pf clientPixelFormat) bool {
|
||||
shifts := uint32(1)<<pf.rShift | uint32(1)<<pf.gShift | uint32(1)<<pf.bShift
|
||||
return shifts == 1<<0|1<<8|1<<16
|
||||
}
|
||||
120
client/vnc/server/session_remote_cursor.go
Normal file
120
client/vnc/server/session_remote_cursor.go
Normal file
@@ -0,0 +1,120 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"io"
|
||||
)
|
||||
|
||||
// handleShowRemoteCursor handles the NetBird-specific RFB message that
|
||||
// toggles "show remote cursor" mode. Wire format: 1-byte enable flag
|
||||
// (0/1) plus 6 padding bytes reserved for future arguments.
|
||||
func (s *session) handleShowRemoteCursor() error {
|
||||
var data [7]byte
|
||||
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
|
||||
return fmt.Errorf("read showRemoteCursor: %w", err)
|
||||
}
|
||||
enable := data[0] != 0
|
||||
s.encMu.Lock()
|
||||
s.showRemoteCursor = enable
|
||||
s.encMu.Unlock()
|
||||
s.log.Debugf("show remote cursor: %v", enable)
|
||||
return nil
|
||||
}
|
||||
|
||||
// maybeCompositeCursor blends the current server cursor into img when the
|
||||
// client has enabled "show remote cursor" mode. Returns silently in every
|
||||
// error path: a failed compositing must not stop the regular encode flow.
|
||||
func (s *session) maybeCompositeCursor(img *image.RGBA) {
|
||||
s.encMu.RLock()
|
||||
enabled := s.showRemoteCursor
|
||||
s.encMu.RUnlock()
|
||||
if !enabled || img == nil {
|
||||
return
|
||||
}
|
||||
src, ok := s.capturer.(cursorSource)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
pos, ok := s.capturer.(cursorPositionSource)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
cursorImg, hotX, hotY, _, err := src.Cursor()
|
||||
if err != nil || cursorImg == nil {
|
||||
s.cursorWarnOnce.Do(func() {
|
||||
s.log.Warnf("remote cursor unavailable: %v", err)
|
||||
})
|
||||
return
|
||||
}
|
||||
posX, posY, err := pos.CursorPos()
|
||||
if err != nil {
|
||||
s.cursorWarnOnce.Do(func() {
|
||||
s.log.Warnf("remote cursor position unavailable: %v", err)
|
||||
})
|
||||
return
|
||||
}
|
||||
compositeCursor(img, cursorImg, posX-hotX, posY-hotY)
|
||||
}
|
||||
|
||||
// compositeCursor alpha-blends sprite onto frame at (dstX, dstY).
|
||||
// sprite is assumed to use premultiplied RGBA, which is what every
|
||||
// cursorSource implementation in this package produces (X11 XFixes and
|
||||
// macOS CG return premultiplied bytes natively; the Windows path
|
||||
// premultiplies during decodeColorCursor). Out-of-bounds destinations are
|
||||
// clipped.
|
||||
func compositeCursor(frame, sprite *image.RGBA, dstX, dstY int) {
|
||||
fw, fh := frame.Rect.Dx(), frame.Rect.Dy()
|
||||
sw, sh := sprite.Rect.Dx(), sprite.Rect.Dy()
|
||||
if sw == 0 || sh == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
x0, y0 := dstX, dstY
|
||||
x1, y1 := dstX+sw, dstY+sh
|
||||
if x0 < 0 {
|
||||
x0 = 0
|
||||
}
|
||||
if y0 < 0 {
|
||||
y0 = 0
|
||||
}
|
||||
if x1 > fw {
|
||||
x1 = fw
|
||||
}
|
||||
if y1 > fh {
|
||||
y1 = fh
|
||||
}
|
||||
if x0 >= x1 || y0 >= y1 {
|
||||
return
|
||||
}
|
||||
|
||||
fStride := frame.Stride
|
||||
sStride := sprite.Stride
|
||||
for y := y0; y < y1; y++ {
|
||||
sy := y - dstY
|
||||
fbRow := y * fStride
|
||||
sRow := sy * sStride
|
||||
for x := x0; x < x1; x++ {
|
||||
sx := x - dstX
|
||||
fbOff := fbRow + x*4
|
||||
sOff := sRow + sx*4
|
||||
a := uint32(sprite.Pix[sOff+3])
|
||||
if a == 0 {
|
||||
continue
|
||||
}
|
||||
if a == 255 {
|
||||
frame.Pix[fbOff+0] = sprite.Pix[sOff+0]
|
||||
frame.Pix[fbOff+1] = sprite.Pix[sOff+1]
|
||||
frame.Pix[fbOff+2] = sprite.Pix[sOff+2]
|
||||
continue
|
||||
}
|
||||
// Premultiplied compositing: dst = src + dst*(1-srcA).
|
||||
inv := 255 - a
|
||||
frame.Pix[fbOff+0] = sprite.Pix[sOff+0] + byte((uint32(frame.Pix[fbOff+0])*inv)/255)
|
||||
frame.Pix[fbOff+1] = sprite.Pix[sOff+1] + byte((uint32(frame.Pix[fbOff+1])*inv)/255)
|
||||
frame.Pix[fbOff+2] = sprite.Pix[sOff+2] + byte((uint32(frame.Pix[fbOff+2])*inv)/255)
|
||||
}
|
||||
}
|
||||
}
|
||||
80
client/vnc/server/shutdown_state.go
Normal file
80
client/vnc/server/shutdown_state.go
Normal file
@@ -0,0 +1,80 @@
|
||||
//go:build unix
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ShutdownState tracks VNC virtual session processes for crash recovery.
|
||||
// Persisted by the state manager; on restart, residual processes are killed.
|
||||
type ShutdownState struct {
|
||||
// Processes maps a description to its PID (e.g., "xvfb:50" -> 1234).
|
||||
Processes map[string]int `json:"processes,omitempty"`
|
||||
}
|
||||
|
||||
// Name returns the state name for the state manager.
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "vnc_sessions_state"
|
||||
}
|
||||
|
||||
// Cleanup kills any residual VNC session processes left from a crash.
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
if len(s.Processes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for desc, pid := range s.Processes {
|
||||
if pid <= 0 {
|
||||
continue
|
||||
}
|
||||
if !isOurProcess(pid, desc) {
|
||||
log.Debugf("cleanup:skipping PID %d (%s), not ours", pid, desc)
|
||||
continue
|
||||
}
|
||||
log.Infof("cleanup:killing residual process %d (%s)", pid, desc)
|
||||
// Kill the process group (negative PID) to get children too.
|
||||
if err := syscall.Kill(-pid, syscall.SIGTERM); err != nil {
|
||||
// Try individual process if group kill fails.
|
||||
if killErr := syscall.Kill(pid, syscall.SIGKILL); killErr != nil {
|
||||
log.Debugf("cleanup: kill pid %d (%s): group kill: %v, single kill: %v", pid, desc, err, killErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.Processes = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// isOurProcess verifies the PID still belongs to a VNC-related process
|
||||
// by checking /proc/<pid>/cmdline (Linux) or the process name.
|
||||
func isOurProcess(pid int, desc string) bool {
|
||||
// Check if the process exists at all.
|
||||
if err := syscall.Kill(pid, 0); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// On Linux, verify via /proc cmdline.
|
||||
cmdline, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid))
|
||||
if err != nil {
|
||||
log.Debugf("cleanup: cannot read /proc/%d/cmdline: %v, treating PID as foreign", pid, err)
|
||||
return false
|
||||
}
|
||||
|
||||
cmd := string(cmdline)
|
||||
// Match against expected process types.
|
||||
if strings.Contains(desc, "xvfb") || strings.Contains(desc, "xorg") {
|
||||
return strings.Contains(cmd, "Xvfb") || strings.Contains(cmd, "Xorg")
|
||||
}
|
||||
if strings.Contains(desc, "desktop") {
|
||||
return strings.Contains(cmd, "session") || strings.Contains(cmd, "plasma") ||
|
||||
strings.Contains(cmd, "gnome") || strings.Contains(cmd, "xfce") ||
|
||||
strings.Contains(cmd, "dbus-launch")
|
||||
}
|
||||
return false
|
||||
}
|
||||
53
client/vnc/server/stubs.go
Normal file
53
client/vnc/server/stubs.go
Normal file
@@ -0,0 +1,53 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
)
|
||||
|
||||
// StubCapturer is a placeholder for platforms without screen capture support.
|
||||
type StubCapturer struct{}
|
||||
|
||||
// Width returns 0 on unsupported platforms.
|
||||
func (c *StubCapturer) Width() int { return 0 }
|
||||
|
||||
// Height returns 0 on unsupported platforms.
|
||||
func (c *StubCapturer) Height() int { return 0 }
|
||||
|
||||
// Capture returns an error on unsupported platforms.
|
||||
func (c *StubCapturer) Capture() (*image.RGBA, error) {
|
||||
return nil, fmt.Errorf("screen capture not supported on this platform")
|
||||
}
|
||||
|
||||
// StubInputInjector is a placeholder for platforms without input injection support.
|
||||
type StubInputInjector struct{}
|
||||
|
||||
// InjectKey is a no-op on unsupported platforms.
|
||||
func (s *StubInputInjector) InjectKey(_ uint32, _ bool) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
// InjectKeyScancode is a no-op on unsupported platforms.
|
||||
func (s *StubInputInjector) InjectKeyScancode(_ uint32, _ uint32, _ bool) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
// InjectPointer is a no-op on unsupported platforms.
|
||||
func (s *StubInputInjector) InjectPointer(_ uint16, _, _, _, _ int) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
// SetClipboard is a no-op on unsupported platforms.
|
||||
func (s *StubInputInjector) SetClipboard(_ string) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
// GetClipboard returns empty on unsupported platforms.
|
||||
func (s *StubInputInjector) GetClipboard() string { return "" }
|
||||
|
||||
// TypeText is a no-op on unsupported platforms.
|
||||
func (s *StubInputInjector) TypeText(_ string) {
|
||||
// no-op
|
||||
}
|
||||
30
client/vnc/server/swizzle.go
Normal file
30
client/vnc/server/swizzle.go
Normal file
@@ -0,0 +1,30 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import "unsafe"
|
||||
|
||||
// swizzleBGRAtoRGBA swaps B and R channels in a BGRA pixel buffer and copies
|
||||
// into dst in-place (dst and src may alias). Operates on uint32 words: one
|
||||
// read-modify-write per pixel, which is meaningfully faster than the naive
|
||||
// three-byte-store per pixel for large buffers like framebuffers.
|
||||
//
|
||||
// The alpha byte is forced to 0xff so callers that capture from X11 GetImage
|
||||
// (where the X server leaves the pad byte as zero) still get an opaque image.
|
||||
func swizzleBGRAtoRGBA(dst, src []byte) {
|
||||
n := len(dst) / 4
|
||||
if len(src)/4 < n {
|
||||
n = len(src) / 4
|
||||
}
|
||||
if n == 0 {
|
||||
return
|
||||
}
|
||||
dp := unsafe.Slice((*uint32)(unsafe.Pointer(&dst[0])), n)
|
||||
sp := unsafe.Slice((*uint32)(unsafe.Pointer(&src[0])), n)
|
||||
for i := range n {
|
||||
p := sp[i]
|
||||
// p in memory: B, G, R, A -> as uint32 little-endian: 0xAARRGGBB
|
||||
// Want memory: R, G, B, 0xFF -> uint32 little-endian: 0xFFBBGGRR
|
||||
dp[i] = 0xFF000000 | (p & 0x0000FF00) | ((p & 0x00FF0000) >> 16) | ((p & 0x000000FF) << 16)
|
||||
}
|
||||
}
|
||||
111
client/vnc/server/tight_test.go
Normal file
111
client/vnc/server/tight_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"image/jpeg"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func makeUniformImage(w, h int, r, g, b byte) *image.RGBA {
|
||||
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
for i := 0; i < len(img.Pix); i += 4 {
|
||||
img.Pix[i+0] = r
|
||||
img.Pix[i+1] = g
|
||||
img.Pix[i+2] = b
|
||||
img.Pix[i+3] = 0xff
|
||||
}
|
||||
return img
|
||||
}
|
||||
|
||||
func makeTwoColorImage(w, h int) *image.RGBA {
|
||||
img := makeUniformImage(w, h, 0x10, 0x20, 0x30)
|
||||
fg := [3]byte{0xa0, 0xb0, 0xc0}
|
||||
for y := 0; y < h; y++ {
|
||||
for x := w / 4; x < w/2; x++ {
|
||||
i := y*img.Stride + x*4
|
||||
img.Pix[i+0] = fg[0]
|
||||
img.Pix[i+1] = fg[1]
|
||||
img.Pix[i+2] = fg[2]
|
||||
}
|
||||
}
|
||||
return img
|
||||
}
|
||||
|
||||
func decodeTightLength(buf []byte) (n, consumed int) {
|
||||
b0 := buf[0]
|
||||
n = int(b0 & 0x7f)
|
||||
if b0&0x80 == 0 {
|
||||
return n, 1
|
||||
}
|
||||
b1 := buf[1]
|
||||
n |= int(b1&0x7f) << 7
|
||||
if b1&0x80 == 0 {
|
||||
return n, 2
|
||||
}
|
||||
b2 := buf[2]
|
||||
n |= int(b2) << 14
|
||||
return n, 3
|
||||
}
|
||||
|
||||
func TestEncodeTightFill(t *testing.T) {
|
||||
pf := defaultClientPixelFormat()
|
||||
img := makeUniformImage(64, 64, 0x12, 0x34, 0x56)
|
||||
tstate := newTightState()
|
||||
buf := encodeTightRect(img, pf, 0, 0, 64, 64, tstate)
|
||||
if len(buf) != 12+1+3 {
|
||||
t.Fatalf("fill rect should be 16 bytes, got %d", len(buf))
|
||||
}
|
||||
if buf[12] != tightFillSubenc {
|
||||
t.Fatalf("expected fill subenc, got 0x%02x", buf[12])
|
||||
}
|
||||
if buf[13] != 0x12 || buf[14] != 0x34 || buf[15] != 0x56 {
|
||||
t.Fatalf("wrong fill colour: %v", buf[13:16])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeTightBasic(t *testing.T) {
|
||||
pf := defaultClientPixelFormat()
|
||||
img := makeTwoColorImage(64, 64)
|
||||
tstate := newTightState()
|
||||
buf := encodeTightRect(img, pf, 0, 0, 64, 64, tstate)
|
||||
if buf[12]&0xf0 != tightBasicFilter {
|
||||
t.Fatalf("expected basic+filter subenc, got 0x%02x", buf[12])
|
||||
}
|
||||
if buf[13] != tightFilterCopy {
|
||||
t.Fatalf("expected copy filter, got 0x%02x", buf[13])
|
||||
}
|
||||
// Length prefix and zlib stream follow.
|
||||
n, _ := decodeTightLength(buf[14:])
|
||||
if n == 0 {
|
||||
t.Fatalf("zero-length basic stream")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeTightJPEG(t *testing.T) {
|
||||
pf := defaultClientPixelFormat()
|
||||
img := makeBenchImage(128, 128, 7) // random → many colours
|
||||
tstate := newTightState()
|
||||
buf := encodeTightRect(img, pf, 0, 0, 128, 128, tstate)
|
||||
if buf[12] != tightJPEGSubenc {
|
||||
t.Fatalf("expected JPEG subenc, got 0x%02x", buf[12])
|
||||
}
|
||||
n, consumed := decodeTightLength(buf[13:])
|
||||
jpegBytes := buf[13+consumed : 13+consumed+n]
|
||||
if _, err := jpeg.Decode(bytes.NewReader(jpegBytes)); err != nil {
|
||||
t.Fatalf("emitted JPEG bytes do not decode: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSampledColorCount(t *testing.T) {
|
||||
uniform := makeUniformImage(64, 64, 0x10, 0x20, 0x30)
|
||||
if c := sampledColorCountInto(map[uint32]struct{}{}, uniform, 0, 0, 64, 64, 32); c != 1 {
|
||||
t.Fatalf("uniform should be 1 colour, got %d", c)
|
||||
}
|
||||
rnd := makeBenchImage(128, 128, 1)
|
||||
if c := sampledColorCountInto(map[uint32]struct{}{}, rnd, 0, 0, 128, 128, 16); c <= 16 {
|
||||
t.Fatalf("random image should exceed colour cap, got %d", c)
|
||||
}
|
||||
}
|
||||
736
client/vnc/server/virtual_x11.go
Normal file
736
client/vnc/server/virtual_x11.go
Normal file
@@ -0,0 +1,736 @@
|
||||
//go:build unix && !darwin && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// VirtualSession manages a virtual X11 display (Xvfb) with a desktop session
|
||||
// running as a target user. It implements ScreenCapturer and InputInjector by
|
||||
// delegating to an X11Capturer/X11InputInjector pointed at the virtual display.
|
||||
const (
|
||||
sessionIdleTimeout = 5 * time.Minute
|
||||
|
||||
defaultSessionWidth uint16 = 1280
|
||||
defaultSessionHeight uint16 = 800
|
||||
)
|
||||
|
||||
type VirtualSession struct {
|
||||
mu sync.Mutex
|
||||
display string
|
||||
user *user.User
|
||||
uid uint32
|
||||
gid uint32
|
||||
groups []uint32
|
||||
width uint16
|
||||
height uint16
|
||||
xvfb *exec.Cmd
|
||||
desktop *exec.Cmd
|
||||
poller *X11Poller
|
||||
injector *X11InputInjector
|
||||
log *log.Entry
|
||||
stopped bool
|
||||
clients int
|
||||
idleTimer *time.Timer
|
||||
onIdle func() // called when idle timeout fires or Xvfb dies
|
||||
}
|
||||
|
||||
// StartVirtualSession creates and starts a virtual X11 session for the given
|
||||
// user. Requires root privileges to create sessions as other users. width and
|
||||
// height request the virtual display geometry; 0 values fall back to the
|
||||
// defaults.
|
||||
func StartVirtualSession(username string, width, height uint16, logger *log.Entry) (*VirtualSession, error) {
|
||||
if os.Getuid() != 0 {
|
||||
return nil, fmt.Errorf("virtual sessions require root privileges")
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("Xvfb"); err != nil {
|
||||
if _, err := exec.LookPath("Xorg"); err != nil {
|
||||
return nil, fmt.Errorf("neither Xvfb nor Xorg found (install xvfb or xserver-xorg)")
|
||||
}
|
||||
if !hasDummyDriver() {
|
||||
return nil, fmt.Errorf("xvfb not found and xorg dummy driver not installed (install xvfb or xf86-video-dummy)")
|
||||
}
|
||||
}
|
||||
|
||||
u, err := user.Lookup(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("lookup user %s: %w", username, err)
|
||||
}
|
||||
|
||||
uid, err := strconv.ParseUint(u.Uid, 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse uid: %w", err)
|
||||
}
|
||||
gid, err := strconv.ParseUint(u.Gid, 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse gid: %w", err)
|
||||
}
|
||||
|
||||
groups, err := supplementaryGroups(u)
|
||||
if err != nil {
|
||||
logger.Debugf("supplementary groups for %s: %v", username, err)
|
||||
}
|
||||
|
||||
if width == 0 {
|
||||
width = defaultSessionWidth
|
||||
}
|
||||
if height == 0 {
|
||||
height = defaultSessionHeight
|
||||
}
|
||||
|
||||
vs := &VirtualSession{
|
||||
user: u,
|
||||
uid: uint32(uid),
|
||||
gid: uint32(gid),
|
||||
groups: groups,
|
||||
width: width,
|
||||
height: height,
|
||||
log: logger.WithField("vnc_user", username),
|
||||
}
|
||||
|
||||
if err := vs.start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return vs, nil
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) start() error {
|
||||
display, err := findFreeDisplay()
|
||||
if err != nil {
|
||||
return fmt.Errorf("find free display: %w", err)
|
||||
}
|
||||
vs.display = display
|
||||
|
||||
if err := vs.startXvfb(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
socketPath := fmt.Sprintf("%s/X%s", x11SocketDir, vs.display[1:])
|
||||
if err := waitForPath(socketPath, 5*time.Second); err != nil {
|
||||
vs.stopXvfb()
|
||||
return fmt.Errorf("wait for X11 socket %s: %w", socketPath, err)
|
||||
}
|
||||
|
||||
// Grant the target user access to the display via xhost.
|
||||
xhostCmd := exec.Command("xhost", "+SI:localuser:"+vs.user.Username)
|
||||
xhostCmd.Env = []string{envDisplay + "=" + vs.display}
|
||||
if out, err := xhostCmd.CombinedOutput(); err != nil {
|
||||
vs.log.Debugf("xhost: %s (%v)", strings.TrimSpace(string(out)), err)
|
||||
}
|
||||
|
||||
vs.poller = NewX11Poller(vs.display)
|
||||
|
||||
injector, err := NewX11InputInjector(vs.display)
|
||||
if err != nil {
|
||||
vs.stopXvfb()
|
||||
return fmt.Errorf("create X11 injector for %s: %w", vs.display, err)
|
||||
}
|
||||
vs.injector = injector
|
||||
|
||||
if err := vs.startDesktop(); err != nil {
|
||||
vs.injector.Close()
|
||||
vs.stopXvfb()
|
||||
return fmt.Errorf("start desktop: %w", err)
|
||||
}
|
||||
|
||||
vs.log.Infof("virtual session started: display=%s user=%s", vs.display, vs.user.Username)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClientConnect increments the client count and cancels any idle timer.
|
||||
func (vs *VirtualSession) ClientConnect() {
|
||||
vs.mu.Lock()
|
||||
defer vs.mu.Unlock()
|
||||
vs.clients++
|
||||
if vs.idleTimer != nil {
|
||||
vs.idleTimer.Stop()
|
||||
vs.idleTimer = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ClientDisconnect decrements the client count. When the last client
|
||||
// disconnects, starts an idle timer that destroys the session.
|
||||
func (vs *VirtualSession) ClientDisconnect() {
|
||||
vs.mu.Lock()
|
||||
defer vs.mu.Unlock()
|
||||
vs.clients--
|
||||
if vs.clients <= 0 {
|
||||
vs.clients = 0
|
||||
vs.log.Infof("no VNC clients connected, session will be destroyed in %s", sessionIdleTimeout)
|
||||
vs.idleTimer = time.AfterFunc(sessionIdleTimeout, vs.idleExpired)
|
||||
}
|
||||
}
|
||||
|
||||
// idleExpired is called by the idle timer. It stops the session and
|
||||
// notifies the session manager via onIdle so it removes us from the map.
|
||||
// Bails out early if a client reconnected before the timer callback won
|
||||
// the race (Stop() doesn't cancel an already-firing AfterFunc, so the
|
||||
// state check has to happen here under vs.mu).
|
||||
func (vs *VirtualSession) idleExpired() {
|
||||
vs.mu.Lock()
|
||||
if vs.stopped || vs.clients > 0 {
|
||||
vs.mu.Unlock()
|
||||
return
|
||||
}
|
||||
vs.mu.Unlock()
|
||||
|
||||
vs.log.Info("idle timeout reached, destroying virtual session")
|
||||
vs.Stop()
|
||||
if vs.onIdle != nil {
|
||||
vs.onIdle()
|
||||
}
|
||||
}
|
||||
|
||||
// isAlive returns true if the session is running and its X server socket exists.
|
||||
func (vs *VirtualSession) isAlive() bool {
|
||||
vs.mu.Lock()
|
||||
stopped := vs.stopped
|
||||
display := vs.display
|
||||
vs.mu.Unlock()
|
||||
|
||||
if stopped {
|
||||
return false
|
||||
}
|
||||
// Verify the X socket still exists on disk.
|
||||
socketPath := fmt.Sprintf("%s/X%s", x11SocketDir, display[1:])
|
||||
if _, err := os.Stat(socketPath); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Capturer returns the screen capturer for this virtual session.
|
||||
func (vs *VirtualSession) Capturer() ScreenCapturer {
|
||||
return vs.poller
|
||||
}
|
||||
|
||||
// Injector returns the input injector for this virtual session.
|
||||
func (vs *VirtualSession) Injector() InputInjector {
|
||||
return vs.injector
|
||||
}
|
||||
|
||||
// Display returns the X11 display string (e.g., ":99").
|
||||
func (vs *VirtualSession) Display() string {
|
||||
return vs.display
|
||||
}
|
||||
|
||||
// Stop terminates the virtual session, killing the desktop and Xvfb.
|
||||
func (vs *VirtualSession) Stop() {
|
||||
vs.mu.Lock()
|
||||
defer vs.mu.Unlock()
|
||||
|
||||
if vs.stopped {
|
||||
return
|
||||
}
|
||||
vs.stopped = true
|
||||
|
||||
if vs.injector != nil {
|
||||
vs.injector.Close()
|
||||
}
|
||||
if vs.poller != nil {
|
||||
vs.poller.Close()
|
||||
vs.poller = nil
|
||||
}
|
||||
|
||||
vs.stopDesktop()
|
||||
vs.stopXvfb()
|
||||
|
||||
vs.log.Info("virtual session stopped")
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) startXvfb() error {
|
||||
if _, err := exec.LookPath("Xvfb"); err == nil {
|
||||
return vs.startXvfbDirect()
|
||||
}
|
||||
return vs.startXorgDummy()
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) startXvfbDirect() error {
|
||||
geom := fmt.Sprintf("%dx%dx24", vs.width, vs.height)
|
||||
vs.xvfb = exec.Command("Xvfb", vs.display,
|
||||
"-screen", "0", geom,
|
||||
"-nolisten", "tcp",
|
||||
)
|
||||
vs.xvfb.SysProcAttr = &syscall.SysProcAttr{Setsid: true, Pdeathsig: syscall.SIGTERM}
|
||||
|
||||
if err := vs.xvfb.Start(); err != nil {
|
||||
return fmt.Errorf("start Xvfb on %s: %w", vs.display, err)
|
||||
}
|
||||
vs.log.Infof("Xvfb started on %s (pid=%d)", vs.display, vs.xvfb.Process.Pid)
|
||||
|
||||
go vs.monitorXvfb()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// startXorgDummy starts Xorg with the dummy video driver as a fallback when
|
||||
// Xvfb is not installed. Most systems with a desktop have Xorg available.
|
||||
func (vs *VirtualSession) startXorgDummy() error {
|
||||
conf := fmt.Sprintf(`Section "Device"
|
||||
Identifier "dummy"
|
||||
Driver "dummy"
|
||||
VideoRam 256000
|
||||
EndSection
|
||||
Section "Screen"
|
||||
Identifier "screen"
|
||||
Device "dummy"
|
||||
DefaultDepth 24
|
||||
SubSection "Display"
|
||||
Depth 24
|
||||
Modes "%dx%d"
|
||||
EndSubSection
|
||||
EndSection
|
||||
`, vs.width, vs.height)
|
||||
f, err := os.CreateTemp("", fmt.Sprintf("nbvnc-dummy-%s-*.conf", vs.display[1:]))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create Xorg dummy config: %w", err)
|
||||
}
|
||||
confPath := f.Name()
|
||||
if _, err := f.WriteString(conf); err != nil {
|
||||
f.Close()
|
||||
os.Remove(confPath)
|
||||
return fmt.Errorf("write Xorg dummy config: %w", err)
|
||||
}
|
||||
if err := f.Chmod(0600); err != nil {
|
||||
f.Close()
|
||||
os.Remove(confPath)
|
||||
return fmt.Errorf("chmod Xorg dummy config: %w", err)
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
os.Remove(confPath)
|
||||
return fmt.Errorf("close Xorg dummy config: %w", err)
|
||||
}
|
||||
|
||||
vs.xvfb = exec.Command("Xorg", vs.display,
|
||||
"-config", confPath,
|
||||
"-noreset",
|
||||
"-nolisten", "tcp",
|
||||
)
|
||||
vs.xvfb.SysProcAttr = &syscall.SysProcAttr{Setsid: true, Pdeathsig: syscall.SIGTERM}
|
||||
|
||||
if err := vs.xvfb.Start(); err != nil {
|
||||
os.Remove(confPath)
|
||||
return fmt.Errorf("start Xorg dummy on %s: %w", vs.display, err)
|
||||
}
|
||||
vs.log.Infof("Xorg (dummy driver) started on %s (pid=%d)", vs.display, vs.xvfb.Process.Pid)
|
||||
|
||||
go func() {
|
||||
vs.monitorXvfb()
|
||||
os.Remove(confPath)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// monitorXvfb waits for the Xvfb/Xorg process to exit. If it exits
|
||||
// unexpectedly (not via Stop), the session is marked as dead and the
|
||||
// onIdle callback fires so the session manager removes it from the map.
|
||||
// The next GetOrCreate call for this user will create a fresh session.
|
||||
func (vs *VirtualSession) monitorXvfb() {
|
||||
if err := vs.xvfb.Wait(); err != nil {
|
||||
vs.log.Debugf("X server exited: %v", err)
|
||||
}
|
||||
|
||||
vs.mu.Lock()
|
||||
alreadyStopped := vs.stopped
|
||||
if !alreadyStopped {
|
||||
vs.log.Warn("X server exited unexpectedly, marking session as dead")
|
||||
vs.stopped = true
|
||||
if vs.idleTimer != nil {
|
||||
vs.idleTimer.Stop()
|
||||
vs.idleTimer = nil
|
||||
}
|
||||
if vs.injector != nil {
|
||||
vs.injector.Close()
|
||||
}
|
||||
vs.stopDesktop()
|
||||
}
|
||||
onIdle := vs.onIdle
|
||||
vs.mu.Unlock()
|
||||
|
||||
if !alreadyStopped && onIdle != nil {
|
||||
onIdle()
|
||||
}
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) stopXvfb() {
|
||||
if vs.xvfb == nil || vs.xvfb.Process == nil {
|
||||
return
|
||||
}
|
||||
if err := syscall.Kill(-vs.xvfb.Process.Pid, syscall.SIGTERM); err != nil {
|
||||
vs.log.Debugf("SIGTERM xvfb group: %v", err)
|
||||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
if err := syscall.Kill(-vs.xvfb.Process.Pid, syscall.SIGKILL); err != nil {
|
||||
vs.log.Debugf("SIGKILL xvfb group: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) startDesktop() error {
|
||||
session := detectDesktopSession()
|
||||
|
||||
// Wrap the desktop command with dbus-launch to provide a session bus.
|
||||
// Without this, most desktop environments (XFCE, MATE, etc.) fail immediately.
|
||||
var args []string
|
||||
if _, err := exec.LookPath("dbus-launch"); err == nil {
|
||||
args = append([]string{"dbus-launch", "--exit-with-session"}, session...)
|
||||
} else {
|
||||
args = session
|
||||
}
|
||||
|
||||
vs.desktop = exec.Command(args[0], args[1:]...)
|
||||
vs.desktop.Dir = vs.user.HomeDir
|
||||
vs.desktop.Env = vs.buildUserEnv()
|
||||
vs.desktop.SysProcAttr = &syscall.SysProcAttr{
|
||||
Credential: &syscall.Credential{
|
||||
Uid: vs.uid,
|
||||
Gid: vs.gid,
|
||||
Groups: vs.groups,
|
||||
},
|
||||
Setsid: true,
|
||||
Pdeathsig: syscall.SIGTERM,
|
||||
}
|
||||
|
||||
if err := vs.desktop.Start(); err != nil {
|
||||
return fmt.Errorf("start desktop session (%v): %w", args, err)
|
||||
}
|
||||
vs.log.Infof("desktop session started: %v (pid=%d)", args, vs.desktop.Process.Pid)
|
||||
|
||||
go vs.monitorDesktop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// monitorDesktop waits for the desktop-session process to exit. When the user
|
||||
// logs out of GNOME/KDE/XFCE/etc., the session process terminates while Xvfb
|
||||
// keeps running, leaving a blank root window. Tear the whole virtual session
|
||||
// down so the next connect starts fresh with a login.
|
||||
func (vs *VirtualSession) monitorDesktop() {
|
||||
if err := vs.desktop.Wait(); err != nil {
|
||||
vs.log.Debugf("desktop session exited: %v", err)
|
||||
}
|
||||
|
||||
vs.mu.Lock()
|
||||
alreadyStopped := vs.stopped
|
||||
if !alreadyStopped {
|
||||
vs.log.Info("desktop session exited (logout), tearing down virtual session")
|
||||
vs.stopped = true
|
||||
if vs.idleTimer != nil {
|
||||
vs.idleTimer.Stop()
|
||||
vs.idleTimer = nil
|
||||
}
|
||||
if vs.injector != nil {
|
||||
vs.injector.Close()
|
||||
}
|
||||
vs.stopXvfb()
|
||||
}
|
||||
onIdle := vs.onIdle
|
||||
vs.mu.Unlock()
|
||||
|
||||
if !alreadyStopped && onIdle != nil {
|
||||
onIdle()
|
||||
}
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) stopDesktop() {
|
||||
if vs.desktop == nil || vs.desktop.Process == nil {
|
||||
return
|
||||
}
|
||||
if err := syscall.Kill(-vs.desktop.Process.Pid, syscall.SIGTERM); err != nil {
|
||||
vs.log.Debugf("SIGTERM desktop group: %v", err)
|
||||
}
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
if err := syscall.Kill(-vs.desktop.Process.Pid, syscall.SIGKILL); err != nil {
|
||||
vs.log.Debugf("SIGKILL desktop group: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (vs *VirtualSession) buildUserEnv() []string {
|
||||
return []string{
|
||||
envDisplay + "=" + vs.display,
|
||||
"HOME=" + vs.user.HomeDir,
|
||||
"USER=" + vs.user.Username,
|
||||
"LOGNAME=" + vs.user.Username,
|
||||
"SHELL=" + getUserShell(vs.user.Uid),
|
||||
"PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin",
|
||||
"XDG_RUNTIME_DIR=/run/user/" + vs.user.Uid,
|
||||
"DBUS_SESSION_BUS_ADDRESS=unix:path=/run/user/" + vs.user.Uid + "/bus",
|
||||
}
|
||||
}
|
||||
|
||||
// detectDesktopSession discovers available desktop sessions from the standard
|
||||
// /usr/share/xsessions/*.desktop files (FreeDesktop standard, used by all
|
||||
// display managers). Falls back to a hardcoded list if no .desktop files found.
|
||||
func detectDesktopSession() []string {
|
||||
// Scan xsessions directories (Linux: /usr/share, FreeBSD: /usr/local/share).
|
||||
for _, dir := range []string{"/usr/share/xsessions", "/usr/local/share/xsessions"} {
|
||||
if cmd := findXSession(dir); cmd != nil {
|
||||
return cmd
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: try common session commands directly.
|
||||
fallbacks := [][]string{
|
||||
{"startplasma-x11"},
|
||||
{"gnome-session"},
|
||||
{"xfce4-session"},
|
||||
{"mate-session"},
|
||||
{"cinnamon-session"},
|
||||
{"openbox-session"},
|
||||
{"xterm"},
|
||||
}
|
||||
for _, s := range fallbacks {
|
||||
if _, err := exec.LookPath(s[0]); err == nil {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return []string{"xterm"}
|
||||
}
|
||||
|
||||
// sessionPriority defines preference order for desktop environments.
|
||||
// Lower number = higher priority. Unknown sessions get 100.
|
||||
var sessionPriority = map[string]int{
|
||||
"plasma": 1, // KDE
|
||||
"gnome": 2,
|
||||
"xfce": 3,
|
||||
"mate": 4,
|
||||
"cinnamon": 5,
|
||||
"lxqt": 6,
|
||||
"lxde": 7,
|
||||
"budgie": 8,
|
||||
"openbox": 20,
|
||||
"fluxbox": 21,
|
||||
"i3": 22,
|
||||
"xinit": 50, // generic user session
|
||||
"lightdm": 50,
|
||||
"default": 50,
|
||||
}
|
||||
|
||||
func findXSession(dir string) []string {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
candidates := collectSessionCandidates(dir, entries)
|
||||
if len(candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
best := bestSessionCandidate(candidates)
|
||||
parts := strings.Fields(best.cmd)
|
||||
if _, err := exec.LookPath(parts[0]); err != nil {
|
||||
return nil
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
type sessionCandidate struct {
|
||||
cmd string
|
||||
priority int
|
||||
}
|
||||
|
||||
func collectSessionCandidates(dir string, entries []os.DirEntry) []sessionCandidate {
|
||||
var out []sessionCandidate
|
||||
for _, e := range entries {
|
||||
c, ok := parseSessionEntry(dir, e)
|
||||
if ok {
|
||||
out = append(out, c)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// parseSessionEntry reads a single .desktop file and extracts its Exec
|
||||
// command plus the priority hint to be used when picking the best session.
|
||||
func parseSessionEntry(dir string, e os.DirEntry) (sessionCandidate, bool) {
|
||||
if !strings.HasSuffix(e.Name(), ".desktop") {
|
||||
return sessionCandidate{}, false
|
||||
}
|
||||
data, err := os.ReadFile(filepath.Join(dir, e.Name()))
|
||||
if err != nil {
|
||||
return sessionCandidate{}, false
|
||||
}
|
||||
execCmd := extractExecLine(data)
|
||||
if execCmd == "" || execCmd == "default" {
|
||||
return sessionCandidate{}, false
|
||||
}
|
||||
return sessionCandidate{cmd: execCmd, priority: sessionPriorityFor(e.Name(), execCmd)}, true
|
||||
}
|
||||
|
||||
func extractExecLine(data []byte) string {
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
if strings.HasPrefix(line, "Exec=") {
|
||||
return strings.TrimSpace(strings.TrimPrefix(line, "Exec="))
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func sessionPriorityFor(name, execCmd string) int {
|
||||
pri := 100
|
||||
lower := strings.ToLower(name + " " + execCmd)
|
||||
for keyword, p := range sessionPriority {
|
||||
if strings.Contains(lower, keyword) && p < pri {
|
||||
pri = p
|
||||
}
|
||||
}
|
||||
return pri
|
||||
}
|
||||
|
||||
func bestSessionCandidate(candidates []sessionCandidate) sessionCandidate {
|
||||
best := candidates[0]
|
||||
for _, c := range candidates[1:] {
|
||||
if c.priority < best.priority {
|
||||
best = c
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
|
||||
// findFreeDisplay scans for an unused X11 display number.
|
||||
func findFreeDisplay() (string, error) {
|
||||
for n := 50; n < 200; n++ {
|
||||
lockFile := fmt.Sprintf("/tmp/.X%d-lock", n)
|
||||
socketFile := fmt.Sprintf("%s/X%d", x11SocketDir, n)
|
||||
if _, err := os.Stat(lockFile); err == nil {
|
||||
continue
|
||||
}
|
||||
if _, err := os.Stat(socketFile); err == nil {
|
||||
continue
|
||||
}
|
||||
return fmt.Sprintf(":%d", n), nil
|
||||
}
|
||||
return "", fmt.Errorf("no free X11 display found (checked :50-:199)")
|
||||
}
|
||||
|
||||
// waitForPath polls until a filesystem path exists or the timeout expires.
|
||||
func waitForPath(path string, timeout time.Duration) error {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
return fmt.Errorf("timeout waiting for %s", path)
|
||||
}
|
||||
|
||||
// getUserShell returns the login shell for the given UID.
|
||||
func getUserShell(uid string) string {
|
||||
data, err := os.ReadFile("/etc/passwd")
|
||||
if err != nil {
|
||||
return "/bin/sh"
|
||||
}
|
||||
for _, line := range strings.Split(string(data), "\n") {
|
||||
fields := strings.Split(line, ":")
|
||||
if len(fields) >= 7 && fields[2] == uid {
|
||||
return fields[6]
|
||||
}
|
||||
}
|
||||
return "/bin/sh"
|
||||
}
|
||||
|
||||
// supplementaryGroups returns the supplementary group IDs for a user.
|
||||
func supplementaryGroups(u *user.User) ([]uint32, error) {
|
||||
gids, err := u.GroupIds()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var groups []uint32
|
||||
for _, g := range gids {
|
||||
id, err := strconv.ParseUint(g, 10, 32)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
groups = append(groups, uint32(id))
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// sessionManager tracks active virtual sessions by username.
|
||||
type sessionManager struct {
|
||||
mu sync.Mutex
|
||||
sessions map[string]*VirtualSession
|
||||
log *log.Entry
|
||||
}
|
||||
|
||||
func newSessionManager(logger *log.Entry) *sessionManager {
|
||||
return &sessionManager{
|
||||
sessions: make(map[string]*VirtualSession),
|
||||
log: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetOrCreate returns an existing virtual session or creates a new one with
|
||||
// the requested geometry. If a previous session for this user is alive it is
|
||||
// reused regardless of the requested geometry; the first caller's size wins
|
||||
// until the session idles out. If a previous session is stopped or its X
|
||||
// server died, it is replaced.
|
||||
func (sm *sessionManager) GetOrCreate(username string, width, height uint16) (vncSession, error) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
if vs, ok := sm.sessions[username]; ok {
|
||||
if vs.isAlive() {
|
||||
return vs, nil
|
||||
}
|
||||
sm.log.Infof("replacing dead virtual session for %s", username)
|
||||
vs.Stop()
|
||||
delete(sm.sessions, username)
|
||||
}
|
||||
|
||||
vs, err := StartVirtualSession(username, width, height, sm.log)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vs.onIdle = func() {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
if cur, ok := sm.sessions[username]; ok && cur == vs {
|
||||
delete(sm.sessions, username)
|
||||
sm.log.Infof("removed idle virtual session for %s", username)
|
||||
}
|
||||
}
|
||||
sm.sessions[username] = vs
|
||||
return vs, nil
|
||||
}
|
||||
|
||||
// hasDummyDriver checks common paths for the Xorg dummy video driver.
|
||||
func hasDummyDriver() bool {
|
||||
paths := []string{
|
||||
"/usr/lib/xorg/modules/drivers/dummy_drv.so", // Debian/Ubuntu
|
||||
"/usr/lib64/xorg/modules/drivers/dummy_drv.so", // RHEL/Fedora
|
||||
"/usr/local/lib/xorg/modules/drivers/dummy_drv.so", // FreeBSD
|
||||
"/usr/lib/x86_64-linux-gnu/xorg/modules/drivers/dummy_drv.so", // Debian multiarch
|
||||
}
|
||||
for _, p := range paths {
|
||||
if _, err := os.Stat(p); err == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// StopAll terminates all active virtual sessions.
|
||||
func (sm *sessionManager) StopAll() {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
for username, vs := range sm.sessions {
|
||||
vs.Stop()
|
||||
delete(sm.sessions, username)
|
||||
sm.log.Infof("stopped virtual session for %s", username)
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
@@ -21,6 +22,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/vnc"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -38,6 +40,7 @@ const (
|
||||
|
||||
func main() {
|
||||
js.Global().Set("NetBirdClient", js.FuncOf(netBirdClientConstructor))
|
||||
js.Global().Set("netbirdGenerateVNCSessionKey", createGenerateVNCSessionKeyMethod())
|
||||
|
||||
select {}
|
||||
}
|
||||
@@ -387,6 +390,156 @@ func createRDPProxyMethod(client *netbird.Client) js.Func {
|
||||
})
|
||||
}
|
||||
|
||||
// createGenerateVNCSessionKeyMethod returns a JS func that mints a fresh
|
||||
// X25519 keypair, stashes the private half inside wasm under a random
|
||||
// session id, and returns { publicKey, sessionId } to JS. The private
|
||||
// key never leaves the wasm heap.
|
||||
func createGenerateVNCSessionKeyMethod() js.Func {
|
||||
return js.FuncOf(func(_ js.Value, _ []js.Value) any {
|
||||
id, pub, err := vnc.NewSessionKey()
|
||||
if err != nil {
|
||||
return js.ValueOf(err.Error())
|
||||
}
|
||||
out := js.Global().Get("Object").New()
|
||||
out.Set("sessionId", id)
|
||||
out.Set("publicKey", base64.StdEncoding.EncodeToString(pub))
|
||||
return out
|
||||
})
|
||||
}
|
||||
|
||||
// createVNCProxyMethod creates the VNC proxy method for raw TCP-over-WebSocket bridging.
|
||||
// JS signature: createVNCProxy(hostname, port, mode?, username?, keySessionID?, sessionID?, width?, height?, peerPublicKey?)
|
||||
// mode: "attach" (default) or "session"
|
||||
// username: required when mode is "session"
|
||||
// keySessionID: handle for the wasm-resident session keypair minted by netbirdGenerateVNCSessionKey
|
||||
// sessionID: Windows session ID (0 = console/auto)
|
||||
// width/height: requested viewport size for session mode (0 = server default)
|
||||
// peerPublicKey: base64 X25519 static pubkey of the destination peer (required for auth)
|
||||
func createVNCProxyMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
params, err := parseVNCProxyArgs(args)
|
||||
if err != nil {
|
||||
if params.rejectViaPromise {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
})
|
||||
}
|
||||
return js.ValueOf(err.Error())
|
||||
}
|
||||
proxy := vnc.NewVNCProxy(client)
|
||||
return proxy.CreateProxy(vnc.ProxyRequest{
|
||||
Hostname: params.hostname,
|
||||
Port: params.port,
|
||||
Mode: params.mode,
|
||||
Username: params.username,
|
||||
SessionID: params.sessionID,
|
||||
Width: params.width,
|
||||
Height: params.height,
|
||||
PeerPublicKey: params.peerPublicKey,
|
||||
KeySessionID: params.keySessionID,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
type vncProxyParams struct {
|
||||
hostname string
|
||||
port string
|
||||
mode string
|
||||
username string
|
||||
keySessionID string
|
||||
sessionID uint32
|
||||
width uint16
|
||||
height uint16
|
||||
peerPublicKey string
|
||||
rejectViaPromise bool
|
||||
}
|
||||
|
||||
// parseVNCProxyArgs validates JS args for createVNCProxyMethod and returns
|
||||
// the parsed params plus the first validation error (nil on success).
|
||||
// vncProxyParams.rejectViaPromise tells the caller which JS-side response
|
||||
// path to use for the returned error.
|
||||
func parseVNCProxyArgs(args []js.Value) (vncProxyParams, error) {
|
||||
var p vncProxyParams
|
||||
if err := parseVNCProxyRequiredArgs(args, &p); err != nil {
|
||||
return p, err
|
||||
}
|
||||
if err := parseVNCProxyOptionalStrings(args, &p); err != nil {
|
||||
return p, err
|
||||
}
|
||||
if err := parseVNCProxyOptionalNumbers(args, &p); err != nil {
|
||||
return p, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func parseVNCProxyRequiredArgs(args []js.Value, p *vncProxyParams) error {
|
||||
if len(args) < 2 {
|
||||
return fmt.Errorf("hostname and port required")
|
||||
}
|
||||
if args[0].Type() != js.TypeString {
|
||||
p.rejectViaPromise = true
|
||||
return fmt.Errorf("hostname parameter must be a string")
|
||||
}
|
||||
if args[1].Type() != js.TypeString {
|
||||
p.rejectViaPromise = true
|
||||
return fmt.Errorf("port parameter must be a string")
|
||||
}
|
||||
p.hostname = args[0].String()
|
||||
p.port = args[1].String()
|
||||
p.mode = "attach"
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseVNCProxyOptionalStrings(args []js.Value, p *vncProxyParams) error {
|
||||
if len(args) > 2 && args[2].Type() == js.TypeString {
|
||||
p.mode = args[2].String()
|
||||
}
|
||||
if p.mode != "attach" && p.mode != "session" {
|
||||
p.rejectViaPromise = true
|
||||
return fmt.Errorf("invalid mode %q: expected \"attach\" or \"session\"", p.mode)
|
||||
}
|
||||
if len(args) > 3 && args[3].Type() == js.TypeString {
|
||||
p.username = args[3].String()
|
||||
}
|
||||
if len(args) > 4 && args[4].Type() == js.TypeString {
|
||||
p.keySessionID = args[4].String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseVNCProxyOptionalNumbers(args []js.Value, p *vncProxyParams) error {
|
||||
if len(args) > 5 && args[5].Type() == js.TypeNumber {
|
||||
v := args[5].Int()
|
||||
if v < 0 || v > 0xFFFFFFFF {
|
||||
p.rejectViaPromise = true
|
||||
return fmt.Errorf("invalid sessionID %d: must be 0..0xFFFFFFFF", v)
|
||||
}
|
||||
p.sessionID = uint32(v)
|
||||
}
|
||||
// width=0 / height=0 mean "use server default"; reject only out-of-range
|
||||
// non-zero values so attach mode (which omits width/height) still works.
|
||||
if len(args) > 6 && args[6].Type() == js.TypeNumber {
|
||||
v := args[6].Int()
|
||||
if v < 0 || v > 0xFFFF {
|
||||
p.rejectViaPromise = true
|
||||
return fmt.Errorf("invalid width %d: must be 0..65535", v)
|
||||
}
|
||||
p.width = uint16(v)
|
||||
}
|
||||
if len(args) > 7 && args[7].Type() == js.TypeNumber {
|
||||
v := args[7].Int()
|
||||
if v < 0 || v > 0xFFFF {
|
||||
p.rejectViaPromise = true
|
||||
return fmt.Errorf("invalid height %d: must be 0..65535", v)
|
||||
}
|
||||
p.height = uint16(v)
|
||||
}
|
||||
if len(args) > 8 && args[8].Type() == js.TypeString {
|
||||
p.peerPublicKey = args[8].String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getStatusOverview is a helper to get the status overview
|
||||
func getStatusOverview(client *netbird.Client) (nbstatus.OutputOverview, error) {
|
||||
fullStatus, err := client.Status()
|
||||
@@ -563,10 +716,10 @@ func createStartCaptureMethod(client *netbird.Client) js.Func {
|
||||
//
|
||||
// Usage from browser devtools console:
|
||||
//
|
||||
// await client.capture() // capture all packets
|
||||
// await client.capture("tcp") // capture with filter
|
||||
// await client.capture({filter: "host 10.0.0.1", verbose: true})
|
||||
// client.stopCapture() // stop and print stats
|
||||
// await netbird.capture() // capture all packets
|
||||
// await netbird.capture("tcp") // capture with filter
|
||||
// await netbird.capture({filter: "host 10.0.0.1", verbose: true})
|
||||
// netbird.stopCapture() // stop and print stats
|
||||
func captureMethods(client *netbird.Client) (startFn, stopFn js.Func) {
|
||||
var mu sync.Mutex
|
||||
var active *wasmcapture.Handle
|
||||
@@ -594,7 +747,7 @@ func captureMethods(client *netbird.Client) (startFn, stopFn js.Func) {
|
||||
active = h
|
||||
|
||||
console := js.Global().Get("console")
|
||||
console.Call("log", "[capture] started, call client.stopCapture() to stop")
|
||||
console.Call("log", "[capture] started, call netbird.stopCapture() to stop")
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
@@ -677,6 +830,7 @@ func createClientObject(client *netbird.Client) js.Value {
|
||||
obj["createSSHConnection"] = createSSHMethod(client)
|
||||
obj["proxyRequest"] = createProxyRequestMethod(client)
|
||||
obj["createRDPProxy"] = createRDPProxyMethod(client)
|
||||
obj["createVNCProxy"] = createVNCProxyMethod(client)
|
||||
obj["status"] = createStatusMethod(client)
|
||||
obj["statusSummary"] = createStatusSummaryMethod(client)
|
||||
obj["statusDetail"] = createStatusDetailMethod(client)
|
||||
|
||||
586
client/wasm/internal/vnc/proxy.go
Normal file
586
client/wasm/internal/vnc/proxy.go
Normal file
@@ -0,0 +1,586 @@
|
||||
//go:build js
|
||||
|
||||
package vnc
|
||||
|
||||
import (
|
||||
"context"
|
||||
crand "crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall/js"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var cryptoRandRead = crand.Read
|
||||
|
||||
// vncIdentityMagic mirrors the server side in client/vnc/server/server.go.
|
||||
var vncIdentityMagic = []byte("NBV3")
|
||||
|
||||
// Noise_IK_25519_ChaChaPoly_SHA256 message sizes (with empty payloads).
|
||||
const (
|
||||
noiseInitiatorMsgLen = 96
|
||||
noiseResponderMsgLen = 48
|
||||
)
|
||||
|
||||
var vncNoiseSuite = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
|
||||
// sessionKeyStore retains per-session X25519 keypairs so the JS layer
|
||||
// only sees an opaque session id + the public key; the private key never
|
||||
// leaves wasm.
|
||||
var sessionKeyStore = struct {
|
||||
mu sync.Mutex
|
||||
keys map[string]noise.DHKey
|
||||
}{keys: map[string]noise.DHKey{}}
|
||||
|
||||
// NewSessionKey mints an X25519 keypair, stores the private half under a
|
||||
// fresh random session id, and returns (id, pubkey).
|
||||
func NewSessionKey() (string, []byte, error) {
|
||||
kp, err := noise.DH25519.GenerateKeypair(nil)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("generate keypair: %w", err)
|
||||
}
|
||||
idBytes := make([]byte, 16)
|
||||
if _, err := cryptoRandRead(idBytes); err != nil {
|
||||
return "", nil, fmt.Errorf("session id randomness: %w", err)
|
||||
}
|
||||
id := base64.RawURLEncoding.EncodeToString(idBytes)
|
||||
sessionKeyStore.mu.Lock()
|
||||
sessionKeyStore.keys[id] = kp
|
||||
sessionKeyStore.mu.Unlock()
|
||||
return id, kp.Public, nil
|
||||
}
|
||||
|
||||
// consumeSessionKey atomically retrieves and removes the keypair for id.
|
||||
// A session handle is single-use; combining lookup and delete under one
|
||||
// critical section prevents concurrent callers from observing the same key.
|
||||
func consumeSessionKey(id string) (noise.DHKey, bool) {
|
||||
sessionKeyStore.mu.Lock()
|
||||
defer sessionKeyStore.mu.Unlock()
|
||||
kp, ok := sessionKeyStore.keys[id]
|
||||
if ok {
|
||||
delete(sessionKeyStore.keys, id)
|
||||
}
|
||||
return kp, ok
|
||||
}
|
||||
|
||||
const (
|
||||
vncProxyHost = "vnc.proxy.local"
|
||||
vncProxyScheme = "ws"
|
||||
vncDialTimeout = 15 * time.Second
|
||||
|
||||
// Connection modes matching server/server.go constants.
|
||||
modeAttach byte = 0
|
||||
modeSession byte = 1
|
||||
|
||||
// WebSocket close codes the dashboard branches on. Codes 1000-1015
|
||||
// are reserved by RFC 6455; 4000-4999 are application-defined.
|
||||
wsCodeNormal = 1000
|
||||
wsCodeAbnormal = 1006
|
||||
wsCodeDialTimeout = 4001
|
||||
wsCodeDialFailure = 4002
|
||||
wsCodeSessionSetup = 4003
|
||||
wsCodeTransport = 4004
|
||||
)
|
||||
|
||||
// VNCProxy bridges WebSocket connections from noVNC in the browser
|
||||
// to TCP VNC server connections through the NetBird tunnel.
|
||||
type vncNBClient interface {
|
||||
Dial(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
type VNCProxy struct {
|
||||
nbClient vncNBClient
|
||||
activeConnections map[string]*vncConnection
|
||||
destinations map[string]vncDestination
|
||||
// pendingHandlers holds the js.Func for handleVNCWebSocket_<id> between
|
||||
// CreateProxy and handleWebSocketConnection so we can move it onto the
|
||||
// vncConnection for later release.
|
||||
pendingHandlers map[string]js.Func
|
||||
mu sync.Mutex
|
||||
nextID atomic.Uint64
|
||||
}
|
||||
|
||||
type vncDestination struct {
|
||||
address string
|
||||
mode byte
|
||||
username string
|
||||
sessionPriv []byte
|
||||
sessionPub []byte
|
||||
sessionID uint32
|
||||
width uint16
|
||||
height uint16
|
||||
peerPubKey []byte
|
||||
}
|
||||
|
||||
type vncConnection struct {
|
||||
id string
|
||||
destination vncDestination
|
||||
mu sync.Mutex
|
||||
vncConn net.Conn
|
||||
wsHandlers js.Value
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
// Go-side callbacks exposed to JS. js.FuncOf pins the Go closure in a
|
||||
// global handle map and MUST be released, otherwise every connection
|
||||
// leaks the Go memory the closure captures.
|
||||
wsHandlerFn js.Func
|
||||
onMessageFn js.Func
|
||||
onCloseFn js.Func
|
||||
}
|
||||
|
||||
// NewVNCProxy creates a new VNC proxy.
|
||||
func NewVNCProxy(client vncNBClient) *VNCProxy {
|
||||
return &VNCProxy{
|
||||
nbClient: client,
|
||||
activeConnections: make(map[string]*vncConnection),
|
||||
}
|
||||
}
|
||||
|
||||
// ProxyRequest bundles the per-call parameters for CreateProxy so the JS
|
||||
// boundary doesn't drown callers in a wide positional argument list.
|
||||
type ProxyRequest struct {
|
||||
Hostname string
|
||||
Port string
|
||||
Mode string
|
||||
Username string
|
||||
SessionID uint32
|
||||
Width uint16
|
||||
Height uint16
|
||||
// PeerPublicKey is the destination peer's base64 X25519 public key,
|
||||
// used as the responder static in the Noise_IK handshake.
|
||||
PeerPublicKey string
|
||||
// KeySessionID is the handle returned by generateVNCSessionKey. The
|
||||
// matching private key is looked up inside wasm and never crosses
|
||||
// the JS boundary.
|
||||
KeySessionID string
|
||||
}
|
||||
|
||||
// CreateProxy creates a new proxy endpoint for the given VNC destination.
|
||||
// req.Mode is "attach" (capture current display) or "session" (virtual session).
|
||||
// req.Username is required for session mode. req.Width/Height request the
|
||||
// virtual display geometry for session mode; 0 means use the server default.
|
||||
// Returns a JS Promise that resolves to the WebSocket proxy URL.
|
||||
func (p *VNCProxy) CreateProxy(req ProxyRequest) js.Value {
|
||||
hostname, port, mode, username := req.Hostname, req.Port, req.Mode, req.Username
|
||||
sessionID, width, height := req.SessionID, req.Width, req.Height
|
||||
address := net.JoinHostPort(hostname, port)
|
||||
|
||||
var m byte
|
||||
if mode == "session" {
|
||||
m = modeSession
|
||||
}
|
||||
|
||||
dest := vncDestination{
|
||||
address: address,
|
||||
mode: m,
|
||||
username: username,
|
||||
sessionID: sessionID,
|
||||
width: width,
|
||||
height: height,
|
||||
}
|
||||
if req.KeySessionID != "" {
|
||||
kp, ok := consumeSessionKey(req.KeySessionID)
|
||||
if !ok {
|
||||
return rejectedPromise("unknown VNC session id")
|
||||
}
|
||||
dest.sessionPriv = kp.Private
|
||||
dest.sessionPub = kp.Public
|
||||
pub, err := decodePeerPubKey(req.PeerPublicKey)
|
||||
if err != nil {
|
||||
return rejectedPromise(fmt.Sprintf("invalid peer public key: %v", err))
|
||||
}
|
||||
dest.peerPubKey = pub
|
||||
}
|
||||
return p.newProxyPromise(address, mode, username, dest)
|
||||
}
|
||||
|
||||
// decodePeerPubKey parses a base64-encoded 32-byte X25519 public key.
|
||||
func decodePeerPubKey(b64 string) ([]byte, error) {
|
||||
if b64 == "" {
|
||||
return nil, errors.New("peer public key missing")
|
||||
}
|
||||
raw, err := base64.StdEncoding.DecodeString(b64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("base64 decode: %w", err)
|
||||
}
|
||||
if len(raw) != 32 {
|
||||
return nil, fmt.Errorf("expected 32 bytes, got %d", len(raw))
|
||||
}
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
// rejectedPromise returns a resolved Promise carrying msg as an error
|
||||
// string, mirroring how CreateProxy reports earlier validation failures.
|
||||
func rejectedPromise(msg string) js.Value {
|
||||
promise := js.Global().Get("Promise")
|
||||
return promise.Call("resolve", js.ValueOf(msg))
|
||||
}
|
||||
|
||||
// newProxyPromise wraps the JS Promise creation + executor lifecycle so
|
||||
// CreateProxy stays a thin parameter-bundling entrypoint.
|
||||
func (p *VNCProxy) newProxyPromise(address, mode, username string, dest vncDestination) js.Value {
|
||||
|
||||
var executor js.Func
|
||||
executor = js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
resolve := args[0]
|
||||
|
||||
go func() {
|
||||
defer executor.Release()
|
||||
|
||||
proxyID := fmt.Sprintf("vnc_proxy_%d", p.nextID.Add(1))
|
||||
|
||||
p.mu.Lock()
|
||||
if p.destinations == nil {
|
||||
p.destinations = make(map[string]vncDestination)
|
||||
}
|
||||
p.destinations[proxyID] = dest
|
||||
p.mu.Unlock()
|
||||
|
||||
proxyURL := fmt.Sprintf("%s://%s/%s", vncProxyScheme, vncProxyHost, proxyID)
|
||||
|
||||
handlerFn := js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: requires WebSocket argument")
|
||||
}
|
||||
p.handleWebSocketConnection(args[0], proxyID)
|
||||
return nil
|
||||
})
|
||||
p.mu.Lock()
|
||||
if p.pendingHandlers == nil {
|
||||
p.pendingHandlers = make(map[string]js.Func)
|
||||
}
|
||||
p.pendingHandlers[proxyID] = handlerFn
|
||||
p.mu.Unlock()
|
||||
js.Global().Set(fmt.Sprintf("handleVNCWebSocket_%s", proxyID), handlerFn)
|
||||
|
||||
log.Infof("created VNC proxy: %s -> %s (mode=%s, user=%s)", proxyURL, address, mode, username)
|
||||
resolve.Invoke(proxyURL)
|
||||
}()
|
||||
|
||||
return nil
|
||||
})
|
||||
return js.Global().Get("Promise").New(executor)
|
||||
}
|
||||
|
||||
func (p *VNCProxy) handleWebSocketConnection(ws js.Value, proxyID string) {
|
||||
p.mu.Lock()
|
||||
dest, ok := p.destinations[proxyID]
|
||||
handlerFn := p.pendingHandlers[proxyID]
|
||||
delete(p.pendingHandlers, proxyID)
|
||||
p.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
log.Errorf("no destination for VNC proxy %s", proxyID)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
conn := &vncConnection{
|
||||
id: proxyID,
|
||||
destination: dest,
|
||||
wsHandlers: ws,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
wsHandlerFn: handlerFn,
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
p.activeConnections[proxyID] = conn
|
||||
p.mu.Unlock()
|
||||
|
||||
p.setupWebSocketHandlers(ws, conn)
|
||||
go p.connectToVNC(conn)
|
||||
|
||||
log.Infof("VNC proxy WebSocket connection established for %s", proxyID)
|
||||
}
|
||||
|
||||
func (p *VNCProxy) setupWebSocketHandlers(ws js.Value, conn *vncConnection) {
|
||||
conn.onMessageFn = js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return nil
|
||||
}
|
||||
data := args[0]
|
||||
go p.handleWebSocketMessage(conn, data)
|
||||
return nil
|
||||
})
|
||||
ws.Set("onGoMessage", conn.onMessageFn)
|
||||
|
||||
conn.onCloseFn = js.FuncOf(func(_ js.Value, _ []js.Value) any {
|
||||
log.Debug("VNC WebSocket closed by JavaScript")
|
||||
conn.cancel()
|
||||
return nil
|
||||
})
|
||||
ws.Set("onGoClose", conn.onCloseFn)
|
||||
}
|
||||
|
||||
func (p *VNCProxy) handleWebSocketMessage(conn *vncConnection, data js.Value) {
|
||||
if !data.InstanceOf(js.Global().Get("Uint8Array")) {
|
||||
return
|
||||
}
|
||||
|
||||
length := data.Get("length").Int()
|
||||
buf := make([]byte, length)
|
||||
js.CopyBytesToGo(buf, data)
|
||||
|
||||
conn.mu.Lock()
|
||||
vncConn := conn.vncConn
|
||||
conn.mu.Unlock()
|
||||
|
||||
if vncConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := vncConn.Write(buf); err != nil {
|
||||
log.Debugf("write to VNC server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *VNCProxy) connectToVNC(conn *vncConnection) {
|
||||
ctx, cancel := context.WithTimeout(conn.ctx, vncDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
vncConn, err := p.nbClient.Dial(ctx, "tcp", conn.destination.address)
|
||||
if err != nil {
|
||||
log.Errorf("VNC connect to %s: %v", conn.destination.address, err)
|
||||
// Close the WebSocket so noVNC fires a disconnect event.
|
||||
code := wsCodeDialFailure
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
code = wsCodeDialTimeout
|
||||
}
|
||||
if conn.wsHandlers.Get("close").Truthy() {
|
||||
conn.wsHandlers.Call("close", code, fmt.Sprintf("connect to peer: %v", err))
|
||||
}
|
||||
p.cleanupConnection(conn)
|
||||
return
|
||||
}
|
||||
conn.mu.Lock()
|
||||
conn.vncConn = vncConn
|
||||
conn.mu.Unlock()
|
||||
|
||||
// Send the NetBird VNC session header before the RFB handshake.
|
||||
if err := p.sendSessionHeader(vncConn, conn.destination); err != nil {
|
||||
log.Errorf("send VNC session header: %v", err)
|
||||
if conn.wsHandlers.Get("close").Truthy() {
|
||||
conn.wsHandlers.Call("close", wsCodeSessionSetup, fmt.Sprintf("send session header: %v", err))
|
||||
}
|
||||
p.cleanupConnection(conn)
|
||||
return
|
||||
}
|
||||
|
||||
// WS→TCP is handled by the onGoMessage handler set in setupWebSocketHandlers,
|
||||
// which writes directly to the VNC connection as data arrives from JS.
|
||||
// Only the TCP→WS direction needs a read loop here.
|
||||
go p.forwardConnToWS(conn)
|
||||
|
||||
<-conn.ctx.Done()
|
||||
p.cleanupConnection(conn)
|
||||
}
|
||||
|
||||
// sendSessionHeader writes the NetBird VNC connection header: mode +
|
||||
// username prefix, an optional Noise_IK handshake that authenticates the
|
||||
// client and the server, then the trailing sessionID / width / height
|
||||
// fields the daemon needs once auth is settled.
|
||||
func (p *VNCProxy) sendSessionHeader(conn net.Conn, dest vncDestination) error {
|
||||
usernameBytes := []byte(dest.username)
|
||||
if len(usernameBytes) > 0xFFFF {
|
||||
return fmt.Errorf("username too long: %d bytes (max %d)", len(usernameBytes), 0xFFFF)
|
||||
}
|
||||
prefix := make([]byte, 3+len(usernameBytes))
|
||||
prefix[0] = dest.mode
|
||||
prefix[1] = byte(len(usernameBytes) >> 8)
|
||||
prefix[2] = byte(len(usernameBytes))
|
||||
copy(prefix[3:], usernameBytes)
|
||||
if err := writeAll(conn, prefix); err != nil {
|
||||
return fmt.Errorf("write header prefix: %w", err)
|
||||
}
|
||||
|
||||
if dest.sessionPriv == nil {
|
||||
return p.writeHeaderTail(conn, dest)
|
||||
}
|
||||
if err := p.runNoiseHandshake(conn, dest); err != nil {
|
||||
return fmt.Errorf("noise handshake: %w", err)
|
||||
}
|
||||
return p.writeHeaderTail(conn, dest)
|
||||
}
|
||||
|
||||
// writeHeaderTail writes the post-auth trailing fields (sessionID,
|
||||
// width, height) the daemon reads regardless of whether the Noise
|
||||
// handshake was performed.
|
||||
func (p *VNCProxy) writeHeaderTail(conn net.Conn, dest vncDestination) error {
|
||||
tail := make([]byte, 4+4)
|
||||
tail[0] = byte(dest.sessionID >> 24)
|
||||
tail[1] = byte(dest.sessionID >> 16)
|
||||
tail[2] = byte(dest.sessionID >> 8)
|
||||
tail[3] = byte(dest.sessionID)
|
||||
tail[4] = byte(dest.width >> 8)
|
||||
tail[5] = byte(dest.width)
|
||||
tail[6] = byte(dest.height >> 8)
|
||||
tail[7] = byte(dest.height)
|
||||
if err := writeAll(conn, tail); err != nil {
|
||||
return fmt.Errorf("write header tail: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// runNoiseHandshake performs the initiator side of a Noise_IK handshake
|
||||
// against the destination daemon. The session keypair authenticates the
|
||||
// client; the daemon's pre-known peer pubkey authenticates the server.
|
||||
func (p *VNCProxy) runNoiseHandshake(conn net.Conn, dest vncDestination) error {
|
||||
state, err := noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: vncNoiseSuite,
|
||||
Pattern: noise.HandshakeIK,
|
||||
Initiator: true,
|
||||
StaticKeypair: noise.DHKey{Private: dest.sessionPriv, Public: dest.sessionPub},
|
||||
PeerStatic: dest.peerPubKey,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("noise initiator init: %w", err)
|
||||
}
|
||||
msg1, _, _, err := state.WriteMessage(nil, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("noise write msg1: %w", err)
|
||||
}
|
||||
out := make([]byte, 0, len(vncIdentityMagic)+len(msg1))
|
||||
out = append(out, vncIdentityMagic...)
|
||||
out = append(out, msg1...)
|
||||
if err := writeAll(conn, out); err != nil {
|
||||
return fmt.Errorf("send noise msg1: %w", err)
|
||||
}
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
return fmt.Errorf("set noise deadline: %w", err)
|
||||
}
|
||||
defer conn.SetReadDeadline(time.Time{}) //nolint:errcheck
|
||||
msg2 := make([]byte, noiseResponderMsgLen)
|
||||
if _, err := io.ReadFull(conn, msg2); err != nil {
|
||||
return fmt.Errorf("read noise msg2: %w", err)
|
||||
}
|
||||
if _, _, _, err := state.ReadMessage(nil, msg2); err != nil {
|
||||
return fmt.Errorf("noise read msg2: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeAll(conn net.Conn, buf []byte) error {
|
||||
for off := 0; off < len(buf); {
|
||||
n, err := conn.Write(buf[off:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
off += n
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *VNCProxy) forwardConnToWS(conn *vncConnection) {
|
||||
buf := make([]byte, 32*1024)
|
||||
|
||||
for {
|
||||
if conn.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
vc, ok := conn.snapshotVNC()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := vc.SetReadDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
||||
log.Debugf("set VNC read deadline: %v", err)
|
||||
}
|
||||
n, err := vc.Read(buf)
|
||||
if err != nil {
|
||||
if p.handleConnReadError(conn, err) {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if n > 0 {
|
||||
p.sendToWebSocket(conn, buf[:n])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// snapshotVNC returns the current vncConn under conn.mu, with ok=false when
|
||||
// the connection has already been cleaned up.
|
||||
func (c *vncConnection) snapshotVNC() (net.Conn, bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.vncConn == nil {
|
||||
return nil, false
|
||||
}
|
||||
return c.vncConn, true
|
||||
}
|
||||
|
||||
// handleConnReadError classifies an error from the VNC read loop. Returns
|
||||
// true if the caller should exit and trigger the cleanup path. A read
|
||||
// timeout counts as a fatal error: in a healthy session the server emits
|
||||
// empty FramebufferUpdate responses several times per second, so a full
|
||||
// idleReadDeadline of silence means the peer is dead (process gone,
|
||||
// machine off, network partition) and the in-browser TCP stack will
|
||||
// never surface that on its own.
|
||||
func (p *VNCProxy) handleConnReadError(conn *vncConnection, err error) bool {
|
||||
if conn.ctx.Err() != nil {
|
||||
return true
|
||||
}
|
||||
if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
|
||||
log.Debugf("VNC read deadline expired; treating peer as dead")
|
||||
} else if err != io.EOF {
|
||||
log.Debugf("read from VNC connection: %v", err)
|
||||
}
|
||||
if conn.wsHandlers.Get("close").Truthy() {
|
||||
conn.wsHandlers.Call("close", wsCodeTransport, "VNC connection lost")
|
||||
}
|
||||
conn.cancel()
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *VNCProxy) sendToWebSocket(conn *vncConnection, data []byte) {
|
||||
if conn.wsHandlers.Get("receiveFromGo").Truthy() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||
js.CopyBytesToJS(uint8Array, data)
|
||||
conn.wsHandlers.Call("receiveFromGo", uint8Array.Get("buffer"))
|
||||
} else if conn.wsHandlers.Get("send").Truthy() {
|
||||
uint8Array := js.Global().Get("Uint8Array").New(len(data))
|
||||
js.CopyBytesToJS(uint8Array, data)
|
||||
conn.wsHandlers.Call("send", uint8Array.Get("buffer"))
|
||||
}
|
||||
}
|
||||
|
||||
func (p *VNCProxy) cleanupConnection(conn *vncConnection) {
|
||||
log.Debugf("cleaning up VNC connection %s", conn.id)
|
||||
conn.cancel()
|
||||
|
||||
conn.mu.Lock()
|
||||
vncConn := conn.vncConn
|
||||
conn.vncConn = nil
|
||||
conn.mu.Unlock()
|
||||
|
||||
if vncConn != nil {
|
||||
if err := vncConn.Close(); err != nil {
|
||||
log.Debugf("close VNC connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the global JS handler registered in CreateProxy.
|
||||
globalName := fmt.Sprintf("handleVNCWebSocket_%s", conn.id)
|
||||
js.Global().Delete(globalName)
|
||||
|
||||
// Release all js.Func handles; js.FuncOf pins the Go closure and the
|
||||
// allocations it captures until Release is called.
|
||||
conn.wsHandlerFn.Release()
|
||||
conn.onMessageFn.Release()
|
||||
conn.onCloseFn.Release()
|
||||
|
||||
p.mu.Lock()
|
||||
delete(p.activeConnections, conn.id)
|
||||
delete(p.destinations, conn.id)
|
||||
delete(p.pendingHandlers, conn.id)
|
||||
p.mu.Unlock()
|
||||
}
|
||||
3
go.mod
3
go.mod
@@ -51,6 +51,7 @@ require (
|
||||
github.com/eko/gocache/lib/v4 v4.2.0
|
||||
github.com/eko/gocache/store/go_cache/v4 v4.2.2
|
||||
github.com/eko/gocache/store/redis/v4 v4.2.2
|
||||
github.com/flynn/noise v1.1.0
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/gliderlabs/ssh v0.3.8
|
||||
github.com/go-jose/go-jose/v4 v4.1.4
|
||||
@@ -66,6 +67,8 @@ require (
|
||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
||||
github.com/hashicorp/go-version v1.7.0
|
||||
github.com/jackc/pgx/v5 v5.5.5
|
||||
github.com/jezek/xgb v1.3.0
|
||||
github.com/kirides/go-d3d v1.0.1
|
||||
github.com/libdns/route53 v1.5.0
|
||||
github.com/libp2p/go-nat v0.2.0
|
||||
github.com/libp2p/go-netroute v0.4.0
|
||||
|
||||
8
go.sum
8
go.sum
@@ -162,6 +162,8 @@ github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g=
|
||||
github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
|
||||
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
||||
github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw=
|
||||
github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
@@ -378,6 +380,8 @@ github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZ
|
||||
github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
|
||||
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade h1:FmusiCI1wHw+XQbvL9M+1r/C3SPqKrmBaIOYwVfQoDE=
|
||||
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade/go.mod h1:ZDXo8KHryOWSIqnsb/CiDq7hQUYryCgdVnxbj8tDG7o=
|
||||
github.com/jezek/xgb v1.3.0 h1:Wa1pn4GVtcmNVAVB6/pnQVJ7xPFZVZ/W1Tc27msDhgI=
|
||||
github.com/jezek/xgb v1.3.0/go.mod h1:nrhwO0FX/enq75I7Y7G8iN1ubpSGZEiA3v9e9GyRFlk=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
@@ -396,6 +400,8 @@ github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6U
|
||||
github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE=
|
||||
github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8=
|
||||
github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg=
|
||||
github.com/kirides/go-d3d v1.0.1 h1:ZDANfvo34vskBMET1uwUUMNw8545Kbe8qYSiRwlNIuA=
|
||||
github.com/kirides/go-d3d v1.0.1/go.mod h1:99AjD+5mRTFEnkpRWkwq8UYMQDljGIIvLn2NyRdVImY=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
@@ -408,6 +414,7 @@ github.com/koron/go-ssdp v0.0.4/go.mod h1:oDXq+E5IL5q0U8uSBcoAXzTzInwy5lEgC91HoK
|
||||
github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8=
|
||||
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
@@ -752,6 +759,7 @@ goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE=
|
||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
||||
|
||||
@@ -17,7 +17,7 @@ type store interface {
|
||||
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||
GetProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error)
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
|
||||
@@ -57,7 +57,7 @@ func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
|
||||
func (m *mockStore) GetProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
|
||||
|
||||
@@ -42,10 +42,35 @@ func (Proxy) TableName() string {
|
||||
return "proxies"
|
||||
}
|
||||
|
||||
// ClusterType is the source of a proxy cluster.
|
||||
type ClusterType string
|
||||
|
||||
const (
|
||||
// ClusterTypeAccount is a cluster operated by the account itself (BYOP) —
|
||||
// at least one proxy row in the cluster carries a non-NULL account_id.
|
||||
ClusterTypeAccount ClusterType = "account"
|
||||
// ClusterTypeShared is a cluster operated by NetBird and shared across
|
||||
// accounts — all proxy rows in the cluster have account_id IS NULL.
|
||||
ClusterTypeShared ClusterType = "shared"
|
||||
)
|
||||
|
||||
// Cluster represents a group of proxy nodes serving the same address.
|
||||
//
|
||||
// Online and ConnectedProxies derive from the same 2-min active window
|
||||
// the rest of the module uses, but Cluster rows are not gated on it —
|
||||
// the cluster listing surfaces offline clusters too so operators can
|
||||
// see and clean them up. The 1-hour heartbeat reaper still bounds the
|
||||
// table eventually.
|
||||
type Cluster struct {
|
||||
ID string
|
||||
Address string
|
||||
Type ClusterType
|
||||
Online bool
|
||||
ConnectedProxies int
|
||||
SelfHosted bool
|
||||
// Capability flags. *bool because nil means "no proxy reported a
|
||||
// capability for this cluster" — the dashboard renders these as
|
||||
// unknown rather than false.
|
||||
SupportsCustomPorts *bool
|
||||
RequireSubdomain *bool
|
||||
SupportsCrowdSec *bool
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
|
||||
GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
|
||||
DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error
|
||||
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
|
||||
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
|
||||
|
||||
@@ -65,20 +65,6 @@ func (mr *MockManagerMockRecorder) CreateServiceFromPeer(ctx, accountID, peerID,
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateServiceFromPeer", reflect.TypeOf((*MockManager)(nil).CreateServiceFromPeer), ctx, accountID, peerID, req)
|
||||
}
|
||||
|
||||
// DeleteAllServices mocks base method.
|
||||
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAllServices indicates an expected call of DeleteAllServices.
|
||||
func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// DeleteAccountCluster mocks base method.
|
||||
func (m *MockManager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -93,6 +79,20 @@ func (mr *MockManagerMockRecorder) DeleteAccountCluster(ctx, accountID, userID,
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAccountCluster", reflect.TypeOf((*MockManager)(nil).DeleteAccountCluster), ctx, accountID, userID, clusterAddress)
|
||||
}
|
||||
|
||||
// DeleteAllServices mocks base method.
|
||||
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAllServices indicates an expected call of DeleteAllServices.
|
||||
func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// DeleteService mocks base method.
|
||||
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -122,21 +122,6 @@ func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{}
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID)
|
||||
}
|
||||
|
||||
// GetActiveClusters mocks base method.
|
||||
func (m *MockManager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveClusters", ctx, accountID, userID)
|
||||
ret0, _ := ret[0].([]proxy.Cluster)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveClusters indicates an expected call of GetActiveClusters.
|
||||
func (mr *MockManagerMockRecorder) GetActiveClusters(ctx, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// GetAllServices mocks base method.
|
||||
func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -152,19 +137,19 @@ func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interfa
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// GetServiceByDomain mocks base method.
|
||||
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
|
||||
// GetClusters mocks base method.
|
||||
func (m *MockManager) GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
|
||||
ret0, _ := ret[0].(*Service)
|
||||
ret := m.ctrl.Call(m, "GetClusters", ctx, accountID, userID)
|
||||
ret0, _ := ret[0].([]proxy.Cluster)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
|
||||
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
|
||||
// GetClusters indicates an expected call of GetClusters.
|
||||
func (mr *MockManagerMockRecorder) GetClusters(ctx, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusters", reflect.TypeOf((*MockManager)(nil).GetClusters), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// GetGlobalServices mocks base method.
|
||||
@@ -197,6 +182,21 @@ func (mr *MockManagerMockRecorder) GetService(ctx, accountID, userID, serviceID
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetService", reflect.TypeOf((*MockManager)(nil).GetService), ctx, accountID, userID, serviceID)
|
||||
}
|
||||
|
||||
// GetServiceByDomain mocks base method.
|
||||
func (m *MockManager) GetServiceByDomain(ctx context.Context, domain string) (*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, domain)
|
||||
ret0, _ := ret[0].(*Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServiceByDomain indicates an expected call of GetServiceByDomain.
|
||||
func (mr *MockManagerMockRecorder) GetServiceByDomain(ctx, domain interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByDomain", reflect.TypeOf((*MockManager)(nil).GetServiceByDomain), ctx, domain)
|
||||
}
|
||||
|
||||
// GetServiceByID mocks base method.
|
||||
func (m *MockManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -187,7 +187,7 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
clusters, err := h.manager.GetClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -196,10 +196,14 @@ func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
|
||||
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
|
||||
for _, c := range clusters {
|
||||
apiClusters = append(apiClusters, api.ProxyCluster{
|
||||
Id: c.ID,
|
||||
Address: c.Address,
|
||||
ConnectedProxies: c.ConnectedProxies,
|
||||
SelfHosted: c.SelfHosted,
|
||||
Id: c.ID,
|
||||
Address: c.Address,
|
||||
Type: api.ProxyClusterType(c.Type),
|
||||
Online: c.Online,
|
||||
ConnectedProxies: c.ConnectedProxies,
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
SupportsCrowdsec: c.SupportsCrowdSec,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -81,6 +81,7 @@ type ClusterDeriver interface {
|
||||
type CapabilityProvider interface {
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
@@ -112,8 +113,12 @@ func (m *Manager) StartExposeReaper(ctx context.Context) {
|
||||
m.exposeReaper.StartExposeReaper(ctx)
|
||||
}
|
||||
|
||||
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
|
||||
func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||
// GetClusters returns every proxy cluster visible to the account
|
||||
// (shared + its own BYOP), regardless of whether any proxy in the
|
||||
// cluster is currently heartbeating. Each cluster is enriched with the
|
||||
// capability flags reported by its active proxies so the dashboard can
|
||||
// render feature support without a second round-trip.
|
||||
func (m *Manager) GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
@@ -122,7 +127,18 @@ func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID strin
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetActiveProxyClusters(ctx, accountID)
|
||||
clusters, err := m.store.GetProxyClusters(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range clusters {
|
||||
clusters[i].SupportsCustomPorts = m.capabilities.ClusterSupportsCustomPorts(ctx, clusters[i].Address)
|
||||
clusters[i].RequireSubdomain = m.capabilities.ClusterRequireSubdomain(ctx, clusters[i].Address)
|
||||
clusters[i].SupportsCrowdSec = m.capabilities.ClusterSupportsCrowdSec(ctx, clusters[i].Address)
|
||||
}
|
||||
|
||||
return clusters, nil
|
||||
}
|
||||
|
||||
// DeleteAccountCluster removes all proxy registrations for the given cluster address
|
||||
|
||||
@@ -2,6 +2,7 @@ package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
@@ -98,10 +99,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
|
||||
|
||||
sshConfig := &proto.SSHConfig{
|
||||
SshEnabled: peer.SSHEnabled || enableSSH,
|
||||
}
|
||||
|
||||
if sshConfig.SshEnabled {
|
||||
sshConfig.JwtConfig = buildJWTConfig(httpConfig, deviceFlowConfig)
|
||||
JwtConfig: buildJWTConfig(httpConfig, deviceFlowConfig),
|
||||
}
|
||||
|
||||
peerConfig := &proto.PeerConfig{
|
||||
@@ -134,13 +132,14 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
includeIPv6 := peer.SupportsIPv6() && peer.IPv6.IsValid()
|
||||
useSourcePrefixes := peer.SupportsSourcePrefixes()
|
||||
|
||||
peerConfig := toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH)
|
||||
response := &proto.SyncResponse{
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||
PeerConfig: peerConfig,
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
|
||||
PeerConfig: peerConfig,
|
||||
},
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
@@ -149,8 +148,6 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||
response.NetbirdConfig = extendedConfig
|
||||
|
||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||
|
||||
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
|
||||
response.RemotePeers = remotePeers
|
||||
@@ -176,18 +173,59 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
response.NetworkMap.ForwardingRules = forwardingRules
|
||||
}
|
||||
|
||||
userIDClaim := auth.DefaultUserIDClaim
|
||||
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||
userIDClaim = httpConfig.AuthUserIDClaim
|
||||
}
|
||||
|
||||
if networkMap.AuthorizedUsers != nil {
|
||||
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
|
||||
userIDClaim := auth.DefaultUserIDClaim
|
||||
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
|
||||
userIDClaim = httpConfig.AuthUserIDClaim
|
||||
}
|
||||
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
|
||||
}
|
||||
|
||||
if networkMap.VNCAuthorizedUsers != nil {
|
||||
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.VNCAuthorizedUsers)
|
||||
response.NetworkMap.VncAuth = &proto.VNCAuth{
|
||||
AuthorizedUsers: hashedUsers,
|
||||
MachineUsers: machineUsers,
|
||||
SessionPubKeys: buildSessionPubKeysProto(ctx, networkMap.VNCSessionPubKeys),
|
||||
}
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// buildSessionPubKeysProto decodes base64 X25519 session pubkeys and
|
||||
// hashes the user IDs they belong to, emitting the proto entries the
|
||||
// daemon's authorizer indexes by pubkey.
|
||||
func buildSessionPubKeysProto(ctx context.Context, in []types.VNCSessionPubKey) []*proto.SessionPubKey {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*proto.SessionPubKey, 0, len(in))
|
||||
for _, e := range in {
|
||||
pub, err := base64.StdEncoding.DecodeString(e.PubKey)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("decode VNC session pubkey: %v", err)
|
||||
continue
|
||||
}
|
||||
if len(pub) != 32 {
|
||||
log.WithContext(ctx).Warnf("VNC session pubkey wrong length: %d", len(pub))
|
||||
continue
|
||||
}
|
||||
hash, err := sshauth.HashUserID(e.UserID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("hash VNC session user id: %v", err)
|
||||
continue
|
||||
}
|
||||
out = append(out, &proto.SessionPubKey{
|
||||
PubKey: pub,
|
||||
UserIdHash: hash[:],
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
|
||||
userIDToIndex := make(map[string]uint32)
|
||||
var hashedUsers [][]byte
|
||||
|
||||
@@ -109,7 +109,7 @@ func (m *mockReverseProxyManager) GetServiceByDomain(_ context.Context, domain s
|
||||
return nil, errors.New("service not found for domain: " + domain)
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
func (m *mockReverseProxyManager) GetClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user