Address CodeRabbit feedback on VNC server

This commit is contained in:
Viktor Liu
2026-05-20 17:16:45 +02:00
parent 17359cdc1e
commit 640a267556
14 changed files with 187 additions and 69 deletions

View File

@@ -10,6 +10,7 @@ import (
"net"
"os"
"os/exec"
"path/filepath"
"strconv"
"sync"
"syscall"
@@ -95,10 +96,13 @@ func (m *darwinAgentManager) ensure(ctx context.Context) (string, error) {
return m.authToken, nil
}
m.killLocked()
// Reap any stray external vnc-agent so the new token is the only one
// the freshly spawned agent will accept on the loopback port.
killAllVNCAgents()
token := generateAuthToken()
if token == "" {
return "", fmt.Errorf("generate agent auth token")
token, err := generateAuthToken()
if err != nil {
return "", fmt.Errorf("generate agent auth token: %w", err)
}
if err := spawnAgentForUser(consoleUID, m.port, token); err != nil {
return "", err
@@ -248,13 +252,16 @@ func killAllVNCAgents() {
}
}
// vncAgentPIDs returns the pids of every process whose argv contains
// "vnc-agent". Skips pid 0 and 1 defensively.
// vncAgentPIDs returns the pids of vnc-agent subprocesses spawned from
// this binary. Matches on (argv[0] basename == our own basename) AND
// argv contains the "vnc-agent" subcommand. Skips pid 0 and 1 defensively.
func vncAgentPIDs() ([]int, error) {
procs, err := unix.SysctlKinfoProcSlice("kern.proc.all")
if err != nil {
return nil, fmt.Errorf("sysctl kern.proc.all: %w", err)
}
ownExe, _ := os.Executable()
ownBase := filepath.Base(ownExe)
var out []int
for i := range procs {
pid := int(procs[i].Proc.P_pid)
@@ -262,7 +269,7 @@ func vncAgentPIDs() ([]int, error) {
continue
}
argv, err := procArgv(pid)
if err != nil || !argvIsVNCAgent(argv) {
if err != nil || !argvIsVNCAgent(argv, ownBase) {
continue
}
out = append(out, pid)
@@ -305,8 +312,17 @@ func procArgv(pid int) ([]string, error) {
return args, nil
}
func argvIsVNCAgent(argv []string) bool {
for _, a := range argv {
// argvIsVNCAgent reports whether argv belongs to a vnc-agent subprocess
// spawned from our binary. Requires argv[0]'s basename to match ownBase
// and the "vnc-agent" subcommand to appear among the positional args.
func argvIsVNCAgent(argv []string, ownBase string) bool {
if len(argv) < 2 || ownBase == "" {
return false
}
if filepath.Base(argv[0]) != ownBase {
return false
}
for _, a := range argv[1:] {
if a == "vnc-agent" {
return true
}

View File

@@ -36,15 +36,13 @@ const (
// generateAuthToken returns a fresh hex-encoded random token for one
// daemon→agent session. The daemon hands this to the spawned agent
// out-of-band (env var on Windows) and verifies it on every connection
// the agent accepts. Returns the empty string on a randomness failure;
// callers should treat that as an error.
func generateAuthToken() string {
// the agent accepts.
func generateAuthToken() (string, error) {
b := make([]byte, agentTokenLen)
if _, err := crand.Read(b); err != nil {
log.Warnf("generate agent auth token: %v", err)
return ""
return "", fmt.Errorf("read random: %w", err)
}
return hex.EncodeToString(b)
return hex.EncodeToString(b), nil
}
// proxyToAgent dials the per-session agent on TCP loopback, writes the
@@ -63,7 +61,11 @@ func proxyToAgent(client net.Conn, port uint16, authToken string) {
}
defer agentConn.Close()
tokenBytes, _ := hex.DecodeString(authToken)
tokenBytes, err := hex.DecodeString(authToken)
if err != nil || len(tokenBytes) != agentTokenLen {
log.Warnf("invalid auth token (len=%d): %v", len(tokenBytes), err)
return
}
if _, err := agentConn.Write(tokenBytes); err != nil {
log.Warnf("send auth token to agent: %v", err)
return

View File

@@ -626,7 +626,12 @@ func (m *sessionManager) maybeSpawnAgent(sid uint32) bool {
if !m.everSpawned {
reapOrphanOnPort(m.port)
}
m.authToken = generateAuthToken()
token, err := generateAuthToken()
if err != nil {
log.Warnf("generate agent auth token: %v", err)
return true
}
m.authToken = token
h, err := spawnAgentInSession(sid, m.port, m.authToken, m.jobHandle)
if err != nil {
m.authToken = ""
@@ -657,11 +662,11 @@ func (m *sessionManager) killAgent() {
}
// relogAgentOutput reads log lines from the agent's stderr pipe and
// relogs them with the service's formatter.
// relogs them with the service's formatter. The *os.File owns the
// underlying handle, so closing it suffices.
func relogAgentOutput(pipe windows.Handle) {
defer func() { _ = windows.CloseHandle(pipe) }()
f := os.NewFile(uintptr(pipe), "vnc-agent-stderr")
defer f.Close()
defer func() { _ = f.Close() }()
relogAgentStream(f)
}

View File

@@ -7,6 +7,7 @@ import (
"image"
"os"
"os/exec"
"strconv"
"strings"
"sync"
"sync/atomic"
@@ -94,24 +95,35 @@ func detectX11FromSockets() bool {
return false
}
// Find the lowest display number.
// Pick the lowest numeric display rather than the lexically first
// entry, so X10 doesn't win over X2.
minDisplay := -1
for _, e := range entries {
name := e.Name()
if len(name) < 2 || name[0] != 'X' {
continue
}
display := ":" + name[1:]
os.Setenv("DISPLAY", display)
auth := findXorgAuthFromPS()
if auth != "" {
os.Setenv("XAUTHORITY", auth)
log.Infof("auto-detected DISPLAY=%s (from socket) XAUTHORITY=%s (from ps)", display, auth)
} else {
log.Infof("auto-detected DISPLAY=%s (from socket)", display)
n, err := strconv.Atoi(name[1:])
if err != nil {
continue
}
if minDisplay < 0 || n < minDisplay {
minDisplay = n
}
return true
}
return false
if minDisplay < 0 {
return false
}
display := ":" + strconv.Itoa(minDisplay)
os.Setenv("DISPLAY", display)
auth := findXorgAuthFromPS()
if auth != "" {
os.Setenv("XAUTHORITY", auth)
log.Infof("auto-detected DISPLAY=%s (from socket) XAUTHORITY=%s (from ps)", display, auth)
} else {
log.Infof("auto-detected DISPLAY=%s (from socket)", display)
}
return true
}
// findXorgAuthFromPS runs ps to find Xorg and extract its -auth argument.

View File

@@ -173,10 +173,10 @@ func decodeCursor(hCur windows.Handle) (*image.RGBA, int, int, error) {
}
defer func() {
if info.HbmMask != 0 {
procDeleteObject.Call(uintptr(info.HbmMask))
_, _, _ = procDeleteObject.Call(uintptr(info.HbmMask))
}
if info.HbmColor != 0 {
procDeleteObject.Call(uintptr(info.HbmColor))
_, _, _ = procDeleteObject.Call(uintptr(info.HbmColor))
}
}()
hotX, hotY := int(info.XHotspot), int(info.YHotspot)
@@ -212,12 +212,12 @@ func dibCopy(hbm windows.Handle, w, h int32) ([]byte, error) {
if hdcScreen == 0 {
return nil, fmt.Errorf("GetDC: failed")
}
defer procReleaseDC.Call(0, hdcScreen)
defer func() { _, _, _ = procReleaseDC.Call(0, hdcScreen) }()
hdcMem, _, _ := procCreateCompatDC.Call(hdcScreen)
if hdcMem == 0 {
return nil, fmt.Errorf("CreateCompatibleDC: failed")
}
defer procDeleteDC.Call(hdcMem)
defer func() { _, _, _ = procDeleteDC.Call(hdcMem) }()
var bih winBitmapInfoHeader
bih.BiSize = dibSectionBytes
@@ -268,6 +268,16 @@ func decodeColorCursor(hbmColor, hbmMask windows.Handle) (*image.RGBA, error) {
g := color[si+1]
r := color[si+2]
a := pixelAlpha(color[si+3], si, mask, hasAlpha)
// Premultiply so the shared compositor can use the same
// formula on every platform (X11 XFixes and macOS CG return
// premultiplied bytes natively).
if a != 255 && a != 0 {
r = byte(uint32(r) * uint32(a) / 255)
g = byte(uint32(g) * uint32(a) / 255)
b = byte(uint32(b) * uint32(a) / 255)
} else if a == 0 {
r, g, b = 0, 0, 0
}
img.Pix[si+0] = r
img.Pix[si+1] = g
img.Pix[si+2] = b

View File

@@ -91,8 +91,8 @@ func buildExtClipRequest(formats uint32) []byte {
// per the extension spec. Rejects oversized input so a caller bug can't
// produce a payload larger than the size advertised in our Caps.
func buildExtClipProvideText(text string) ([]byte, error) {
if len(text) > extClipMaxText {
return nil, fmt.Errorf("clipboard text exceeds extClipMaxText (%d > %d)", len(text), extClipMaxText)
if len(text)+1 > extClipMaxText {
return nil, fmt.Errorf("clipboard text exceeds extClipMaxText (%d > %d)", len(text)+1, extClipMaxText)
}
body := make([]byte, 0, 4+len(text)+1)
var lenBuf [4]byte

View File

@@ -110,7 +110,7 @@ func NewUInputInjector(w, h int) (*UInputInjector, error) {
return nil, fmt.Errorf("UI_SET_KEYBIT %d: %w", key, err)
}
}
for _, btn := range []uint16{btnLeft, btnRight, btnMiddle} {
for _, btn := range []uint16{btnLeft, btnRight, btnMiddle, btnSide, btnExtra} {
if err := setBit(fd, uiSetKeyBit, uint32(btn)); err != nil {
unix.Close(fd)
return nil, fmt.Errorf("UI_SET_KEYBIT btn %d: %w", btn, err)

View File

@@ -134,8 +134,15 @@ type WindowsInputInjector struct {
closed chan struct{}
closeOnce sync.Once
prevButtonMask uint16
ctrlDown bool
altDown bool
// lastQueuedButtonMask is the most recent buttonMask submitted to ch
// by InjectPointer. Compared against the incoming sample to decide
// whether the new event is move-only (lossy enqueue) or carries a
// button/wheel transition (reliable enqueue).
lastQueuedButtonMask uint16
lastQueuedMaskValid bool
queueMu sync.Mutex
ctrlDown bool
altDown bool
}
// NewWindowsInputInjector creates a desktop-aware input injector.
@@ -171,6 +178,21 @@ func (w *WindowsInputInjector) tryEnqueue(cmd inputCmd) {
}
}
// enqueueReliable posts a command and blocks until it's accepted or the
// injector closes. Used for edge-triggered events (button/wheel) where a
// drop would desynchronize prevButtonMask in dispatch().
func (w *WindowsInputInjector) enqueueReliable(cmd inputCmd) {
select {
case <-w.closed:
return
default:
}
select {
case w.ch <- cmd:
case <-w.closed:
}
}
func (w *WindowsInputInjector) loop() {
runtime.LockOSThread()
@@ -223,11 +245,22 @@ func (w *WindowsInputInjector) InjectKeyScancode(scancode uint32, keysym uint32,
}
// InjectPointer queues a pointer event for injection on the input desktop
// thread. Pointer events coalesce: when the channel is full (slow desktop
// switch, hung SendInput), drop the new sample so the read loop never
// blocks. The next mouse event carries fresher position anyway.
// thread. Move-only updates use lossy enqueue (next sample carries fresher
// position anyway), but any sample whose buttonMask differs from the last
// queued mask is enqueued reliably so wheel ticks and button transitions
// can't be dropped under backpressure.
func (w *WindowsInputInjector) InjectPointer(buttonMask uint16, x, y, serverW, serverH int) {
w.tryEnqueue(inputCmd{buttonMask: buttonMask, x: x, y: y, serverW: serverW, serverH: serverH})
cmd := inputCmd{buttonMask: buttonMask, x: x, y: y, serverW: serverW, serverH: serverH}
w.queueMu.Lock()
transition := !w.lastQueuedMaskValid || w.lastQueuedButtonMask != buttonMask
w.lastQueuedButtonMask = buttonMask
w.lastQueuedMaskValid = true
w.queueMu.Unlock()
if transition {
w.enqueueReliable(cmd)
return
}
w.tryEnqueue(cmd)
}
// doInjectKeyScancode injects a key event using the QEMU scancode directly,

View File

@@ -146,7 +146,7 @@ func (x *X11InputInjector) InjectPointer(buttonMask uint16, px, py, serverW, ser
for _, b := range buttons {
pressed := buttonMask&b.rfbBit != 0
wasPressed := x.lastButtons&b.rfbBit != 0
if b.x11Btn >= 4 {
if b.x11Btn == 4 || b.x11Btn == 5 {
// Scroll: send press+release on each scroll event.
if pressed {
xtest.FakeInput(x.conn, xproto.ButtonPress, b.x11Btn, 0, x.root, 0, 0, 0)

View File

@@ -182,6 +182,12 @@ type Server struct {
sessionSeq uint64
sessions map[uint64]ActiveSessionInfo
sessionConns map[uint64]net.Conn
// acceptedConns tracks every connection between Accept() and handler
// return, including connections still in the connection-header /
// handshake phase that have not yet been registered in sessionConns.
// closeActiveSessions iterates this set so Stop() can interrupt
// handshaking peers, not just post-handshake sessions.
acceptedConns map[net.Conn]struct{}
// sessionRecorder, when non-nil, receives a SessionTick periodically
// during each VNC session and on session close. The engine wires
@@ -219,12 +225,13 @@ type virtualSessionManager interface {
// header; the protocol-level VNC password scheme is not supported.
func New(capturer ScreenCapturer, injector InputInjector) *Server {
return &Server{
capturer: capturer,
injector: injector,
authorizer: sshauth.NewAuthorizer(),
log: log.WithField("component", "vnc-server"),
sessions: make(map[uint64]ActiveSessionInfo),
sessionConns: make(map[uint64]net.Conn),
capturer: capturer,
injector: injector,
authorizer: sshauth.NewAuthorizer(),
log: log.WithField("component", "vnc-server"),
sessions: make(map[uint64]ActiveSessionInfo),
sessionConns: make(map[uint64]net.Conn),
acceptedConns: make(map[net.Conn]struct{}),
}
}
@@ -256,15 +263,15 @@ func (s *Server) removeSession(id uint64) {
delete(s.sessionConns, id)
}
// closeActiveSessions closes every active session's connection so the
// per-session serve goroutines unblock from their Read loops and exit.
// Called from Stop to make sure clients see an immediate disconnect when
// the server is brought down, instead of waiting for the OS to reclaim
// the sockets after process exit.
// closeActiveSessions closes every accepted connection so per-connection
// goroutines unblock from their Read loops and exit. Called from Stop to
// make sure clients see an immediate disconnect when the server is brought
// down. Iterates acceptedConns so handshaking connections that have not
// yet registered in sessionConns are also closed.
func (s *Server) closeActiveSessions() {
s.sessionsMu.Lock()
conns := make([]net.Conn, 0, len(s.sessionConns))
for _, c := range s.sessionConns {
conns := make([]net.Conn, 0, len(s.acceptedConns))
for c := range s.acceptedConns {
conns = append(conns, c)
}
s.sessionsMu.Unlock()
@@ -273,6 +280,21 @@ func (s *Server) closeActiveSessions() {
}
}
// trackConn registers a freshly accepted connection so Stop() can close
// it even before the session is registered in sessionConns.
func (s *Server) trackConn(c net.Conn) {
s.sessionsMu.Lock()
s.acceptedConns[c] = struct{}{}
s.sessionsMu.Unlock()
}
// untrackConn forgets a connection once its handler is returning.
func (s *Server) untrackConn(c net.Conn) {
s.sessionsMu.Lock()
delete(s.acceptedConns, c)
s.sessionsMu.Unlock()
}
// SetServiceMode enables proxy-to-agent mode for Windows service operation.
func (s *Server) SetServiceMode(enabled bool) {
s.serviceMode = enabled
@@ -442,7 +464,11 @@ func (s *Server) acceptLoop() {
}
enableTCPKeepAlive(conn, s.log)
go s.handleConnection(conn)
s.trackConn(conn)
go func(c net.Conn) {
defer s.untrackConn(c)
s.handleConnection(c)
}(conn)
}
}

View File

@@ -49,7 +49,11 @@ func (s *Server) serviceAcceptLoop() {
enableTCPKeepAlive(conn, s.log)
conn = newMetricsConn(conn, s.sessionRecorder)
go s.handleServiceConnectionDarwin(conn, mgr)
s.trackConn(conn)
go func(c net.Conn) {
defer s.untrackConn(c)
s.handleServiceConnectionDarwin(c, mgr)
}(conn)
}
}

View File

@@ -257,7 +257,11 @@ func (s *Server) serviceAcceptLoop() {
enableTCPKeepAlive(conn, s.log)
conn = newMetricsConn(conn, s.sessionRecorder)
go s.handleServiceConnection(conn, sm)
s.trackConn(conn)
go func(c net.Conn) {
defer s.untrackConn(c)
s.handleServiceConnection(c, sm)
}(conn)
}
}

View File

@@ -171,7 +171,11 @@ func (s *session) applyBackpressure() float64 {
base := jpegQualityForLevel(tight.qualityLevel)
if base == 0 {
base = tightJPEGQuality
// No client-negotiated quality; let tightQualityFor pick the
// area-based default and skip backpressure adjustments that
// would otherwise lock in a wrong starting point.
tight.jpegQualityOverride = 0
return frac
}
q := base
if frac > backpressureRampStart {

View File

@@ -59,11 +59,12 @@ func (s *session) maybeCompositeCursor(img *image.RGBA) {
compositeCursor(img, cursorImg, posX-hotX, posY-hotY)
}
// compositeCursor alpha-blends sprite onto frame at (dstX, dstY) using
// straight (non-premultiplied) alpha. Out-of-bounds destinations are
// clipped. Frames captured by our X11/Windows/macOS paths all advertise
// RGBA with a 255-only alpha channel, so the result keeps the framebuffer
// invariant ("opaque pixels everywhere") that the encoder depends on.
// compositeCursor alpha-blends sprite onto frame at (dstX, dstY).
// sprite is assumed to use premultiplied RGBA, which is what every
// cursorSource implementation in this package produces (X11 XFixes and
// macOS CG return premultiplied bytes natively; the Windows path
// premultiplies during decodeColorCursor). Out-of-bounds destinations are
// clipped.
func compositeCursor(frame, sprite *image.RGBA, dstX, dstY int) {
fw, fh := frame.Rect.Dx(), frame.Rect.Dy()
sw, sh := sprite.Rect.Dx(), sprite.Rect.Dy()
@@ -109,10 +110,11 @@ func compositeCursor(frame, sprite *image.RGBA, dstX, dstY int) {
frame.Pix[fbOff+2] = sprite.Pix[sOff+2]
continue
}
// Premultiplied compositing: dst = src + dst*(1-srcA).
inv := 255 - a
frame.Pix[fbOff+0] = byte((uint32(sprite.Pix[sOff+0])*a + uint32(frame.Pix[fbOff+0])*inv) / 255)
frame.Pix[fbOff+1] = byte((uint32(sprite.Pix[sOff+1])*a + uint32(frame.Pix[fbOff+1])*inv) / 255)
frame.Pix[fbOff+2] = byte((uint32(sprite.Pix[sOff+2])*a + uint32(frame.Pix[fbOff+2])*inv) / 255)
frame.Pix[fbOff+0] = sprite.Pix[sOff+0] + byte((uint32(frame.Pix[fbOff+0])*inv)/255)
frame.Pix[fbOff+1] = sprite.Pix[sOff+1] + byte((uint32(frame.Pix[fbOff+1])*inv)/255)
frame.Pix[fbOff+2] = sprite.Pix[sOff+2] + byte((uint32(frame.Pix[fbOff+2])*inv)/255)
}
}
}