Files
netbird/client/vnc/server/session.go

447 lines
13 KiB
Go

//go:build !js && !ios && !android
package server
import (
"encoding/binary"
"fmt"
"image"
"io"
"net"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
readDeadline = 60 * time.Second
maxCutTextBytes = 1 << 20 // 1 MiB
)
const tileSize = 64 // pixels per tile for dirty-rect detection
// fullFramePromoteNum/Den trigger full-frame encoding when the dirty area
// exceeds num/den of the screen. Once past the crossover (benchmarks put it
// around 60% at 1080p) a single zlib rect is faster than many per-tile
// encodes AND produces about the same wire bytes: the per-tile path keeps
// restarting zlib dictionaries and re-emitting rect headers.
const (
fullFramePromoteNum = 60
fullFramePromoteDen = 100
)
type session struct {
conn net.Conn
capturer ScreenCapturer
injector InputInjector
serverW int
serverH int
desktopName string
log *log.Entry
writeMu sync.Mutex
// encMu guards the negotiated pixel format and encoding state below.
// messageLoop writes these on SetPixelFormat/SetEncodings, which RFB
// clients may send at any time after the handshake, while encoderLoop
// reads them on every frame.
encMu sync.RWMutex
pf clientPixelFormat
useTight bool
useCopyRect bool
tight *tightState
copyRectDet *copyRectDetector
// Pseudo-encodings the client advertised support for. Updated under
// encMu by handleSetEncodings and read by the encoder goroutine.
clientSupportsDesktopSize bool
clientSupportsExtendedDesktopSize bool
clientSupportsDesktopName bool
clientSupportsLastRect bool
clientSupportsQEMUKey bool
clientSupportsExtClipboard bool
extClipCapsSent bool
// clientJPEGQuality and clientZlibLevel hold the 0..9 levels the client
// advertised via the QualityLevel / CompressLevel pseudo-encodings, or
// -1 when the client has not expressed a preference. Applied to the
// tight encoder state after every SetEncodings.
clientJPEGQuality int
clientZlibLevel int
// prevFrame, curFrame and idleFrames live on the encoder goroutine and
// must not be touched elsewhere. curFrame holds a session-owned copy of
// the capturer's latest frame so the encoder works on a stable buffer
// even when the capturer double-buffers and recycles memory underneath.
prevFrame *image.RGBA
curFrame *image.RGBA
idleFrames int
// captureErrLast throttles "capture (transient)" logs while the
// capturer is in a sustained failure state (e.g. X server died but a
// noVNC tab is still open). Owned by the encoder goroutine.
captureErrLast time.Time
captureErrSeen bool
// encodeCh carries framebuffer-update requests from the read loop to the
// encoder goroutine. Buffered size 1: RFB clients have one outstanding
// request at a time, so a new request always replaces any pending one.
encodeCh chan fbRequest
}
type fbRequest struct {
incremental bool
}
func (s *session) addr() string { return s.conn.RemoteAddr().String() }
// serve runs the full RFB session lifecycle.
func (s *session) serve() {
defer s.conn.Close()
s.pf = defaultClientPixelFormat()
s.clientJPEGQuality = -1
s.clientZlibLevel = -1
s.encodeCh = make(chan fbRequest, 1)
if err := s.handshake(); err != nil {
s.log.Warnf("handshake with %s: %v", s.addr(), err)
return
}
s.log.Infof("client connected: %s", s.addr())
done := make(chan struct{})
defer close(done)
go s.clipboardPoll(done)
encoderDone := make(chan struct{})
go s.encoderLoop(encoderDone)
defer func() {
close(s.encodeCh)
<-encoderDone
}()
if err := s.messageLoop(); err != nil && err != io.EOF {
s.log.Warnf("client %s disconnected: %v", s.addr(), err)
} else {
s.log.Infof("client disconnected: %s", s.addr())
}
}
func (s *session) handshake() error {
// Send protocol version.
if _, err := io.WriteString(s.conn, rfbProtocolVersion); err != nil {
return fmt.Errorf("send version: %w", err)
}
// Read client version.
var clientVer [12]byte
if _, err := io.ReadFull(s.conn, clientVer[:]); err != nil {
return fmt.Errorf("read client version: %w", err)
}
// Send supported security types.
if err := s.sendSecurityTypes(); err != nil {
return err
}
// Read chosen security type.
var secType [1]byte
if _, err := io.ReadFull(s.conn, secType[:]); err != nil {
return fmt.Errorf("read security type: %w", err)
}
if err := s.handleSecurity(secType[0]); err != nil {
return err
}
// Read ClientInit.
var clientInit [1]byte
if _, err := io.ReadFull(s.conn, clientInit[:]); err != nil {
return fmt.Errorf("read ClientInit: %w", err)
}
return s.sendServerInit()
}
// sendSecurityTypes advertises only secNone. Authentication and access
// control are layered on top by the dashboard JWT exchange after the RFB
// handshake completes, not by the protocol-level password scheme.
func (s *session) sendSecurityTypes() error {
_, err := s.conn.Write([]byte{1, secNone})
return err
}
func (s *session) handleSecurity(secType byte) error {
if secType != secNone {
return fmt.Errorf("unsupported security type: %d", secType)
}
return binary.Write(s.conn, binary.BigEndian, uint32(0))
}
func (s *session) sendServerInit() error {
desktop := s.desktopName
if desktop == "" {
desktop = "NetBird VNC"
}
name := []byte(desktop)
buf := make([]byte, 0, 4+16+4+len(name))
// Framebuffer width and height.
buf = append(buf, byte(s.serverW>>8), byte(s.serverW))
buf = append(buf, byte(s.serverH>>8), byte(s.serverH))
// Server pixel format.
buf = append(buf, serverPixelFormat[:]...)
// Desktop name.
buf = append(buf,
byte(len(name)>>24), byte(len(name)>>16),
byte(len(name)>>8), byte(len(name)),
)
buf = append(buf, name...)
_, err := s.conn.Write(buf)
return err
}
func (s *session) messageLoop() error {
for {
var msgType [1]byte
if err := s.conn.SetDeadline(time.Now().Add(readDeadline)); err != nil {
return fmt.Errorf("set deadline: %w", err)
}
if _, err := io.ReadFull(s.conn, msgType[:]); err != nil {
return err
}
var err error
switch msgType[0] {
case clientSetPixelFormat:
err = s.handleSetPixelFormat()
case clientSetEncodings:
err = s.handleSetEncodings()
case clientFramebufferUpdateRequest:
err = s.handleFBUpdateRequest()
case clientKeyEvent:
err = s.handleKeyEvent()
case clientPointerEvent:
err = s.handlePointerEvent()
case clientCutText:
err = s.handleCutText()
case clientQEMUMessage:
err = s.handleQEMUMessage()
case clientNetbirdTypeText:
err = s.handleTypeText()
default:
return fmt.Errorf("unknown client message type: %d", msgType[0])
}
// Clear the deadline only after the full message has been read and
// processed so payload reads in the handlers stay bounded.
_ = s.conn.SetDeadline(time.Time{})
if err != nil {
return err
}
}
}
func (s *session) handleSetPixelFormat() error {
var buf [19]byte // 3 padding + 16 pixel format
if _, err := io.ReadFull(s.conn, buf[:]); err != nil {
return fmt.Errorf("read SetPixelFormat: %w", err)
}
pf, err := parsePixelFormat(buf[3:19])
if err != nil {
return err
}
s.encMu.Lock()
s.pf = pf
s.encMu.Unlock()
return nil
}
func (s *session) handleSetEncodings() error {
var header [3]byte // 1 padding + 2 number-of-encodings
if _, err := io.ReadFull(s.conn, header[:]); err != nil {
return fmt.Errorf("read SetEncodings header: %w", err)
}
numEnc := binary.BigEndian.Uint16(header[1:3])
// RFB clients advertise a handful of real encodings plus pseudo-encodings.
// Cap to keep a malicious client from forcing a 256 KiB allocation per
// SetEncodings message.
const maxEncodings = 64
if numEnc > maxEncodings {
return fmt.Errorf("SetEncodings: too many encodings (%d)", numEnc)
}
buf := make([]byte, int(numEnc)*4)
if _, err := io.ReadFull(s.conn, buf); err != nil {
return err
}
var encs []string
s.encMu.Lock()
for i := range int(numEnc) {
enc := int32(binary.BigEndian.Uint32(buf[i*4 : i*4+4]))
if name := s.applyEncoding(enc); name != "" {
encs = append(encs, name)
}
}
if s.useTight && (s.tight == nil ||
s.tight.qualityLevel != s.clientJPEGQuality ||
s.tight.compressLevel != s.clientZlibLevel) {
s.tight = newTightStateWithLevels(s.clientJPEGQuality, s.clientZlibLevel)
}
sendExtClipCaps := s.clientSupportsExtClipboard && !s.extClipCapsSent
if sendExtClipCaps {
s.extClipCapsSent = true
}
s.encMu.Unlock()
if len(encs) > 0 {
s.log.Debugf("client supports encodings: %s", strings.Join(encs, ", "))
}
if sendExtClipCaps {
if err := s.writeExtClipMessage(buildExtClipCaps()); err != nil {
return fmt.Errorf("send ext clipboard caps: %w", err)
}
}
return nil
}
// applyEncoding records a single encoding/pseudo-encoding from a SetEncodings
// message. Returns the short name used in the debug log, or "" if the value
// is one we don't recognise. Caller holds s.encMu.
func (s *session) applyEncoding(enc int32) string {
switch enc {
case encCopyRect:
s.useCopyRect = true
if s.copyRectDet == nil {
s.copyRectDet = newCopyRectDetector(tileSize)
}
return "copyrect"
case pseudoEncDesktopSize:
s.clientSupportsDesktopSize = true
return "desktop-size"
case pseudoEncExtendedDesktopSize:
s.clientSupportsExtendedDesktopSize = true
return "ext-desktop-size"
case pseudoEncDesktopName:
s.clientSupportsDesktopName = true
return "desktop-name"
case pseudoEncLastRect:
s.clientSupportsLastRect = true
return "last-rect"
case pseudoEncQEMUExtendedKeyEvent:
s.clientSupportsQEMUKey = true
return "qemu-key"
case pseudoEncExtendedClipboard:
s.clientSupportsExtClipboard = true
return "ext-clipboard"
case encTight:
s.useTight = true
return "tight"
}
if enc >= pseudoEncQualityLevelMin && enc <= pseudoEncQualityLevelMax {
s.clientJPEGQuality = int(enc - pseudoEncQualityLevelMin)
return fmt.Sprintf("quality=%d", s.clientJPEGQuality)
}
if enc >= pseudoEncCompressLevelMin && enc <= pseudoEncCompressLevelMax {
s.clientZlibLevel = int(enc - pseudoEncCompressLevelMin)
return fmt.Sprintf("compress=%d", s.clientZlibLevel)
}
return ""
}
// handleFBUpdateRequest parses the request and hands it to the encoder
// goroutine. It never blocks on capture/encode, so the input dispatch loop
// stays responsive even when a previous frame is still being encoded.
func (s *session) handleFBUpdateRequest() error {
var req [9]byte
if _, err := io.ReadFull(s.conn, req[:]); err != nil {
return fmt.Errorf("read FBUpdateRequest: %w", err)
}
r := fbRequest{incremental: req[0] == 1}
// Channel is size 1. If a request is already pending, replace it with
// this fresher one so the encoder always works on the latest ask.
select {
case s.encodeCh <- r:
default:
select {
case <-s.encodeCh:
default:
}
select {
case s.encodeCh <- r:
default:
}
}
return nil
}
// SendDesktopName pushes a DesktopName pseudo-encoded update to the
// client if it advertised support. Used by the server to keep the
// dashboard title in sync with the active session (e.g. username
// changes after login on a virtual session).
func (s *session) SendDesktopName(name string) error {
s.encMu.RLock()
supported := s.clientSupportsDesktopName
s.encMu.RUnlock()
if !supported {
s.desktopName = name
return nil
}
s.desktopName = name
header := make([]byte, 4)
header[0] = serverFramebufferUpdate
binary.BigEndian.PutUint16(header[2:4], 1)
body := encodeDesktopNameBody(name)
s.writeMu.Lock()
defer s.writeMu.Unlock()
if _, err := s.conn.Write(header); err != nil {
return err
}
_, err := s.conn.Write(body)
return err
}
func (s *session) handleKeyEvent() error {
var data [7]byte
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
return fmt.Errorf("read KeyEvent: %w", err)
}
down := data[0] == 1
keysym := binary.BigEndian.Uint32(data[3:7])
s.injector.InjectKey(keysym, down)
return nil
}
// handleQEMUMessage parses one QEMU vendor message. Today we only handle
// subtype 0 (Extended Key Event); the message itself is 12 bytes total so
// reading 11 more after the type byte covers the layout regardless of
// subtype, and unknown subtypes are dropped without aborting the session.
func (s *session) handleQEMUMessage() error {
var data [11]byte // subtype(1) + down(2) + keysym(4) + keycode(4)
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
return fmt.Errorf("read QEMU message: %w", err)
}
subtype := data[0]
if subtype != qemuSubtypeExtendedKeyEvent {
s.log.Tracef("ignoring QEMU subtype %d", subtype)
return nil
}
down := binary.BigEndian.Uint16(data[1:3]) != 0
keysym := binary.BigEndian.Uint32(data[3:7])
scancode := binary.BigEndian.Uint32(data[7:11])
s.injector.InjectKeyScancode(scancode, keysym, down)
return nil
}
func (s *session) handlePointerEvent() error {
var data [5]byte
if _, err := io.ReadFull(s.conn, data[:]); err != nil {
return fmt.Errorf("read PointerEvent: %w", err)
}
buttonMask := data[0]
x := int(binary.BigEndian.Uint16(data[1:3]))
y := int(binary.BigEndian.Uint16(data[3:5]))
s.injector.InjectPointer(buttonMask, x, y, s.serverW, s.serverH)
return nil
}