Files
netbird/client/vnc/server/server.go

691 lines
20 KiB
Go

package server
import (
"context"
"crypto/subtle"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"image"
"io"
"net"
"net/netip"
"strings"
"sync"
"time"
gojwt "github.com/golang-jwt/jwt/v5"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
// Connection modes sent by the client in the session header.
const (
ModeAttach byte = 0 // Capture current display
ModeSession byte = 1 // Virtual session as specified user
)
// RFB security-failure reason codes sent to the client. These prefixes are
// stable so dashboard/noVNC integrations can branch on them without parsing
// free text. Format: "CODE: human message".
const (
RejectCodeJWTMissing = "AUTH_JWT_MISSING"
RejectCodeJWTExpired = "AUTH_JWT_EXPIRED"
RejectCodeJWTInvalid = "AUTH_JWT_INVALID"
RejectCodeAuthForbidden = "AUTH_FORBIDDEN"
RejectCodeAuthConfig = "AUTH_CONFIG"
RejectCodeSessionError = "SESSION_ERROR"
RejectCodeCapturerError = "CAPTURER_ERROR"
RejectCodeUnsupportedOS = "UNSUPPORTED"
RejectCodeBadRequest = "BAD_REQUEST"
)
// EnvVNCDisableDownscale disables any platform-specific framebuffer
// downscaling (e.g. Retina 2:1). Set to 1/true to send the native resolution.
const EnvVNCDisableDownscale = "NB_VNC_DISABLE_DOWNSCALE"
// ScreenCapturer grabs desktop frames for the VNC server.
type ScreenCapturer interface {
// Width returns the current screen width in pixels.
Width() int
// Height returns the current screen height in pixels.
Height() int
// Capture returns the current desktop as an RGBA image.
Capture() (*image.RGBA, error)
}
// InputInjector delivers keyboard and mouse events to the OS.
type InputInjector interface {
// InjectKey simulates a key press or release. keysym is an X11 KeySym.
InjectKey(keysym uint32, down bool)
// InjectPointer simulates mouse movement and button state.
InjectPointer(buttonMask uint8, x, y, serverW, serverH int)
// SetClipboard sets the system clipboard to the given text.
SetClipboard(text string)
// GetClipboard returns the current system clipboard text.
GetClipboard() string
}
// JWTConfig holds JWT validation configuration for VNC auth.
type JWTConfig struct {
Issuer string
KeysLocation string
MaxTokenAge int64
Audiences []string
}
// connectionHeader is sent by the client before the RFB handshake to specify
// the VNC session mode and authenticate.
type connectionHeader struct {
mode byte
username string
jwt string
sessionID uint32 // Windows session ID (0 = console/auto)
}
// Server is the embedded VNC server that listens on the WireGuard interface.
// It supports two operating modes:
// - Direct mode: captures the screen and handles VNC sessions in-process.
// Used when running in a user session with desktop access.
// - Service mode: proxies VNC connections to an agent process spawned in
// the active console session. Used when running as a Windows service in
// Session 0.
//
// Within direct mode, each connection can request one of two session modes
// via the connection header:
// - Attach: capture the current physical display.
// - Session: start a virtual Xvfb display as the requested user.
type Server struct {
capturer ScreenCapturer
injector InputInjector
password string
serviceMode bool
disableAuth bool
localAddr netip.Addr // NetBird WireGuard IP this server is bound to
network netip.Prefix // NetBird overlay network
log *log.Entry
recordingDir string // when set, VNC sessions are recorded to this directory
recordingEncKey string // base64-encoded X25519 public key for encrypting recordings
mu sync.Mutex
listener net.Listener
ctx context.Context
cancel context.CancelFunc
vmgr virtualSessionManager
jwtConfig *JWTConfig
jwtValidator *nbjwt.Validator
jwtExtractor *nbjwt.ClaimsExtractor
authorizer *sshauth.Authorizer
netstackNet *netstack.Net
agentToken []byte // raw token bytes for agent-mode auth
}
// vncSession provides capturer and injector for a virtual display session.
type vncSession interface {
Capturer() ScreenCapturer
Injector() InputInjector
Display() string
ClientConnect()
ClientDisconnect()
}
// virtualSessionManager is implemented by sessionManager on Linux.
type virtualSessionManager interface {
GetOrCreate(username string) (vncSession, error)
StopAll()
}
// New creates a VNC server with the given screen capturer and input injector.
func New(capturer ScreenCapturer, injector InputInjector, password string) *Server {
return &Server{
capturer: capturer,
injector: injector,
password: password,
authorizer: sshauth.NewAuthorizer(),
log: log.WithField("component", "vnc-server"),
}
}
// SetServiceMode enables proxy-to-agent mode for Windows service operation.
func (s *Server) SetServiceMode(enabled bool) {
s.serviceMode = enabled
}
// SetJWTConfig configures JWT authentication for VNC connections.
// Pass nil to disable JWT (public mode).
func (s *Server) SetJWTConfig(config *JWTConfig) {
s.mu.Lock()
defer s.mu.Unlock()
s.jwtConfig = config
s.jwtValidator = nil
s.jwtExtractor = nil
}
// SetDisableAuth disables authentication entirely.
func (s *Server) SetDisableAuth(disable bool) {
s.disableAuth = disable
}
// SetAgentToken sets a hex-encoded token that must be presented by incoming
// connections before any VNC data. Used in agent mode to verify that only the
// trusted service process connects.
func (s *Server) SetAgentToken(hexToken string) {
if hexToken == "" {
return
}
b, err := hex.DecodeString(hexToken)
if err != nil {
s.log.Warnf("invalid agent token: %v", err)
return
}
s.agentToken = b
}
// SetNetstackNet sets the netstack network for userspace-only listening.
// When set, the VNC server listens via netstack instead of a real OS socket.
func (s *Server) SetNetstackNet(n *netstack.Net) {
s.mu.Lock()
defer s.mu.Unlock()
s.netstackNet = n
}
// SetRecordingDir enables VNC session recording to the given directory.
func (s *Server) SetRecordingDir(dir string) {
s.recordingDir = dir
}
// SetRecordingEncryptionKey sets the base64-encoded X25519 public key for encrypting recordings.
func (s *Server) SetRecordingEncryptionKey(key string) {
s.recordingEncKey = key
}
// UpdateVNCAuth updates the fine-grained authorization configuration.
func (s *Server) UpdateVNCAuth(config *sshauth.Config) {
s.mu.Lock()
defer s.mu.Unlock()
s.jwtValidator = nil
s.jwtExtractor = nil
s.authorizer.Update(config)
}
// Start begins listening for VNC connections on the given address.
// network is the NetBird overlay prefix used to validate connection sources.
func (s *Server) Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.listener != nil {
return fmt.Errorf("server already running")
}
s.ctx, s.cancel = context.WithCancel(ctx)
s.vmgr = s.platformSessionManager()
s.localAddr = addr.Addr()
s.network = network
var listener net.Listener
var listenDesc string
if s.netstackNet != nil {
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
if err != nil {
return fmt.Errorf("listen on netstack %s: %w", addr, err)
}
listener = ln
listenDesc = fmt.Sprintf("netstack %s", addr)
} else {
tcpAddr := net.TCPAddrFromAddrPort(addr)
ln, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
return fmt.Errorf("listen on %s: %w", addr, err)
}
listener = ln
listenDesc = addr.String()
}
s.listener = listener
if s.serviceMode {
s.platformInit()
}
if s.serviceMode {
go s.serviceAcceptLoop()
} else {
go s.acceptLoop()
}
s.log.Infof("started on %s (service_mode=%v)", listenDesc, s.serviceMode)
return nil
}
// Stop shuts down the server and closes all connections.
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.cancel != nil {
s.cancel()
s.cancel = nil
}
if s.vmgr != nil {
s.vmgr.StopAll()
}
if c, ok := s.capturer.(interface{ Close() }); ok {
c.Close()
}
if s.listener != nil {
err := s.listener.Close()
s.listener = nil
if err != nil {
return fmt.Errorf("close VNC listener: %w", err)
}
}
s.log.Info("stopped")
return nil
}
// acceptLoop handles VNC connections directly (user session mode).
func (s *Server) acceptLoop() {
for {
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.ctx.Done():
return
default:
}
s.log.Debugf("accept VNC connection: %v", err)
continue
}
go s.handleConnection(conn)
}
}
func (s *Server) validateCapturer(cap ScreenCapturer) error {
// Quick check first: if already ready, return immediately.
if cap.Width() > 0 && cap.Height() > 0 {
return nil
}
// Capturer not ready: poke any retry loop that supports it so it doesn't
// wait out its full backoff (e.g. macOS waiting for Screen Recording).
if w, ok := cap.(interface{ Wake() }); ok {
w.Wake()
}
// Wait up to 5s for the capturer to become ready.
for range 50 {
time.Sleep(100 * time.Millisecond)
if cap.Width() > 0 && cap.Height() > 0 {
return nil
}
}
return errors.New("no display available (check X11 on Linux or Screen Recording permission on macOS)")
}
// isAllowedSource rejects connections from outside the NetBird overlay network
// and from the local WireGuard IP (prevents local privilege escalation).
// Matches the SSH server's connectionValidator logic.
func (s *Server) isAllowedSource(addr net.Addr) bool {
tcpAddr, ok := addr.(*net.TCPAddr)
if !ok {
s.log.Warnf("connection rejected: non-TCP address %s", addr)
return false
}
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
if !ok {
s.log.Warnf("connection rejected: invalid remote IP %s", tcpAddr.IP)
return false
}
remoteIP = remoteIP.Unmap()
if remoteIP.IsLoopback() && s.localAddr.IsLoopback() {
return true
}
if remoteIP == s.localAddr {
s.log.Warnf("connection rejected from own IP %s", remoteIP)
return false
}
if s.network.IsValid() && !s.network.Contains(remoteIP) {
s.log.Warnf("connection rejected from non-NetBird IP %s", remoteIP)
return false
}
return true
}
func (s *Server) handleConnection(conn net.Conn) {
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
if !s.isAllowedSource(conn.RemoteAddr()) {
conn.Close()
return
}
if len(s.agentToken) > 0 {
buf := make([]byte, len(s.agentToken))
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
connLog.Debugf("set agent token deadline: %v", err)
conn.Close()
return
}
if _, err := io.ReadFull(conn, buf); err != nil {
connLog.Warnf("agent auth: read token: %v", err)
conn.Close()
return
}
conn.SetReadDeadline(time.Time{}) //nolint:errcheck
if subtle.ConstantTimeCompare(buf, s.agentToken) != 1 {
connLog.Warn("agent auth: invalid token, rejecting")
conn.Close()
return
}
}
header, err := readConnectionHeader(conn)
if err != nil {
connLog.Warnf("read connection header: %v", err)
conn.Close()
return
}
if !s.disableAuth {
if s.jwtConfig == nil {
rejectConnection(conn, codeMessage(RejectCodeAuthConfig, "auth enabled but no identity provider configured"))
connLog.Warn("auth rejected: no identity provider configured")
return
}
jwtUserID, err := s.authenticateJWT(header)
if err != nil {
rejectConnection(conn, codeMessage(jwtErrorCode(err), err.Error()))
connLog.Warnf("auth rejected: %v", err)
return
}
connLog = connLog.WithField("jwt_user", jwtUserID)
}
var capturer ScreenCapturer
var injector InputInjector
switch header.mode {
case ModeSession:
if s.vmgr == nil {
rejectConnection(conn, codeMessage(RejectCodeUnsupportedOS, "virtual sessions not supported on this platform"))
connLog.Warn("session rejected: not supported on this platform")
return
}
if header.username == "" {
rejectConnection(conn, codeMessage(RejectCodeBadRequest, "session mode requires a username"))
connLog.Warn("session rejected: no username provided")
return
}
vs, err := s.vmgr.GetOrCreate(header.username)
if err != nil {
rejectConnection(conn, codeMessage(RejectCodeSessionError, fmt.Sprintf("create virtual session: %v", err)))
connLog.Warnf("create virtual session for %s: %v", header.username, err)
return
}
capturer = vs.Capturer()
injector = vs.Injector()
vs.ClientConnect()
defer vs.ClientDisconnect()
connLog = connLog.WithField("vnc_user", header.username)
connLog.Infof("session mode: user=%s display=%s", header.username, vs.Display())
default:
capturer = s.capturer
injector = s.injector
if cc, ok := capturer.(interface{ ClientConnect() }); ok {
cc.ClientConnect()
}
defer func() {
if cd, ok := capturer.(interface{ ClientDisconnect() }); ok {
cd.ClientDisconnect()
}
}()
}
if err := s.validateCapturer(capturer); err != nil {
rejectConnection(conn, codeMessage(RejectCodeCapturerError, fmt.Sprintf("screen capturer: %v", err)))
connLog.Warnf("capturer not ready: %v", err)
return
}
var rec *vncRecorder
if s.recordingDir != "" {
mode := "attach"
if header.mode == ModeSession {
mode = "session"
}
jwtUser, _ := connLog.Data["jwt_user"].(string)
var err error
rec, err = newVNCRecorder(s.recordingDir, capturer.Width(), capturer.Height(), &RecordingMeta{
User: header.username,
RemoteAddr: conn.RemoteAddr().String(),
JWTUser: jwtUser,
Mode: mode,
}, s.recordingEncKey, connLog)
if err != nil {
connLog.Warnf("start VNC recording: %v", err)
}
}
sess := &session{
conn: conn,
capturer: capturer,
injector: injector,
serverW: capturer.Width(),
serverH: capturer.Height(),
password: s.password,
log: connLog,
recorder: rec,
}
sess.serve()
}
// codeMessage formats a stable reject code with a human-readable message.
// Dashboards split on the first ": " to recover the code without parsing the
// free-text suffix.
func codeMessage(code, msg string) string {
return code + ": " + msg
}
// jwtErrorCode maps a JWT auth error to a stable reject code.
func jwtErrorCode(err error) string {
if err == nil {
return RejectCodeJWTInvalid
}
if errors.Is(err, nbjwt.ErrTokenExpired) {
return RejectCodeJWTExpired
}
msg := err.Error()
switch {
case strings.Contains(msg, "JWT required but not provided"):
return RejectCodeJWTMissing
case strings.Contains(msg, "authorize") || strings.Contains(msg, "not authorized"):
return RejectCodeAuthForbidden
default:
return RejectCodeJWTInvalid
}
}
// rejectConnection sends a minimal RFB handshake with a security failure
// reason, so VNC clients display the error message instead of a generic
// "unexpected disconnect."
func rejectConnection(conn net.Conn, reason string) {
defer conn.Close()
// RFB 3.8 server version.
io.WriteString(conn, "RFB 003.008\n")
// Read client version (12 bytes), ignore errors.
var clientVer [12]byte
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
io.ReadFull(conn, clientVer[:])
conn.SetReadDeadline(time.Time{})
// Send 0 security types = connection failed, followed by reason.
msg := []byte(reason)
buf := make([]byte, 1+4+len(msg))
buf[0] = 0 // 0 security types = failure
binary.BigEndian.PutUint32(buf[1:5], uint32(len(msg)))
copy(buf[5:], msg)
conn.Write(buf)
}
const defaultJWTMaxTokenAge = 10 * 60 // 10 minutes
// authenticateJWT validates the JWT from the connection header and checks
// authorization. For attach mode, just checks membership in the authorized
// user list. For session mode, additionally validates the OS user mapping.
func (s *Server) authenticateJWT(header *connectionHeader) (string, error) {
if header.jwt == "" {
return "", fmt.Errorf("JWT required but not provided")
}
s.mu.Lock()
if err := s.ensureJWTValidator(); err != nil {
s.mu.Unlock()
return "", fmt.Errorf("initialize JWT validator: %w", err)
}
validator := s.jwtValidator
extractor := s.jwtExtractor
s.mu.Unlock()
token, err := validator.ValidateAndParse(context.Background(), header.jwt)
if err != nil {
return "", fmt.Errorf("validate JWT: %w", err)
}
if err := s.checkTokenAge(token); err != nil {
return "", err
}
userAuth, err := extractor.ToUserAuth(token)
if err != nil {
return "", fmt.Errorf("extract user from JWT: %w", err)
}
if userAuth.UserId == "" {
return "", fmt.Errorf("JWT has no user ID")
}
switch header.mode {
case ModeSession:
// Session mode: check user + OS username mapping.
if _, err := s.authorizer.Authorize(userAuth.UserId, header.username); err != nil {
return "", fmt.Errorf("authorize session for %s: %w", header.username, err)
}
default:
// Attach mode: just check user is in the authorized list (wildcard OS user).
if _, err := s.authorizer.Authorize(userAuth.UserId, "*"); err != nil {
return "", fmt.Errorf("user not authorized for VNC: %w", err)
}
}
return userAuth.UserId, nil
}
// ensureJWTValidator lazily initializes the JWT validator. Must be called with mu held.
func (s *Server) ensureJWTValidator() error {
if s.jwtValidator != nil && s.jwtExtractor != nil {
return nil
}
if s.jwtConfig == nil {
return fmt.Errorf("no JWT config")
}
s.jwtValidator = nbjwt.NewValidator(
s.jwtConfig.Issuer,
s.jwtConfig.Audiences,
s.jwtConfig.KeysLocation,
false,
)
opts := []nbjwt.ClaimsExtractorOption{nbjwt.WithAudience(s.jwtConfig.Audiences[0])}
if claim := s.authorizer.GetUserIDClaim(); claim != "" {
opts = append(opts, nbjwt.WithUserIDClaim(claim))
}
s.jwtExtractor = nbjwt.NewClaimsExtractor(opts...)
return nil
}
func (s *Server) checkTokenAge(token *gojwt.Token) error {
maxAge := defaultJWTMaxTokenAge
if s.jwtConfig != nil && s.jwtConfig.MaxTokenAge > 0 {
maxAge = int(s.jwtConfig.MaxTokenAge)
}
return nbjwt.CheckTokenAge(token, time.Duration(maxAge)*time.Second)
}
// readConnectionHeader reads the NetBird VNC session header from the connection.
// Format: [mode: 1 byte] [username_len: 2 bytes BE] [username: N bytes]
//
// [jwt_len: 2 bytes BE] [jwt: N bytes]
//
// Uses a short timeout: our WASM proxy sends the header immediately after
// connecting. Standard VNC clients don't send anything first (server speaks
// first in RFB), so they time out and get the default attach mode.
func readConnectionHeader(conn net.Conn) (*connectionHeader, error) {
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
defer conn.SetReadDeadline(time.Time{}) //nolint:errcheck
var hdr [3]byte
if _, err := io.ReadFull(conn, hdr[:]); err != nil {
// Timeout or error: assume no header, use attach mode.
return &connectionHeader{mode: ModeAttach}, nil
}
// Restore a longer deadline for reading variable-length fields.
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
mode := hdr[0]
usernameLen := binary.BigEndian.Uint16(hdr[1:3])
var username string
if usernameLen > 0 {
if usernameLen > 256 {
return nil, fmt.Errorf("username too long: %d", usernameLen)
}
buf := make([]byte, usernameLen)
if _, err := io.ReadFull(conn, buf); err != nil {
return nil, fmt.Errorf("read username: %w", err)
}
username = string(buf)
}
// Read JWT token length and data.
var jwtLenBuf [2]byte
var jwtToken string
if _, err := io.ReadFull(conn, jwtLenBuf[:]); err == nil {
jwtLen := binary.BigEndian.Uint16(jwtLenBuf[:])
if jwtLen > 0 && jwtLen < 8192 {
buf := make([]byte, jwtLen)
if _, err := io.ReadFull(conn, buf); err != nil {
return nil, fmt.Errorf("read JWT: %w", err)
}
jwtToken = string(buf)
}
}
// Read optional Windows session ID (4 bytes BE). Missing = 0 (console/auto).
var sessionID uint32
var sidBuf [4]byte
if _, err := io.ReadFull(conn, sidBuf[:]); err == nil {
sessionID = binary.BigEndian.Uint32(sidBuf[:])
}
return &connectionHeader{mode: mode, username: username, jwt: jwtToken, sessionID: sessionID}, nil
}