mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-26 18:49:56 +00:00
Compare commits
8 Commits
embedded-v
...
windows-dn
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cca46f070b | ||
|
|
5f8b88471f | ||
|
|
f42b8aed90 | ||
|
|
0415137acd | ||
|
|
7fd16666e3 | ||
|
|
0571eeaba0 | ||
|
|
6a201d12b5 | ||
|
|
4810e79a00 |
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@@ -12,7 +12,6 @@
|
||||
- [ ] Is a feature enhancement
|
||||
- [ ] It is a refactor
|
||||
- [ ] Created tests that fail without the change (if possible)
|
||||
- [ ] This change does **not** modify the public API, gRPC protocols, functionality behavior, CLI / service flags, or introduce a new feature — **OR** I have discussed it with the NetBird team beforehand (link the issue / Slack thread in the description). See [CONTRIBUTING.md](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTING.md#discuss-changes-with-the-netbird-team-first).
|
||||
|
||||
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
|
||||
|
||||
|
||||
4
.github/workflows/wasm-build-validation.yml
vendored
4
.github/workflows/wasm-build-validation.yml
vendored
@@ -61,8 +61,8 @@ jobs:
|
||||
|
||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||
|
||||
if [ ${SIZE} -gt 62914560 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 60MB limit!"
|
||||
if [ ${SIZE} -gt 58720256 ]; then
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
@@ -92,6 +92,9 @@ linters:
|
||||
- linters:
|
||||
- unused
|
||||
path: client/firewall/iptables/rule\.go
|
||||
- linters:
|
||||
- unused
|
||||
path: client/internal/dns/dnsfw/(types|syscall|zsyscall)_windows.*\.go
|
||||
- linters:
|
||||
- gosec
|
||||
- mirror
|
||||
|
||||
@@ -15,7 +15,6 @@ If you haven't already, join our slack workspace [here](https://docs.netbird.io/
|
||||
- [Contributing to NetBird](#contributing-to-netbird)
|
||||
- [Contents](#contents)
|
||||
- [Code of conduct](#code-of-conduct)
|
||||
- [Discuss changes with the NetBird team first](#discuss-changes-with-the-netbird-team-first)
|
||||
- [Directory structure](#directory-structure)
|
||||
- [Development setup](#development-setup)
|
||||
- [Requirements](#requirements)
|
||||
@@ -34,14 +33,6 @@ Conduct which can be found in the file [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md).
|
||||
By participating, you are expected to uphold this code. Please report
|
||||
unacceptable behavior to community@netbird.io.
|
||||
|
||||
## Discuss changes with the NetBird team first
|
||||
|
||||
Changes to the **public API**, **gRPC protocols**, **functionality behavior**, **CLI / service flags**, or **new features** should be discussed with the NetBird team before you start the work. These surfaces are part of NetBird's contract with operators, self-hosters, and downstream integrators, and changes to them have compatibility, security, and release-planning implications that benefit from an early conversation.
|
||||
|
||||
Open an issue or reach out on [Slack](https://docs.netbird.io/slack-url) to talk through what you have in mind. We'll help shape the change, flag any constraints we know about, and confirm the direction so the PR review can focus on implementation rather than design.
|
||||
|
||||
Typical bug fixes, internal refactors, documentation updates, and tests do not need pre-discussion — open the PR directly.
|
||||
|
||||
## Directory structure
|
||||
|
||||
The NetBird project monorepo is organized to maintain most of its individual dependencies code within their directories, except for a few auxiliary or shared packages.
|
||||
|
||||
106
client/cmd/up.go
106
client/cmd/up.go
@@ -361,12 +361,6 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
req.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
req.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
if cmd.Flag(disableVNCApprovalFlag).Changed {
|
||||
req.DisableVNCApproval = &disableVNCApproval
|
||||
}
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
req.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
@@ -473,14 +467,30 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
ic.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
if cmd.Flag(disableVNCApprovalFlag).Changed {
|
||||
ic.DisableVNCApproval = &disableVNCApproval
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
ic.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
|
||||
applySSHFlagsToConfig(cmd, &ic)
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
ic.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
ic.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
@@ -556,49 +566,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
return &ic, nil
|
||||
}
|
||||
|
||||
func applySSHFlagsToConfig(cmd *cobra.Command, ic *profilemanager.ConfigInput) {
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
ic.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
ic.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
ic.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
|
||||
}
|
||||
}
|
||||
|
||||
func applySSHFlagsToLogin(cmd *cobra.Command, req *proto.LoginRequest) {
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
req.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
req.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
req.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
ttl := int32(sshJWTCacheTTL)
|
||||
req.SshJWTCacheTTL = &ttl
|
||||
}
|
||||
}
|
||||
|
||||
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
@@ -628,14 +595,31 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
loginRequest.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
if cmd.Flag(disableVNCApprovalFlag).Changed {
|
||||
loginRequest.DisableVNCApproval = &disableVNCApproval
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
loginRequest.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
|
||||
applySSHFlagsToLogin(cmd, &loginRequest)
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
loginRequest.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
loginRequest.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
//go:build windows || (darwin && !ios)
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
var (
|
||||
vncAgentSocket string
|
||||
vncAgentTargetUID uint32
|
||||
)
|
||||
|
||||
func init() {
|
||||
vncAgentCmd.Flags().StringVar(&vncAgentSocket, "socket", "", "Unix-domain socket path the agent listens on (required)")
|
||||
vncAgentCmd.Flags().Uint32Var(&vncAgentTargetUID, "target-uid", 0, "uid the agent should drop privileges to before listening (darwin only; 0 = stay as current uid)")
|
||||
rootCmd.AddCommand(vncAgentCmd)
|
||||
}
|
||||
|
||||
// vncAgentCmd runs a VNC server inside the user's interactive session,
|
||||
// listening on a Unix-domain socket. The NetBird service spawns it: on
|
||||
// Windows via CreateProcessAsUser into the console session, on macOS via
|
||||
// launchctl asuser into the Aqua session.
|
||||
var vncAgentCmd = &cobra.Command{
|
||||
Use: "vnc-agent",
|
||||
Short: "Run VNC capture agent (internal, spawned by service)",
|
||||
Hidden: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
log.SetReportCaller(true)
|
||||
log.SetFormatter(&log.JSONFormatter{})
|
||||
log.SetOutput(os.Stderr)
|
||||
|
||||
if vncAgentSocket == "" {
|
||||
return fmt.Errorf("--socket is required")
|
||||
}
|
||||
|
||||
token := os.Getenv("NB_VNC_AGENT_TOKEN")
|
||||
if token == "" {
|
||||
return fmt.Errorf("NB_VNC_AGENT_TOKEN not set; agent requires a token from the service")
|
||||
}
|
||||
// Purge the token from env so it doesn't leak via /proc/<pid>/environ.
|
||||
if err := os.Unsetenv("NB_VNC_AGENT_TOKEN"); err != nil {
|
||||
log.Debugf("unset NB_VNC_AGENT_TOKEN: %v", err)
|
||||
}
|
||||
|
||||
// Drop root privileges to the target console user BEFORE creating
|
||||
// the listening socket: keeps a post-auth bug in the encoder /
|
||||
// input / capture paths confined to the user's own privileges
|
||||
// rather than escalating to host root, and makes the daemon's
|
||||
// LOCAL_PEERCRED check see the right uid. No-op on Windows
|
||||
// (both processes run as SYSTEM) and when --target-uid is 0.
|
||||
if vncAgentTargetUID != 0 {
|
||||
if err := dropAgentPrivileges(vncAgentTargetUID); err != nil {
|
||||
return fmt.Errorf("drop privileges to uid %d: %w", vncAgentTargetUID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := os.Remove(vncAgentSocket); err != nil && !os.IsNotExist(err) {
|
||||
log.Debugf("remove stale socket %s: %v", vncAgentSocket, err)
|
||||
}
|
||||
ln, err := net.Listen("unix", vncAgentSocket)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", vncAgentSocket, err)
|
||||
}
|
||||
if err := os.Chmod(vncAgentSocket, 0o600); err != nil {
|
||||
log.Debugf("chmod %s: %v", vncAgentSocket, err)
|
||||
}
|
||||
|
||||
capturer, injector, err := newAgentResources()
|
||||
if err != nil {
|
||||
_ = ln.Close()
|
||||
return err
|
||||
}
|
||||
srv := vncserver.New(vncserver.Config{
|
||||
Capturer: capturer,
|
||||
Injector: injector,
|
||||
DisableAuth: true,
|
||||
AgentTokenHex: token,
|
||||
Listener: ln,
|
||||
})
|
||||
|
||||
if err := srv.Start(cmd.Context(), netip.AddrPort{}, netip.Prefix{}); err != nil {
|
||||
return fmt.Errorf("start vnc server: %w", err)
|
||||
}
|
||||
log.Infof("vnc-agent listening on %s, ready", vncAgentSocket)
|
||||
|
||||
<-cmd.Context().Done()
|
||||
log.Info("vnc-agent context cancelled, shutting down")
|
||||
return srv.Stop()
|
||||
},
|
||||
SilenceUsage: true,
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newAgentResources() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
|
||||
capturer := vncserver.NewMacPoller()
|
||||
injector, err := vncserver.NewMacInputInjector()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("macOS input injector: %w", err)
|
||||
}
|
||||
return capturer, injector, nil
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// dropAgentPrivileges drops the vnc-agent process from root (its
|
||||
// launchctl-asuser-inherited starting uid) to the target console user
|
||||
// before any other initialisation runs. Without this the agent runs as
|
||||
// root for the lifetime of the session; any post-auth memory-safety
|
||||
// issue in the capture/input/encode paths would then be a root-level
|
||||
// RCE on the host instead of a user-level one. Also makes the daemon's
|
||||
// LOCAL_PEERCRED check correctly identify the agent as the console user,
|
||||
// not as root.
|
||||
//
|
||||
// Returns an error when the agent is running as a non-root uid that
|
||||
// differs from targetUID: non-root can only setuid to itself, so a
|
||||
// mismatch here means the spawn went to the wrong session.
|
||||
func dropAgentPrivileges(targetUID uint32) error {
|
||||
if targetUID == 0 {
|
||||
return fmt.Errorf("refusing to keep agent running as root (target uid 0)")
|
||||
}
|
||||
cur := uint32(os.Getuid())
|
||||
if cur == targetUID {
|
||||
return nil
|
||||
}
|
||||
if cur != 0 {
|
||||
return fmt.Errorf("agent uid %d does not match expected %d and we lack root to fix it", cur, targetUID)
|
||||
}
|
||||
// Drop supplementary groups first: setgid alone doesn't touch the
|
||||
// auxiliary group list, leaving root's groups attached would let the
|
||||
// dropped process write to root-only group-writable files.
|
||||
if err := syscall.Setgroups([]int{}); err != nil {
|
||||
return fmt.Errorf("setgroups([]): %w", err)
|
||||
}
|
||||
if err := syscall.Setgid(int(targetUID)); err != nil {
|
||||
return fmt.Errorf("setgid(%d): %w", targetUID, err)
|
||||
}
|
||||
if err := syscall.Setuid(int(targetUID)); err != nil {
|
||||
return fmt.Errorf("setuid(%d): %w", targetUID, err)
|
||||
}
|
||||
if uint32(os.Getuid()) != targetUID || uint32(os.Geteuid()) != targetUID {
|
||||
return fmt.Errorf("setuid verification: uid=%d euid=%d, expected %d", os.Getuid(), os.Geteuid(), targetUID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestDropAgentPrivileges_RefusesRootTarget locks in the contract that
|
||||
// dropAgentPrivileges must never be a no-op when asked to keep the
|
||||
// agent as root (target uid 0). A future caller that passes 0 by
|
||||
// mistake would otherwise leave the post-auth attack surface running
|
||||
// with full root privileges.
|
||||
func TestDropAgentPrivileges_RefusesRootTarget(t *testing.T) {
|
||||
err := dropAgentPrivileges(0)
|
||||
if err == nil {
|
||||
t.Fatal("expected refusal for target uid 0, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "root") {
|
||||
t.Fatalf("error should mention root, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDropAgentPrivileges_NoOpWhenAlreadyTarget covers the dev path
|
||||
// where the agent is launched by hand as the target user (no root
|
||||
// available, no setuid needed). The helper must succeed silently
|
||||
// instead of trying (and failing) a setuid to its current uid.
|
||||
func TestDropAgentPrivileges_NoOpWhenAlreadyTarget(t *testing.T) {
|
||||
// Skip when running as root: the early-return path we want to
|
||||
// cover only fires when current uid == target uid.
|
||||
uid := currentUIDForTest()
|
||||
if uid == 0 {
|
||||
t.Skip("test must not run as root; cannot exercise the no-op early-return")
|
||||
}
|
||||
if err := dropAgentPrivileges(uid); err != nil {
|
||||
t.Fatalf("expected no-op when current uid == target, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDropAgentPrivileges_RefusesMismatchedNonRoot guards the "non-root
|
||||
// caller tries to setuid to a different uid" path: setuid would fail
|
||||
// with EPERM anyway, but the helper should surface a clear error
|
||||
// before issuing the syscall so a misconfigured spawn (wrong --target-uid
|
||||
// flag) is debuggable.
|
||||
func TestDropAgentPrivileges_RefusesMismatchedNonRoot(t *testing.T) {
|
||||
uid := currentUIDForTest()
|
||||
if uid == 0 {
|
||||
t.Skip("test must not run as root; covered case requires non-root caller")
|
||||
}
|
||||
err := dropAgentPrivileges(uid + 1)
|
||||
if err == nil {
|
||||
t.Fatal("expected refusal when non-root caller asks to setuid elsewhere")
|
||||
}
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package cmd
|
||||
|
||||
import "os"
|
||||
|
||||
// currentUIDForTest exposes os.Getuid for the darwin dropprivs tests
|
||||
// without leaking an os import into the test file itself.
|
||||
func currentUIDForTest() uint32 {
|
||||
return uint32(os.Getuid())
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package cmd
|
||||
|
||||
// dropAgentPrivileges is a no-op on Windows: the agent and the daemon
|
||||
// both run as SYSTEM (the daemon spawns the agent into the interactive
|
||||
// session via CreateProcessAsUser with an impersonation token, but the
|
||||
// resulting process still runs under SYSTEM, not under the user's
|
||||
// account). The Windows path relies on the C:\Windows\Temp socket
|
||||
// location (admin/SYSTEM-write-only) and the per-spawn token for
|
||||
// integrity instead.
|
||||
func dropAgentPrivileges(_ uint32) error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newAgentResources() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
|
||||
sessionID := vncserver.GetCurrentSessionID()
|
||||
log.Infof("VNC agent running in Windows session %d", sessionID)
|
||||
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector(), nil
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
package cmd
|
||||
|
||||
const (
|
||||
serverVNCAllowedFlag = "allow-server-vnc"
|
||||
disableVNCApprovalFlag = "disable-vnc-approval"
|
||||
)
|
||||
|
||||
var (
|
||||
serverVNCAllowed bool
|
||||
disableVNCApproval bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&serverVNCAllowed, serverVNCAllowedFlag, false, "Allow embedded VNC server on peer")
|
||||
upCmd.PersistentFlags().BoolVar(&disableVNCApproval, disableVNCApprovalFlag, false, "Disable per-connection user approval prompts for the embedded VNC server")
|
||||
}
|
||||
@@ -6,30 +6,19 @@ import (
|
||||
"runtime"
|
||||
)
|
||||
|
||||
var (
|
||||
// StateDir holds persistent state (config, profiles, install metadata).
|
||||
StateDir string
|
||||
// RuntimeDir holds ephemeral artifacts that should not survive reboot,
|
||||
// such as Unix sockets for daemon and per-session IPC. Empty on
|
||||
// platforms without a conventional /var/run-style location.
|
||||
RuntimeDir string
|
||||
)
|
||||
var StateDir string
|
||||
|
||||
func init() {
|
||||
StateDir = os.Getenv("NB_STATE_DIR")
|
||||
if StateDir != "" {
|
||||
return
|
||||
}
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
StateDir = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird")
|
||||
case "darwin", "linux":
|
||||
StateDir = "/var/lib/netbird"
|
||||
RuntimeDir = "/var/run/netbird"
|
||||
case "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||
StateDir = "/var/db/netbird"
|
||||
RuntimeDir = "/var/run/netbird"
|
||||
}
|
||||
if v := os.Getenv("NB_STATE_DIR"); v != "" {
|
||||
StateDir = v
|
||||
}
|
||||
if v := os.Getenv("NB_RUNTIME_DIR"); v != "" {
|
||||
RuntimeDir = v
|
||||
}
|
||||
}
|
||||
|
||||
@@ -260,23 +260,15 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
|
||||
|
||||
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
||||
|
||||
; Drop Run, App Paths and Uninstall entries left in the 32-bit registry view
|
||||
; or HKCU by legacy installers.
|
||||
DetailPrint "Cleaning legacy 32-bit / HKCU entries..."
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
SetRegView 32
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DeleteRegKey HKLM "${REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UNINSTALL_PATH}"
|
||||
SetRegView 64
|
||||
|
||||
; Create autostart registry entry based on checkbox
|
||||
DetailPrint "Autostart enabled: $AutostartEnabled"
|
||||
${If} $AutostartEnabled == "1"
|
||||
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
|
||||
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
|
||||
${Else}
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DetailPrint "Autostart not enabled by user"
|
||||
${EndIf}
|
||||
|
||||
@@ -307,16 +299,11 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
||||
DetailPrint "Terminating Netbird UI process..."
|
||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||
|
||||
; Remove autostart entries from every view a previous installer may have used.
|
||||
; Remove autostart registry entry
|
||||
DetailPrint "Removing autostart registry entry if exists..."
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
SetRegView 32
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DeleteRegKey HKLM "${REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UNINSTALL_PATH}"
|
||||
SetRegView 64
|
||||
|
||||
; Handle data deletion based on checkbox
|
||||
DetailPrint "Checking if user requested data deletion..."
|
||||
|
||||
@@ -1,219 +0,0 @@
|
||||
// Package approval brokers per-attempt user-accept prompts for inbound
|
||||
// remote access (VNC today, SSH and others in the future). A caller pushes
|
||||
// a Prompt; the broker emits a SystemEvent on the daemon→UI stream and
|
||||
// blocks until the UI calls the daemon's RespondApproval RPC, the per-
|
||||
// request timeout fires, or no subscriber is connected. The latter case
|
||||
// fails closed so a backgrounded UI cannot silently bypass the gate.
|
||||
package approval
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// Metadata keys the broker reserves on the emitted SystemEvent. Callers
|
||||
// should not set these themselves; values in Prompt.Metadata that collide
|
||||
// are overwritten by the broker.
|
||||
const (
|
||||
MetaRequestID = "request_id"
|
||||
MetaKind = "kind"
|
||||
MetaExpiresAt = "expires_at"
|
||||
)
|
||||
|
||||
// ShortKeyFingerprint formats a hex-encoded Noise_IK static pubkey as a
|
||||
// short, eyeball-able fingerprint to display in the approval dialog.
|
||||
// The dashboard-supplied display name attached to a SessionPubKey isn't
|
||||
// cryptographically asserted by the connecting client, so the prompt
|
||||
// must also show something that IS: the key fingerprint, a hash of
|
||||
// the static public key the client just proved possession of during the
|
||||
// Noise handshake. Returns the empty string when the input is too short
|
||||
// to plausibly be a hex pubkey, so the row is omitted rather than
|
||||
// rendered as a misleading partial.
|
||||
//
|
||||
// Output format: 16 hex chars grouped as XXXX-XXXX-XXXX-XXXX (64 bits of
|
||||
// fingerprint, resistant to random-prefix collisions and easy for a human
|
||||
// to compare with an out-of-band reference).
|
||||
func ShortKeyFingerprint(hexKey string) string {
|
||||
if len(hexKey) < 8 {
|
||||
return ""
|
||||
}
|
||||
src := hexKey
|
||||
if len(src) > 16 {
|
||||
src = src[:16]
|
||||
}
|
||||
var out []byte
|
||||
for i, c := range src {
|
||||
if i > 0 && i%4 == 0 {
|
||||
out = append(out, '-')
|
||||
}
|
||||
out = append(out, byte(c))
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
|
||||
// Kind values for the well-known prompt subjects. New subsystems should
|
||||
// add a constant here so the UI can dispatch on a known string.
|
||||
const (
|
||||
KindVNC = "vnc"
|
||||
KindSSH = "ssh"
|
||||
)
|
||||
|
||||
// DefaultTimeout is the wall-clock window the user has to accept or deny a
|
||||
// pending approval before the broker fails closed and returns ErrTimeout.
|
||||
// Kept well under typical VNC client and dashboard connection timeouts so
|
||||
// the RFB rejection actually reaches the browser instead of racing the
|
||||
// browser's own "connection timed out" message.
|
||||
const DefaultTimeout = 15 * time.Second
|
||||
|
||||
// timeoutValue returns the active timeout. It's a var so tests in this
|
||||
// package can shorten the wait without exposing a setter on the public
|
||||
// API. Production code always sees DefaultTimeout.
|
||||
var timeoutValue = func() time.Duration { return DefaultTimeout }
|
||||
|
||||
// ErrNoSubscriber indicates no UI is connected to consume the prompt.
|
||||
// The caller must reject the underlying connection (fail-closed).
|
||||
var ErrNoSubscriber = errors.New("no UI subscriber connected for approval")
|
||||
|
||||
// ErrTimeout indicates the user did not respond within DefaultTimeout.
|
||||
var ErrTimeout = errors.New("approval timed out")
|
||||
|
||||
// ErrDenied indicates the user explicitly denied the connection.
|
||||
var ErrDenied = errors.New("approval denied")
|
||||
|
||||
// EventPublisher is the subset of peer.Status used to emit prompts.
|
||||
type EventPublisher interface {
|
||||
PublishEvent(
|
||||
severity proto.SystemEvent_Severity,
|
||||
category proto.SystemEvent_Category,
|
||||
msg string,
|
||||
userMsg string,
|
||||
metadata map[string]string,
|
||||
)
|
||||
HasEventSubscribers() bool
|
||||
}
|
||||
|
||||
// Prompt describes the pending request shown to the user. Kind selects
|
||||
// the UI dispatch path (e.g. "vnc", "ssh"). Subject is the human-readable
|
||||
// one-liner the UI may show as a title or notification body. Metadata is
|
||||
// passed through verbatim and is the subsystem-specific payload (peer
|
||||
// name, source IP, mode, etc.).
|
||||
type Prompt struct {
|
||||
Kind string
|
||||
Subject string
|
||||
Metadata map[string]string
|
||||
}
|
||||
|
||||
// Decision carries the user's response to an approval prompt. ViewOnly is
|
||||
// only meaningful when Accept is true; it lets the host grant the
|
||||
// connection but signal the requester that input control is withheld.
|
||||
type Decision struct {
|
||||
Accept bool
|
||||
ViewOnly bool
|
||||
}
|
||||
|
||||
// Broker holds in-flight approval requests keyed by request ID.
|
||||
type Broker struct {
|
||||
pub EventPublisher
|
||||
|
||||
mu sync.Mutex
|
||||
pending map[string]chan Decision
|
||||
}
|
||||
|
||||
// New returns a broker that publishes prompts via pub.
|
||||
func New(pub EventPublisher) *Broker {
|
||||
return &Broker{
|
||||
pub: pub,
|
||||
pending: make(map[string]chan Decision),
|
||||
}
|
||||
}
|
||||
|
||||
// Request emits a SystemEvent for p and blocks until the UI calls Respond,
|
||||
// ctx is cancelled, or DefaultTimeout elapses. Returns a Decision when
|
||||
// the user replied; ErrDenied / ErrTimeout / ErrNoSubscriber / ctx.Err
|
||||
// otherwise. Callers must treat any non-nil error as a deny.
|
||||
func (b *Broker) Request(ctx context.Context, p Prompt) (Decision, error) {
|
||||
var zero Decision
|
||||
if b == nil || b.pub == nil {
|
||||
return zero, fmt.Errorf("approval broker not configured")
|
||||
}
|
||||
if !b.pub.HasEventSubscribers() {
|
||||
return zero, ErrNoSubscriber
|
||||
}
|
||||
|
||||
id := uuid.NewString()
|
||||
resp := make(chan Decision, 1)
|
||||
|
||||
b.mu.Lock()
|
||||
b.pending[id] = resp
|
||||
b.mu.Unlock()
|
||||
|
||||
defer b.dropPending(id)
|
||||
|
||||
timeout := timeoutValue()
|
||||
expiresAt := time.Now().Add(timeout)
|
||||
meta := make(map[string]string, len(p.Metadata)+3)
|
||||
for k, v := range p.Metadata {
|
||||
meta[k] = v
|
||||
}
|
||||
meta[MetaRequestID] = id
|
||||
meta[MetaKind] = p.Kind
|
||||
meta[MetaExpiresAt] = expiresAt.UTC().Format(time.RFC3339)
|
||||
|
||||
subject := p.Subject
|
||||
if subject == "" {
|
||||
subject = fmt.Sprintf("%s connection requires approval", p.Kind)
|
||||
}
|
||||
b.pub.PublishEvent(proto.SystemEvent_INFO, proto.SystemEvent_APPROVAL, subject, subject, meta)
|
||||
log.Debugf("approval request %s (%s) emitted: %s", id, p.Kind, subject)
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case d := <-resp:
|
||||
if !d.Accept {
|
||||
return zero, ErrDenied
|
||||
}
|
||||
return d, nil
|
||||
case <-timer.C:
|
||||
return zero, ErrTimeout
|
||||
case <-ctx.Done():
|
||||
return zero, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Respond delivers the user's decision for id. Returns true when a pending
|
||||
// request matched and was woken, false when id was unknown or already done.
|
||||
func (b *Broker) Respond(id string, d Decision) bool {
|
||||
if b == nil {
|
||||
return false
|
||||
}
|
||||
b.mu.Lock()
|
||||
ch, ok := b.pending[id]
|
||||
if ok {
|
||||
delete(b.pending, id)
|
||||
}
|
||||
b.mu.Unlock()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case ch <- d:
|
||||
default:
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (b *Broker) dropPending(id string) {
|
||||
b.mu.Lock()
|
||||
delete(b.pending, id)
|
||||
b.mu.Unlock()
|
||||
}
|
||||
@@ -1,434 +0,0 @@
|
||||
package approval
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// fakePublisher records published events and reports whether subscribers
|
||||
// are connected. The subscribers flag is the security-critical signal:
|
||||
// when false the broker must refuse to emit and the gate must fail closed.
|
||||
type fakePublisher struct {
|
||||
mu sync.Mutex
|
||||
subscribers bool
|
||||
events []*proto.SystemEvent
|
||||
}
|
||||
|
||||
func (p *fakePublisher) PublishEvent(
|
||||
severity proto.SystemEvent_Severity,
|
||||
category proto.SystemEvent_Category,
|
||||
msg string,
|
||||
userMsg string,
|
||||
metadata map[string]string,
|
||||
) {
|
||||
p.mu.Lock()
|
||||
p.events = append(p.events, &proto.SystemEvent{
|
||||
Severity: severity,
|
||||
Category: category,
|
||||
Message: msg,
|
||||
UserMessage: userMsg,
|
||||
Metadata: metadata,
|
||||
})
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
func (p *fakePublisher) HasEventSubscribers() bool {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.subscribers
|
||||
}
|
||||
|
||||
func (p *fakePublisher) lastEvent(t *testing.T) *proto.SystemEvent {
|
||||
t.Helper()
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
require.NotEmpty(t, p.events, "publisher saw no events")
|
||||
return p.events[len(p.events)-1]
|
||||
}
|
||||
|
||||
func (p *fakePublisher) eventCount() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return len(p.events)
|
||||
}
|
||||
|
||||
// TestRequestNoSubscriberFailsClosed is the core fail-closed invariant:
|
||||
// when the UI is not subscribed, the broker must refuse without emitting
|
||||
// an event or arming a waiter. A regression here is a silent bypass.
|
||||
func TestRequestNoSubscriberFailsClosed(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: false}
|
||||
b := New(pub)
|
||||
|
||||
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
|
||||
assert.ErrorIs(t, err, ErrNoSubscriber)
|
||||
assert.Equal(t, 0, pub.eventCount(), "no event must be emitted when fail-closed")
|
||||
|
||||
b.mu.Lock()
|
||||
pending := len(b.pending)
|
||||
b.mu.Unlock()
|
||||
assert.Equal(t, 0, pending, "no waiter must be registered on fail-closed")
|
||||
}
|
||||
|
||||
// TestRequestTimeoutDenies verifies that a request without a UI response
|
||||
// returns ErrTimeout (deny) rather than nil (silent accept). Uses a short
|
||||
// per-test broker timeout via Respond after the fact to keep the test fast.
|
||||
func TestRequestTimeoutDenies(t *testing.T) {
|
||||
// Replace DefaultTimeout for the lifetime of this test.
|
||||
orig := DefaultTimeout
|
||||
defaultTimeout(t, 60*time.Millisecond)
|
||||
defer defaultTimeout(t, orig)
|
||||
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
start := time.Now()
|
||||
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
|
||||
assert.ErrorIs(t, err, ErrTimeout, "missing user response must yield ErrTimeout, not nil")
|
||||
assert.GreaterOrEqual(t, time.Since(start), 50*time.Millisecond, "timeout fired prematurely")
|
||||
}
|
||||
|
||||
// TestRequestDenied returns ErrDenied when the UI responds with false.
|
||||
func TestRequestDenied(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
var requestID string
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
|
||||
}()
|
||||
|
||||
requestID = waitForRequestID(t, pub)
|
||||
require.True(t, b.Respond(requestID, Decision{Accept: false}))
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assert.ErrorIs(t, err, ErrDenied)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Request did not return after Respond(false)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRequestAccepted is the happy path. Failure here doesn't bypass the
|
||||
// gate but breaks the feature.
|
||||
func TestRequestAccepted(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
|
||||
}()
|
||||
|
||||
id := waitForRequestID(t, pub)
|
||||
require.True(t, b.Respond(id, Decision{Accept: true}))
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assert.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Request did not return after Respond(true)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRequestCtxCancelDenies verifies that an upstream cancel (e.g. the
|
||||
// engine shutting down mid-prompt) returns the cancel error rather than
|
||||
// nil. A nil here would be a silent bypass on shutdown races.
|
||||
func TestRequestCtxCancelDenies(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, ctx, Prompt{Kind: KindVNC, Subject: "test"})
|
||||
}()
|
||||
|
||||
// Wait until the prompt is in flight so cancel races a live waiter.
|
||||
_ = waitForRequestID(t, pub)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Request did not return after ctx cancel")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRespondUnknownIsNoop ensures a stray RespondApproval RPC cannot
|
||||
// affect or accidentally accept any in-flight request whose id it doesn't
|
||||
// match. Also confirms it doesn't panic.
|
||||
func TestRespondUnknownIsNoop(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
// No in-flight prompts: Respond returns false.
|
||||
assert.False(t, b.Respond("does-not-exist", Decision{Accept: true}))
|
||||
|
||||
// With an in-flight prompt, a wrong id still returns false and the
|
||||
// prompt remains armed (eventually timing out as a deny).
|
||||
defaultTimeout(t, 60*time.Millisecond)
|
||||
defer defaultTimeout(t, DefaultTimeout)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
|
||||
}()
|
||||
realID := waitForRequestID(t, pub)
|
||||
assert.False(t, b.Respond("totally-bogus", Decision{Accept: true}), "unknown id must not match")
|
||||
assert.NotEqual(t, "totally-bogus", realID)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assert.ErrorIs(t, err, ErrTimeout, "armed prompt must still time out, not accept")
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("prompt did not resolve")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRespondAfterTimeoutNoop confirms a late accept response can't
|
||||
// retroactively flip a denied (timed-out) request. The dropPending defer
|
||||
// in Request must have removed the entry by the time Respond races in.
|
||||
func TestRespondAfterTimeoutNoop(t *testing.T) {
|
||||
defaultTimeout(t, 30*time.Millisecond)
|
||||
defer defaultTimeout(t, DefaultTimeout)
|
||||
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
|
||||
}()
|
||||
id := waitForRequestID(t, pub)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
require.ErrorIs(t, err, ErrTimeout)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("prompt did not time out")
|
||||
}
|
||||
|
||||
assert.False(t, b.Respond(id, Decision{Accept: true}), "late respond must be no-op")
|
||||
}
|
||||
|
||||
// TestRespondDoubleNoop ensures a duplicate ack from the UI doesn't leak
|
||||
// past the matched waiter or panic on a closed/full channel.
|
||||
func TestRespondDoubleNoop(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
|
||||
}()
|
||||
id := waitForRequestID(t, pub)
|
||||
require.True(t, b.Respond(id, Decision{Accept: true}))
|
||||
assert.False(t, b.Respond(id, Decision{Accept: false}), "second response must be no-op")
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assert.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("prompt did not resolve")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNilBrokerRequestErrors guards the engine pre-init path where the
|
||||
// broker may not yet exist (or its publisher is nil): Request must
|
||||
// error, never silently accept.
|
||||
func TestNilBrokerRequestErrors(t *testing.T) {
|
||||
var b *Broker
|
||||
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC})
|
||||
assert.Error(t, err, "nil broker must error, never silently accept")
|
||||
|
||||
b2 := New(nil)
|
||||
_, err = b2.Request(context.Background(), Prompt{Kind: KindVNC})
|
||||
assert.Error(t, err, "broker with nil publisher must error, never silently accept")
|
||||
}
|
||||
|
||||
// TestPromptMetadataInjected confirms the broker stamps request_id, kind,
|
||||
// and expires_at on the emitted event. The UI relies on these keys; if
|
||||
// they are dropped, the user cannot route the prompt and the response
|
||||
// path breaks (which fails closed via timeout).
|
||||
func TestPromptMetadataInjected(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, context.Background(), Prompt{
|
||||
Kind: KindVNC,
|
||||
Subject: "VNC connection from peerA",
|
||||
Metadata: map[string]string{"peer_name": "peerA"},
|
||||
})
|
||||
}()
|
||||
|
||||
id := waitForRequestID(t, pub)
|
||||
ev := pub.lastEvent(t)
|
||||
|
||||
assert.Equal(t, proto.SystemEvent_APPROVAL, ev.Category)
|
||||
assert.Equal(t, KindVNC, ev.Metadata[MetaKind])
|
||||
assert.Equal(t, id, ev.Metadata[MetaRequestID])
|
||||
assert.NotEmpty(t, ev.Metadata[MetaExpiresAt])
|
||||
assert.Equal(t, "peerA", ev.Metadata["peer_name"], "caller metadata must pass through")
|
||||
|
||||
require.True(t, b.Respond(id, Decision{Accept: true}))
|
||||
<-done
|
||||
}
|
||||
|
||||
// TestConcurrentRequests verifies that two concurrent prompts are tracked
|
||||
// independently. A bug that aliases ids would let one Respond unblock
|
||||
// the wrong waiter (a silent accept across prompts).
|
||||
func TestConcurrentRequests(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
const n = 20
|
||||
results := make(chan error, n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
results <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
|
||||
}()
|
||||
}
|
||||
|
||||
ids := waitForNRequestIDs(t, pub, n)
|
||||
require.Len(t, ids, n)
|
||||
|
||||
// Deny exactly half, accept the rest. Track outcome per id so we can
|
||||
// match each Request's return value against the response we sent.
|
||||
denySet := make(map[string]bool, n)
|
||||
for i, id := range ids {
|
||||
deny := i%2 == 0
|
||||
denySet[id] = deny
|
||||
require.True(t, b.Respond(id, Decision{Accept: !deny}))
|
||||
}
|
||||
|
||||
// Collect all returns and check no nil errors slipped past a deny.
|
||||
var accepted, denied atomic.Int32
|
||||
for i := 0; i < n; i++ {
|
||||
select {
|
||||
case err := <-results:
|
||||
if err == nil {
|
||||
accepted.Add(1)
|
||||
} else {
|
||||
assert.ErrorIs(t, err, ErrDenied)
|
||||
denied.Add(1)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("only got %d/%d responses", i, n)
|
||||
}
|
||||
}
|
||||
assert.Equal(t, int32(n/2), denied.Load())
|
||||
assert.Equal(t, int32(n/2), accepted.Load())
|
||||
}
|
||||
|
||||
// waitForRequestID blocks until the publisher sees its next event and
|
||||
// returns the request_id stamped on it.
|
||||
func waitForRequestID(t *testing.T, pub *fakePublisher) string {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
pub.mu.Lock()
|
||||
count := len(pub.events)
|
||||
var id string
|
||||
if count > 0 {
|
||||
id = pub.events[count-1].Metadata[MetaRequestID]
|
||||
}
|
||||
pub.mu.Unlock()
|
||||
if id != "" {
|
||||
return id
|
||||
}
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
}
|
||||
t.Fatal("timeout waiting for emitted event")
|
||||
return ""
|
||||
}
|
||||
|
||||
func waitForNRequestIDs(t *testing.T, pub *fakePublisher, n int) []string {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
pub.mu.Lock()
|
||||
count := len(pub.events)
|
||||
pub.mu.Unlock()
|
||||
if count >= n {
|
||||
break
|
||||
}
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
}
|
||||
pub.mu.Lock()
|
||||
defer pub.mu.Unlock()
|
||||
out := make([]string, 0, len(pub.events))
|
||||
seen := make(map[string]struct{}, len(pub.events))
|
||||
for _, ev := range pub.events {
|
||||
id := ev.Metadata[MetaRequestID]
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, dup := seen[id]; dup {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
if len(out) < n {
|
||||
t.Fatalf("only got %d/%d request ids", len(out), n)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// defaultTimeout swaps the broker's per-request wall-clock window so the
|
||||
// timeout tests run quickly. Restores the prior value on the next call.
|
||||
func defaultTimeout(t *testing.T, d time.Duration) {
|
||||
t.Helper()
|
||||
if d <= 0 {
|
||||
t.Fatal("defaultTimeout must be > 0")
|
||||
}
|
||||
timeoutValue = func() time.Duration { return d }
|
||||
}
|
||||
|
||||
// requestErr wraps Broker.Request to drop the Decision when tests only
|
||||
// care about the error path. Keeps the goroutine bodies tight.
|
||||
func requestErr(b *Broker, ctx context.Context, p Prompt) error {
|
||||
_, err := b.Request(ctx, p)
|
||||
return err
|
||||
}
|
||||
|
||||
// TestRequestViewOnly checks the view-only outcome flows through Request's
|
||||
// Decision return without being silently swallowed.
|
||||
func TestRequestViewOnly(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
type result struct {
|
||||
d Decision
|
||||
err error
|
||||
}
|
||||
done := make(chan result, 1)
|
||||
go func() {
|
||||
d, err := b.Request(context.Background(), Prompt{Kind: KindVNC})
|
||||
done <- result{d, err}
|
||||
}()
|
||||
|
||||
id := waitForRequestID(t, pub)
|
||||
require.True(t, b.Respond(id, Decision{Accept: true, ViewOnly: true}))
|
||||
|
||||
select {
|
||||
case r := <-done:
|
||||
assert.NoError(t, r.err)
|
||||
assert.True(t, r.d.Accept)
|
||||
assert.True(t, r.d.ViewOnly, "ViewOnly must survive the round-trip")
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("view-only request did not resolve")
|
||||
}
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
package approval
|
||||
|
||||
import "testing"
|
||||
|
||||
// TestShortKeyFingerprint locks in the format the VNC approval prompt
|
||||
// shows to the user. The fingerprint is the user's only cryptographic
|
||||
// anchor against a malicious management server that pushes a spoofed
|
||||
// display name, so accidental changes to its format would silently
|
||||
// undermine that defence.
|
||||
func TestShortKeyFingerprint(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "full_32_byte_pubkey",
|
||||
in: "0123456789abcdeffedcba9876543210ffeeddccbbaa99887766554433221100",
|
||||
want: "0123-4567-89ab-cdef",
|
||||
},
|
||||
{
|
||||
name: "exactly_16_chars",
|
||||
in: "0123456789abcdef",
|
||||
want: "0123-4567-89ab-cdef",
|
||||
},
|
||||
{
|
||||
name: "borderline_8_chars",
|
||||
in: "01234567",
|
||||
want: "0123-4567",
|
||||
},
|
||||
{
|
||||
name: "too_short_returns_empty",
|
||||
in: "0123",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty_returns_empty",
|
||||
in: "",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ShortKeyFingerprint(tc.in)
|
||||
if got != tc.want {
|
||||
t.Fatalf("ShortKeyFingerprint(%q) = %q, want %q", tc.in, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestShortKeyFingerprint_DistinctKeysDistinctOutputs guards against a
|
||||
// formatting bug that would collapse different prefixes onto the same
|
||||
// displayed fingerprint and let an attacker substitute their pubkey for
|
||||
// a victim's while keeping the prompt visually identical.
|
||||
func TestShortKeyFingerprint_DistinctKeysDistinctOutputs(t *testing.T) {
|
||||
a := ShortKeyFingerprint("0123456789abcdef" + "rest_of_pubkey_ignored")
|
||||
b := ShortKeyFingerprint("0123456789abcde0" + "rest_of_pubkey_ignored")
|
||||
if a == b {
|
||||
t.Fatalf("expected distinct outputs for distinct prefixes, both = %q", a)
|
||||
}
|
||||
}
|
||||
@@ -315,7 +315,6 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
a.config.RosenpassEnabled,
|
||||
a.config.RosenpassPermissive,
|
||||
a.config.ServerSSHAllowed,
|
||||
a.config.ServerVNCAllowed,
|
||||
a.config.DisableClientRoutes,
|
||||
a.config.DisableServerRoutes,
|
||||
a.config.DisableDNS,
|
||||
|
||||
@@ -562,8 +562,6 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
ServerVNCAllowed: config.ServerVNCAllowed != nil && *config.ServerVNCAllowed,
|
||||
DisableVNCApproval: config.DisableVNCApproval,
|
||||
EnableSSHRoot: config.EnableSSHRoot,
|
||||
EnableSSHSFTP: config.EnableSSHSFTP,
|
||||
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
||||
@@ -646,7 +644,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.RosenpassEnabled,
|
||||
config.RosenpassPermissive,
|
||||
config.ServerSSHAllowed,
|
||||
config.ServerVNCAllowed,
|
||||
config.DisableClientRoutes,
|
||||
config.DisableServerRoutes,
|
||||
config.DisableDNS,
|
||||
|
||||
@@ -636,12 +636,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
if g.internalConfig.SSHJWTCacheTTL != nil {
|
||||
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
|
||||
}
|
||||
if g.internalConfig.ServerVNCAllowed != nil {
|
||||
configContent.WriteString(fmt.Sprintf("ServerVNCAllowed: %v\n", *g.internalConfig.ServerVNCAllowed))
|
||||
}
|
||||
if g.internalConfig.DisableVNCApproval != nil {
|
||||
configContent.WriteString(fmt.Sprintf("DisableVNCApproval: %v\n", *g.internalConfig.DisableVNCApproval))
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||
|
||||
@@ -862,8 +862,6 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
||||
RosenpassEnabled: true,
|
||||
RosenpassPermissive: true,
|
||||
ServerSSHAllowed: &bTrue,
|
||||
ServerVNCAllowed: &bTrue,
|
||||
DisableVNCApproval: &bTrue,
|
||||
EnableSSHRoot: &bTrue,
|
||||
EnableSSHSFTP: &bTrue,
|
||||
EnableSSHLocalPortForwarding: &bTrue,
|
||||
|
||||
63
client/internal/dns/dnsfw/config.go
Normal file
63
client/internal/dns/dnsfw/config.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// EnvDisable disables the DNS firewall entirely when set to a truthy value.
|
||||
EnvDisable = "NB_DISABLE_DNS_FIREWALL"
|
||||
// EnvPorts overrides the comma-separated list of remote ports to block.
|
||||
// Empty disables the firewall.
|
||||
EnvPorts = "NB_DNS_FIREWALL_PORTS"
|
||||
// EnvStrict enables strict mode: permit DNS only to the virtual DNS IP
|
||||
// and the netbird daemon. Default mode also permits anything on the
|
||||
// netbird tunnel interface, which is safer if NRPT is silently ignored
|
||||
// by Windows but lets apps reach custom DNS servers via the tunnel.
|
||||
EnvStrict = "NB_DNS_FIREWALL_STRICT"
|
||||
)
|
||||
|
||||
// defaultBlockedPorts are the well-known DNS ports we block for non-netbird
|
||||
// processes: 53 (plain DNS) and 853 (DNS-over-TLS).
|
||||
var defaultBlockedPorts = []uint16{53, 853}
|
||||
|
||||
// blockedPorts returns the effective port list, honoring env overrides.
|
||||
// A nil return means the firewall should not be installed.
|
||||
func blockedPorts() []uint16 {
|
||||
if disabled, _ := strconv.ParseBool(os.Getenv(EnvDisable)); disabled {
|
||||
log.Infof("dns firewall disabled via %s", EnvDisable)
|
||||
return nil
|
||||
}
|
||||
|
||||
override, ok := os.LookupEnv(EnvPorts)
|
||||
if !ok {
|
||||
return defaultBlockedPorts
|
||||
}
|
||||
|
||||
var ports []uint16
|
||||
for _, raw := range strings.Split(override, ",") {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
continue
|
||||
}
|
||||
port, err := strconv.ParseUint(raw, 10, 16)
|
||||
if err != nil {
|
||||
log.Warnf("dns firewall: ignoring invalid port %q in %s: %v", raw, EnvPorts, err)
|
||||
continue
|
||||
}
|
||||
if port == 0 {
|
||||
log.Warnf("dns firewall: ignoring port 0 in %s", EnvPorts)
|
||||
continue
|
||||
}
|
||||
ports = append(ports, uint16(port))
|
||||
}
|
||||
if len(ports) == 0 {
|
||||
log.Infof("dns firewall disabled: %s yielded no valid ports", EnvPorts)
|
||||
return nil
|
||||
}
|
||||
return ports
|
||||
}
|
||||
39
client/internal/dns/dnsfw/config_test.go
Normal file
39
client/internal/dns/dnsfw/config_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBlockedPorts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
disable string
|
||||
ports string
|
||||
setPorts bool
|
||||
want []uint16
|
||||
}{
|
||||
{name: "default", want: defaultBlockedPorts},
|
||||
{name: "disabled", disable: "true", want: nil},
|
||||
{name: "disabled false keeps default", disable: "false", want: defaultBlockedPorts},
|
||||
{name: "override single port", ports: "53", setPorts: true, want: []uint16{53}},
|
||||
{name: "override multi", ports: "53, 853 ,5353", setPorts: true, want: []uint16{53, 853, 5353}},
|
||||
{name: "override empty disables", ports: "", setPorts: true, want: nil},
|
||||
{name: "override invalid skipped", ports: "53,not-a-port,853", setPorts: true, want: []uint16{53, 853}},
|
||||
{name: "override zero skipped", ports: "53,0,853", setPorts: true, want: []uint16{53, 853}},
|
||||
{name: "override only invalid disables", ports: "abc", setPorts: true, want: nil},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(EnvDisable, tc.disable)
|
||||
if tc.setPorts {
|
||||
t.Setenv(EnvPorts, tc.ports)
|
||||
}
|
||||
got := blockedPorts()
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Fatalf("blockedPorts() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
16
client/internal/dns/dnsfw/dnsfw.go
Normal file
16
client/internal/dns/dnsfw/dnsfw.go
Normal file
@@ -0,0 +1,16 @@
|
||||
// Package dnsfw blocks DNS traffic from non-netbird processes when netbird is
|
||||
// managing the host's DNS, so that resolvers running on apps or libraries
|
||||
// outside netbird cannot bypass the configured DNS path.
|
||||
//
|
||||
// Implementation is Windows-only (uses WFP). On other platforms New returns
|
||||
// a no-op manager.
|
||||
package dnsfw
|
||||
|
||||
import "net/netip"
|
||||
|
||||
// Manager controls the per-tunnel DNS firewall. Both methods must be safe
|
||||
// to call multiple times.
|
||||
type Manager interface {
|
||||
Enable(ifaceGUID string, virtualDNSIP netip.Addr) error
|
||||
Disable() error
|
||||
}
|
||||
15
client/internal/dns/dnsfw/dnsfw_other.go
Normal file
15
client/internal/dns/dnsfw/dnsfw_other.go
Normal file
@@ -0,0 +1,15 @@
|
||||
//go:build !windows
|
||||
|
||||
package dnsfw
|
||||
|
||||
import "net/netip"
|
||||
|
||||
type noopManager struct{}
|
||||
|
||||
func (noopManager) Enable(string, netip.Addr) error { return nil }
|
||||
func (noopManager) Disable() error { return nil }
|
||||
|
||||
// New returns a no-op manager on non-Windows platforms.
|
||||
func New() Manager {
|
||||
return noopManager{}
|
||||
}
|
||||
144
client/internal/dns/dnsfw/dnsfw_windows.go
Normal file
144
client/internal/dns/dnsfw/dnsfw_windows.go
Normal file
@@ -0,0 +1,144 @@
|
||||
//go:build windows
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
modIphlpapi = windows.NewLazyDLL("iphlpapi.dll")
|
||||
procConvertInterfaceGuidToLuid = modIphlpapi.NewProc("ConvertInterfaceGuidToLuid")
|
||||
)
|
||||
|
||||
type windowsManager struct {
|
||||
mu sync.Mutex
|
||||
// session is the WFP engine handle. Zero when disabled.
|
||||
session uintptr
|
||||
}
|
||||
|
||||
// Enable installs the dns firewall. Strict mode propagates failures;
|
||||
// non-strict mode logs and returns nil so partial protection is preserved.
|
||||
func (m *windowsManager) Enable(ifaceGUID string, virtualDNSIP netip.Addr) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
ports := blockedPorts()
|
||||
if len(ports) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if m.session != 0 {
|
||||
if err := m.disableLocked(); err != nil {
|
||||
return fmt.Errorf("reset existing dns firewall session: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
strict := strictMode()
|
||||
|
||||
luid, err := luidFromGUID(ifaceGUID)
|
||||
if err != nil {
|
||||
return m.failOrLog(strict, fmt.Errorf("resolve tun luid from guid %s: %w", ifaceGUID, err))
|
||||
}
|
||||
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return m.failOrLog(strict, fmt.Errorf("resolve daemon executable path: %w", err))
|
||||
}
|
||||
|
||||
cfg := installConfig{
|
||||
tunLUID: luid,
|
||||
daemonExe: exe,
|
||||
blockedPorts: ports,
|
||||
strict: strict,
|
||||
virtualDNSIP: virtualDNSIP,
|
||||
}
|
||||
// session==0 signals a hard failure; non-zero with non-nil err is a partial install.
|
||||
session, installErr := installFilters(cfg)
|
||||
if session == 0 {
|
||||
return m.failOrLog(strict, fmt.Errorf("install dns firewall filters: %w", installErr))
|
||||
}
|
||||
|
||||
if installErr != nil && strict {
|
||||
_ = closeSession(session)
|
||||
return fmt.Errorf("strict dns firewall: partial install: %w", installErr)
|
||||
}
|
||||
|
||||
m.session = session
|
||||
log.Infof("dns firewall installed: iface=%s daemon=%s ports=%v strict=%v virtual_dns=%s",
|
||||
ifaceGUID, exe, ports, strict, virtualDNSIP)
|
||||
if installErr != nil {
|
||||
log.Warnf("dns firewall partially installed (some filters failed): %v", installErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *windowsManager) Disable() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.disableLocked()
|
||||
}
|
||||
|
||||
func (m *windowsManager) disableLocked() error {
|
||||
if m.session == 0 {
|
||||
return nil
|
||||
}
|
||||
session := m.session
|
||||
m.session = 0
|
||||
if err := closeSession(session); err != nil {
|
||||
return fmt.Errorf("close wfp session: %w", err)
|
||||
}
|
||||
log.Info("dns firewall removed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// failOrLog returns err unchanged in strict mode. In non-strict mode the
|
||||
// error is logged and nil is returned.
|
||||
func (m *windowsManager) failOrLog(strict bool, err error) error {
|
||||
if strict {
|
||||
return err
|
||||
}
|
||||
log.Errorf("dns firewall: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// New returns a Windows DNS firewall manager backed by WFP.
|
||||
func New() Manager {
|
||||
return &windowsManager{}
|
||||
}
|
||||
|
||||
// strictMode reports whether strict mode is enabled via env.
|
||||
func strictMode() bool {
|
||||
v, _ := strconv.ParseBool(os.Getenv(EnvStrict))
|
||||
return v
|
||||
}
|
||||
|
||||
// luidFromGUID converts a Windows interface GUID string to its LUID.
|
||||
func luidFromGUID(ifaceGUID string) (luid uint64, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in luidFromGUID: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
guid, err := windows.GUIDFromString(ifaceGUID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parse guid: %w", err)
|
||||
}
|
||||
rc, _, _ := procConvertInterfaceGuidToLuid.Call(
|
||||
uintptr(unsafe.Pointer(&guid)),
|
||||
uintptr(unsafe.Pointer(&luid)),
|
||||
)
|
||||
if rc != 0 {
|
||||
return 0, fmt.Errorf("ConvertInterfaceGuidToLuid returned %d", rc)
|
||||
}
|
||||
return luid, nil
|
||||
}
|
||||
72
client/internal/dns/dnsfw/dnsfw_windows_test.go
Normal file
72
client/internal/dns/dnsfw/dnsfw_windows_test.go
Normal file
@@ -0,0 +1,72 @@
|
||||
//go:build windows
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStrictMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val string
|
||||
set bool
|
||||
want bool
|
||||
}{
|
||||
{name: "unset", want: false},
|
||||
{name: "true", val: "true", set: true, want: true},
|
||||
{name: "1", val: "1", set: true, want: true},
|
||||
{name: "false", val: "false", set: true, want: false},
|
||||
{name: "invalid is false", val: "garbage", set: true, want: false},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(EnvStrict, tc.val)
|
||||
if !tc.set {
|
||||
os.Unsetenv(EnvStrict)
|
||||
}
|
||||
if got := strictMode(); got != tc.want {
|
||||
t.Fatalf("strictMode() = %v, want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsManagerDisableIdempotent(t *testing.T) {
|
||||
m := &windowsManager{}
|
||||
if err := m.Disable(); err != nil {
|
||||
t.Fatalf("first Disable on fresh manager: %v", err)
|
||||
}
|
||||
if err := m.Disable(); err != nil {
|
||||
t.Fatalf("second Disable on fresh manager: %v", err)
|
||||
}
|
||||
if m.session != 0 {
|
||||
t.Fatalf("session should remain zero, got %d", m.session)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsManagerEnableNoOpWhenDisabledByEnv(t *testing.T) {
|
||||
t.Setenv(EnvDisable, "true")
|
||||
|
||||
m := &windowsManager{}
|
||||
if err := m.Enable("00000000-0000-0000-0000-000000000000", netip.Addr{}); err != nil {
|
||||
t.Fatalf("Enable should be a no-op when firewall disabled by env: %v", err)
|
||||
}
|
||||
if m.session != 0 {
|
||||
t.Fatalf("session must remain zero when env disables firewall, got %d", m.session)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsManagerEnableNoOpWhenPortsEmpty(t *testing.T) {
|
||||
t.Setenv(EnvPorts, "")
|
||||
|
||||
m := &windowsManager{}
|
||||
if err := m.Enable("00000000-0000-0000-0000-000000000000", netip.Addr{}); err != nil {
|
||||
t.Fatalf("Enable should be a no-op when ports list is empty: %v", err)
|
||||
}
|
||||
if m.session != 0 {
|
||||
t.Fatalf("session must remain zero when ports list is empty, got %d", m.session)
|
||||
}
|
||||
}
|
||||
53
client/internal/dns/dnsfw/helpers_windows.go
Normal file
53
client/internal/dns/dnsfw/helpers_windows.go
Normal file
@@ -0,0 +1,53 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/helpers.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func createWtFwpmDisplayData0(name, description string) (*wtFwpmDisplayData0, error) {
|
||||
namePtr, err := windows.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
|
||||
descriptionPtr, err := windows.UTF16PtrFromString(description)
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
|
||||
return &wtFwpmDisplayData0{
|
||||
name: namePtr,
|
||||
description: descriptionPtr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func filterWeight(weight uint8) wtFwpValue0 {
|
||||
return wtFwpValue0{
|
||||
_type: cFWP_UINT8,
|
||||
value: uintptr(weight),
|
||||
}
|
||||
}
|
||||
|
||||
func wrapErr(err error) error {
|
||||
var errno syscall.Errno
|
||||
if !errors.As(err, &errno) {
|
||||
return err
|
||||
}
|
||||
_, file, line, ok := runtime.Caller(1)
|
||||
if !ok {
|
||||
return fmt.Errorf("wfp error at unknown location: %w", err)
|
||||
}
|
||||
return fmt.Errorf("wfp error at %s:%d: %w", file, line, err)
|
||||
}
|
||||
249
client/internal/dns/dnsfw/rules_windows.go
Normal file
249
client/internal/dns/dnsfw/rules_windows.go
Normal file
@@ -0,0 +1,249 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2026 NetBird GmbH. All Rights Reserved.
|
||||
*
|
||||
* Filter installers adapted from wireguard-windows tunnel/firewall/rules.go.
|
||||
* The block-DNS approach (port 53 + UDP/TCP) matches what wireguard-windows
|
||||
* uses for its kill-switch DNS leak protection. We extend it with a
|
||||
* configurable port set so we also cover :853 (DoT) and any future ports.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
// Filters install at outbound ALE_AUTH_CONNECT layers only; inbound replies
|
||||
// follow the authorized outbound flow.
|
||||
|
||||
// permitTunInterface installs a permit filter for any traffic whose local
|
||||
// interface is the netbird tunnel.
|
||||
func permitTunInterface(session uintptr, base *baseObjects, weight uint8, ifLUID uint64) error {
|
||||
cond := wtFwpmFilterCondition0{
|
||||
fieldKey: cFWPM_CONDITION_IP_LOCAL_INTERFACE,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT64,
|
||||
value: uintptr(unsafe.Pointer(&ifLUID)),
|
||||
},
|
||||
}
|
||||
|
||||
filter := wtFwpmFilter0{
|
||||
providerKey: &base.provider,
|
||||
subLayerKey: base.filters,
|
||||
weight: filterWeight(weight),
|
||||
numFilterConditions: 1,
|
||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&cond)),
|
||||
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
|
||||
}
|
||||
|
||||
return addOutboundFilters(session, &filter, "Permit netbird tunnel")
|
||||
}
|
||||
|
||||
// permitDaemonByAppID installs a permit filter matching the netbird daemon
|
||||
// executable by App-ID. App-ID alone is sufficient because netbird.exe is a
|
||||
// dedicated binary.
|
||||
func permitDaemonByAppID(session uintptr, base *baseObjects, daemonExe string, weight uint8) error {
|
||||
appID, err := daemonAppID(daemonExe)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fwpmFreeMemory0(unsafe.Pointer(&appID))
|
||||
|
||||
cond := wtFwpmFilterCondition0{
|
||||
fieldKey: cFWPM_CONDITION_ALE_APP_ID,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_BYTE_BLOB_TYPE,
|
||||
value: uintptr(unsafe.Pointer(appID)),
|
||||
},
|
||||
}
|
||||
|
||||
filter := wtFwpmFilter0{
|
||||
providerKey: &base.provider,
|
||||
subLayerKey: base.filters,
|
||||
weight: filterWeight(weight),
|
||||
numFilterConditions: 1,
|
||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&cond)),
|
||||
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
|
||||
}
|
||||
|
||||
return addOutboundFilters(session, &filter, "Permit netbird daemon")
|
||||
}
|
||||
|
||||
// permitVirtualDNSIP installs a permit filter for DNS-port traffic destined
|
||||
// for the in-tunnel virtual DNS IP. Used in strict mode in lieu of
|
||||
// permitTunInterface.
|
||||
func permitVirtualDNSIP(session uintptr, base *baseObjects, ip netip.Addr, ports []uint16, weight uint8) error {
|
||||
var merr *multierror.Error
|
||||
for _, port := range ports {
|
||||
if err := permitDNSToHost(session, base, ip, port, weight); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("permit %s:%d: %w", ip, port, err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func permitDNSToHost(session uintptr, base *baseObjects, ip netip.Addr, port uint16, weight uint8) error {
|
||||
if !ip.IsValid() {
|
||||
return fmt.Errorf("invalid address")
|
||||
}
|
||||
|
||||
var addrCond wtFwpmFilterCondition0
|
||||
var layer windows.GUID
|
||||
// v6 backing must outlive fwpmFilterAdd0; keep it on this stack frame.
|
||||
var v6 wtFwpByteArray16
|
||||
|
||||
if ip.Is4() {
|
||||
v4 := ip.As4()
|
||||
addrCond = wtFwpmFilterCondition0{
|
||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT32,
|
||||
value: uintptr(binary.BigEndian.Uint32(v4[:])),
|
||||
},
|
||||
}
|
||||
layer = cFWPM_LAYER_ALE_AUTH_CONNECT_V4
|
||||
} else {
|
||||
v6 = wtFwpByteArray16{byteArray16: ip.As16()}
|
||||
addrCond = wtFwpmFilterCondition0{
|
||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_ADDRESS,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_BYTE_ARRAY16_TYPE,
|
||||
value: uintptr(unsafe.Pointer(&v6)),
|
||||
},
|
||||
}
|
||||
layer = cFWPM_LAYER_ALE_AUTH_CONNECT_V6
|
||||
}
|
||||
|
||||
conditions := [2]wtFwpmFilterCondition0{
|
||||
addrCond,
|
||||
{
|
||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_PORT,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT16,
|
||||
value: uintptr(port),
|
||||
},
|
||||
},
|
||||
}
|
||||
filter := wtFwpmFilter0{
|
||||
providerKey: &base.provider,
|
||||
subLayerKey: base.filters,
|
||||
weight: filterWeight(weight),
|
||||
numFilterConditions: uint32(len(conditions)),
|
||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&conditions[0])),
|
||||
action: wtFwpmAction0{_type: cFWP_ACTION_PERMIT},
|
||||
}
|
||||
|
||||
display, err := createWtFwpmDisplayData0(fmt.Sprintf("Permit DNS to %s:%d", ip, port), "")
|
||||
if err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
filter.displayData = *display
|
||||
filter.layerKey = layer
|
||||
|
||||
var filterID uint64
|
||||
if err := fwpmFilterAdd0(session, &filter, 0, &filterID); err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
_ = v6
|
||||
return nil
|
||||
}
|
||||
|
||||
// blockDNSPorts installs a deny filter for outbound traffic to each of the
|
||||
// given remote ports over UDP or TCP. Per-port and per-layer failures are
|
||||
// accumulated; partial coverage is preferred over zero coverage.
|
||||
func blockDNSPorts(session uintptr, base *baseObjects, ports []uint16, weight uint8) error {
|
||||
var merr *multierror.Error
|
||||
for _, port := range ports {
|
||||
if err := blockDNSPort(session, base, port, weight); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("block port %d: %w", port, err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func blockDNSPort(session uintptr, base *baseObjects, port uint16, weight uint8) error {
|
||||
conditions := [3]wtFwpmFilterCondition0{
|
||||
{
|
||||
fieldKey: cFWPM_CONDITION_IP_REMOTE_PORT,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT16,
|
||||
value: uintptr(port),
|
||||
},
|
||||
},
|
||||
{
|
||||
fieldKey: cFWPM_CONDITION_IP_PROTOCOL,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT8,
|
||||
value: uintptr(cIPPROTO_UDP),
|
||||
},
|
||||
},
|
||||
// Repeat the IP_PROTOCOL condition for logical OR with TCP.
|
||||
{
|
||||
fieldKey: cFWPM_CONDITION_IP_PROTOCOL,
|
||||
matchType: cFWP_MATCH_EQUAL,
|
||||
conditionValue: wtFwpConditionValue0{
|
||||
_type: cFWP_UINT8,
|
||||
value: uintptr(cIPPROTO_TCP),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
filter := wtFwpmFilter0{
|
||||
providerKey: &base.provider,
|
||||
subLayerKey: base.filters,
|
||||
weight: filterWeight(weight),
|
||||
numFilterConditions: uint32(len(conditions)),
|
||||
filterCondition: (*wtFwpmFilterCondition0)(unsafe.Pointer(&conditions[0])),
|
||||
action: wtFwpmAction0{_type: cFWP_ACTION_BLOCK},
|
||||
}
|
||||
|
||||
return addOutboundFilters(session, &filter, fmt.Sprintf("Block DNS port %d", port))
|
||||
}
|
||||
|
||||
// addOutboundFilters installs the same filter on the v4 and v6 outbound ALE
|
||||
// connect layers. v4 and v6 are installed independently: failure on one
|
||||
// layer does not abort the other, and the accumulated errors are returned.
|
||||
// Partial coverage is preferred over zero coverage.
|
||||
func addOutboundFilters(session uintptr, filter *wtFwpmFilter0, name string) error {
|
||||
layers := [...]struct {
|
||||
layer windows.GUID
|
||||
label string
|
||||
}{
|
||||
{cFWPM_LAYER_ALE_AUTH_CONNECT_V4, name + " (IPv4)"},
|
||||
{cFWPM_LAYER_ALE_AUTH_CONNECT_V6, name + " (IPv6)"},
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, l := range layers {
|
||||
display, err := createWtFwpmDisplayData0(l.label, "")
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("%s: %w", l.label, wrapErr(err)))
|
||||
continue
|
||||
}
|
||||
filter.displayData = *display
|
||||
filter.layerKey = l.layer
|
||||
|
||||
var filterID uint64
|
||||
if err := fwpmFilterAdd0(session, filter, 0, &filterID); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("%s: %w", l.label, wrapErr(err)))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
177
client/internal/dns/dnsfw/session_windows.go
Normal file
177
client/internal/dns/dnsfw/session_windows.go
Normal file
@@ -0,0 +1,177 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2026 NetBird GmbH. All Rights Reserved.
|
||||
*
|
||||
* Session lifecycle and the high-level Install/Close entry points adapted
|
||||
* from wireguard-windows tunnel/firewall.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"unsafe"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
)
|
||||
|
||||
// installConfig is the input to installFilters.
|
||||
type installConfig struct {
|
||||
tunLUID uint64
|
||||
daemonExe string
|
||||
blockedPorts []uint16
|
||||
// strict, when true, narrows the carve-out from "anything on tun" to
|
||||
// "DNS only to virtualDNSIP". virtualDNSIP must be valid in this case.
|
||||
strict bool
|
||||
virtualDNSIP netip.Addr
|
||||
}
|
||||
|
||||
// baseObjects holds the GUIDs of the WFP provider and sublayer registered
|
||||
// for our session. Both are randomly generated per session.
|
||||
type baseObjects struct {
|
||||
provider windows.GUID
|
||||
filters windows.GUID
|
||||
}
|
||||
|
||||
// installFilters opens a dynamic WFP session and installs the netbird DNS
|
||||
// firewall filters. Returns a zero session on hard failure (session create,
|
||||
// base objects); a non-zero session with a non-nil error is a partial install
|
||||
// (some per-filter installs failed) and is safe to close.
|
||||
func installFilters(cfg installConfig) (session uintptr, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Dynamic session: kernel will clean up on process exit even
|
||||
// if we leave the handle dangling here.
|
||||
err = fmt.Errorf("panic in installFilters: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if len(cfg.blockedPorts) == 0 {
|
||||
return 0, errors.New("dns firewall: no blocked ports configured")
|
||||
}
|
||||
if cfg.strict && !cfg.virtualDNSIP.IsValid() {
|
||||
return 0, errors.New("dns firewall: strict mode requires a valid virtual DNS IP")
|
||||
}
|
||||
|
||||
session, err = createSession()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
base, err := registerBaseObjects(session)
|
||||
if err != nil {
|
||||
_ = fwpmEngineClose0(session)
|
||||
return 0, fmt.Errorf("register base objects: %w", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
if cfg.strict {
|
||||
if err := permitVirtualDNSIP(session, base, cfg.virtualDNSIP, cfg.blockedPorts, 15); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("permit virtual dns: %w", err))
|
||||
}
|
||||
} else {
|
||||
if err := permitTunInterface(session, base, 15, cfg.tunLUID); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("permit tun interface: %w", err))
|
||||
}
|
||||
}
|
||||
if err := permitDaemonByAppID(session, base, cfg.daemonExe, 14); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("permit netbird daemon: %w", err))
|
||||
}
|
||||
if err := blockDNSPorts(session, base, cfg.blockedPorts, 10); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("block dns ports: %w", err))
|
||||
}
|
||||
|
||||
return session, nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// closeSession tears down a WFP session previously opened by installFilters.
|
||||
// All filters owned by the session are removed.
|
||||
func closeSession(session uintptr) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in closeSession: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if session == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := fwpmEngineClose0(session); err != nil {
|
||||
return wrapErr(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func createSession() (uintptr, error) {
|
||||
displayData, err := createWtFwpmDisplayData0("NetBird DNS firewall", "NetBird DNS firewall dynamic session")
|
||||
if err != nil {
|
||||
return 0, wrapErr(err)
|
||||
}
|
||||
session := wtFwpmSession0{
|
||||
displayData: *displayData,
|
||||
flags: cFWPM_SESSION_FLAG_DYNAMIC,
|
||||
txnWaitTimeoutInMSec: windows.INFINITE,
|
||||
}
|
||||
var handle uintptr
|
||||
if err := fwpmEngineOpen0(nil, cRPC_C_AUTHN_WINNT, nil, &session, unsafe.Pointer(&handle)); err != nil {
|
||||
return 0, wrapErr(err)
|
||||
}
|
||||
return handle, nil
|
||||
}
|
||||
|
||||
func registerBaseObjects(session uintptr) (*baseObjects, error) {
|
||||
bo := &baseObjects{}
|
||||
var err error
|
||||
if bo.provider, err = windows.GenerateGUID(); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
if bo.filters, err = windows.GenerateGUID(); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
|
||||
displayData, err := createWtFwpmDisplayData0("NetBird DNS firewall", "NetBird DNS firewall provider")
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
provider := wtFwpmProvider0{
|
||||
providerKey: bo.provider,
|
||||
displayData: *displayData,
|
||||
}
|
||||
if err := fwpmProviderAdd0(session, &provider, 0); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
|
||||
subDisplay, err := createWtFwpmDisplayData0("NetBird DNS firewall filters", "Permit and block filters")
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
sublayer := wtFwpmSublayer0{
|
||||
subLayerKey: bo.filters,
|
||||
displayData: *subDisplay,
|
||||
providerKey: &bo.provider,
|
||||
weight: ^uint16(0),
|
||||
}
|
||||
if err := fwpmSubLayerAdd0(session, &sublayer, 0); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
return bo, nil
|
||||
}
|
||||
|
||||
// daemonAppID returns the WFP App-ID byte blob for the given executable path.
|
||||
func daemonAppID(path string) (*wtFwpByteBlob, error) {
|
||||
pathPtr, err := windows.UTF16PtrFromString(path)
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
var appID *wtFwpByteBlob
|
||||
if err := fwpmGetAppIdFromFileName0(pathPtr, unsafe.Pointer(&appID)); err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
return appID, nil
|
||||
}
|
||||
38
client/internal/dns/dnsfw/syscall_windows.go
Normal file
38
client/internal/dns/dnsfw/syscall_windows.go
Normal file
@@ -0,0 +1,38 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/syscall_windows.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmengineopen0
|
||||
//sys fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) [failretval!=0] = fwpuclnt.FwpmEngineOpen0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmengineclose0
|
||||
//sys fwpmEngineClose0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmEngineClose0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmsublayeradd0
|
||||
//sys fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmSubLayerAdd0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmgetappidfromfilename0
|
||||
//sys fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) [failretval!=0] = fwpuclnt.FwpmGetAppIdFromFileName0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmfreememory0
|
||||
//sys fwpmFreeMemory0(p unsafe.Pointer) = fwpuclnt.FwpmFreeMemory0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmfilteradd0
|
||||
//sys fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) [failretval!=0] = fwpuclnt.FwpmFilterAdd0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/Fwpmu/nf-fwpmu-fwpmtransactionbegin0
|
||||
//sys fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionBegin0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmtransactioncommit0
|
||||
//sys fwpmTransactionCommit0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionCommit0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmtransactionabort0
|
||||
//sys fwpmTransactionAbort0(engineHandle uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmTransactionAbort0
|
||||
|
||||
// https://docs.microsoft.com/en-us/windows/desktop/api/fwpmu/nf-fwpmu-fwpmprovideradd0
|
||||
//sys fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) [failretval!=0] = fwpuclnt.FwpmProviderAdd0
|
||||
414
client/internal/dns/dnsfw/types_windows.go
Normal file
414
client/internal/dns/dnsfw/types_windows.go
Normal file
@@ -0,0 +1,414 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/types_windows.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
const (
|
||||
anysizeArray = 1 // ANYSIZE_ARRAY defined in winnt.h
|
||||
|
||||
wtFwpBitmapArray64_Size = 8
|
||||
|
||||
wtFwpByteArray16_Size = 16
|
||||
|
||||
wtFwpByteArray6_Size = 6
|
||||
|
||||
wtFwpmAction0_Size = 20
|
||||
wtFwpmAction0_filterType_Offset = 4
|
||||
|
||||
wtFwpV4AddrAndMask_Size = 8
|
||||
wtFwpV4AddrAndMask_mask_Offset = 4
|
||||
|
||||
wtFwpV6AddrAndMask_Size = 17
|
||||
wtFwpV6AddrAndMask_prefixLength_Offset = 16
|
||||
)
|
||||
|
||||
type wtFwpActionFlag uint32
|
||||
|
||||
const (
|
||||
cFWP_ACTION_FLAG_TERMINATING wtFwpActionFlag = 0x00001000
|
||||
cFWP_ACTION_FLAG_NON_TERMINATING wtFwpActionFlag = 0x00002000
|
||||
cFWP_ACTION_FLAG_CALLOUT wtFwpActionFlag = 0x00004000
|
||||
)
|
||||
|
||||
// FWP_ACTION_TYPE defined in fwptypes.h
|
||||
type wtFwpActionType uint32
|
||||
|
||||
const (
|
||||
cFWP_ACTION_BLOCK wtFwpActionType = wtFwpActionType(0x00000001 | cFWP_ACTION_FLAG_TERMINATING)
|
||||
cFWP_ACTION_PERMIT wtFwpActionType = wtFwpActionType(0x00000002 | cFWP_ACTION_FLAG_TERMINATING)
|
||||
cFWP_ACTION_CALLOUT_TERMINATING wtFwpActionType = wtFwpActionType(0x00000003 | cFWP_ACTION_FLAG_CALLOUT | cFWP_ACTION_FLAG_TERMINATING)
|
||||
cFWP_ACTION_CALLOUT_INSPECTION wtFwpActionType = wtFwpActionType(0x00000004 | cFWP_ACTION_FLAG_CALLOUT | cFWP_ACTION_FLAG_NON_TERMINATING)
|
||||
cFWP_ACTION_CALLOUT_UNKNOWN wtFwpActionType = wtFwpActionType(0x00000005 | cFWP_ACTION_FLAG_CALLOUT)
|
||||
cFWP_ACTION_CONTINUE wtFwpActionType = wtFwpActionType(0x00000006 | cFWP_ACTION_FLAG_NON_TERMINATING)
|
||||
cFWP_ACTION_NONE wtFwpActionType = 0x00000007
|
||||
cFWP_ACTION_NONE_NO_MATCH wtFwpActionType = 0x00000008
|
||||
cFWP_ACTION_BITMAP_INDEX_SET wtFwpActionType = 0x00000009
|
||||
)
|
||||
|
||||
// FWP_BYTE_BLOB defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_blob_)
|
||||
type wtFwpByteBlob struct {
|
||||
size uint32
|
||||
data *uint8
|
||||
}
|
||||
|
||||
// FWP_MATCH_TYPE defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ne-fwptypes-fwp_match_type_)
|
||||
type wtFwpMatchType uint32
|
||||
|
||||
const (
|
||||
cFWP_MATCH_EQUAL wtFwpMatchType = 0
|
||||
cFWP_MATCH_GREATER wtFwpMatchType = cFWP_MATCH_EQUAL + 1
|
||||
cFWP_MATCH_LESS wtFwpMatchType = cFWP_MATCH_GREATER + 1
|
||||
cFWP_MATCH_GREATER_OR_EQUAL wtFwpMatchType = cFWP_MATCH_LESS + 1
|
||||
cFWP_MATCH_LESS_OR_EQUAL wtFwpMatchType = cFWP_MATCH_GREATER_OR_EQUAL + 1
|
||||
cFWP_MATCH_RANGE wtFwpMatchType = cFWP_MATCH_LESS_OR_EQUAL + 1
|
||||
cFWP_MATCH_FLAGS_ALL_SET wtFwpMatchType = cFWP_MATCH_RANGE + 1
|
||||
cFWP_MATCH_FLAGS_ANY_SET wtFwpMatchType = cFWP_MATCH_FLAGS_ALL_SET + 1
|
||||
cFWP_MATCH_FLAGS_NONE_SET wtFwpMatchType = cFWP_MATCH_FLAGS_ANY_SET + 1
|
||||
cFWP_MATCH_EQUAL_CASE_INSENSITIVE wtFwpMatchType = cFWP_MATCH_FLAGS_NONE_SET + 1
|
||||
cFWP_MATCH_NOT_EQUAL wtFwpMatchType = cFWP_MATCH_EQUAL_CASE_INSENSITIVE + 1
|
||||
cFWP_MATCH_PREFIX wtFwpMatchType = cFWP_MATCH_NOT_EQUAL + 1
|
||||
cFWP_MATCH_NOT_PREFIX wtFwpMatchType = cFWP_MATCH_PREFIX + 1
|
||||
cFWP_MATCH_TYPE_MAX wtFwpMatchType = cFWP_MATCH_NOT_PREFIX + 1
|
||||
)
|
||||
|
||||
// FWPM_ACTION0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_action0_)
|
||||
type wtFwpmAction0 struct {
|
||||
_type wtFwpActionType
|
||||
filterType windows.GUID // Windows type: GUID
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. 4cd62a49-59c3-4969-b7f3-bda5d32890a4
|
||||
var cFWPM_CONDITION_IP_LOCAL_INTERFACE = windows.GUID{
|
||||
Data1: 0x4cd62a49,
|
||||
Data2: 0x59c3,
|
||||
Data3: 0x4969,
|
||||
Data4: [8]byte{0xb7, 0xf3, 0xbd, 0xa5, 0xd3, 0x28, 0x90, 0xa4},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. b235ae9a-1d64-49b8-a44c-5ff3d9095045
|
||||
var cFWPM_CONDITION_IP_REMOTE_ADDRESS = windows.GUID{
|
||||
Data1: 0xb235ae9a,
|
||||
Data2: 0x1d64,
|
||||
Data3: 0x49b8,
|
||||
Data4: [8]byte{0xa4, 0x4c, 0x5f, 0xf3, 0xd9, 0x09, 0x50, 0x45},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. 3971ef2b-623e-4f9a-8cb1-6e79b806b9a7
|
||||
var cFWPM_CONDITION_IP_PROTOCOL = windows.GUID{
|
||||
Data1: 0x3971ef2b,
|
||||
Data2: 0x623e,
|
||||
Data3: 0x4f9a,
|
||||
Data4: [8]byte{0x8c, 0xb1, 0x6e, 0x79, 0xb8, 0x06, 0xb9, 0xa7},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. 0c1ba1af-5765-453f-af22-a8f791ac775b
|
||||
var cFWPM_CONDITION_IP_LOCAL_PORT = windows.GUID{
|
||||
Data1: 0x0c1ba1af,
|
||||
Data2: 0x5765,
|
||||
Data3: 0x453f,
|
||||
Data4: [8]byte{0xaf, 0x22, 0xa8, 0xf7, 0x91, 0xac, 0x77, 0x5b},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. c35a604d-d22b-4e1a-91b4-68f674ee674b
|
||||
var cFWPM_CONDITION_IP_REMOTE_PORT = windows.GUID{
|
||||
Data1: 0xc35a604d,
|
||||
Data2: 0xd22b,
|
||||
Data3: 0x4e1a,
|
||||
Data4: [8]byte{0x91, 0xb4, 0x68, 0xf6, 0x74, 0xee, 0x67, 0x4b},
|
||||
}
|
||||
|
||||
// Defined in fwpmu.h. d78e1e87-8644-4ea5-9437-d809ecefc971
|
||||
var cFWPM_CONDITION_ALE_APP_ID = windows.GUID{
|
||||
Data1: 0xd78e1e87,
|
||||
Data2: 0x8644,
|
||||
Data3: 0x4ea5,
|
||||
Data4: [8]byte{0x94, 0x37, 0xd8, 0x09, 0xec, 0xef, 0xc9, 0x71},
|
||||
}
|
||||
|
||||
// af043a0a-b34d-4f86-979c-c90371af6e66
|
||||
var cFWPM_CONDITION_ALE_USER_ID = windows.GUID{
|
||||
Data1: 0xaf043a0a,
|
||||
Data2: 0xb34d,
|
||||
Data3: 0x4f86,
|
||||
Data4: [8]byte{0x97, 0x9c, 0xc9, 0x03, 0x71, 0xaf, 0x6e, 0x66},
|
||||
}
|
||||
|
||||
// d9ee00de-c1ef-4617-bfe3-ffd8f5a08957
|
||||
var cFWPM_CONDITION_IP_LOCAL_ADDRESS = windows.GUID{
|
||||
Data1: 0xd9ee00de,
|
||||
Data2: 0xc1ef,
|
||||
Data3: 0x4617,
|
||||
Data4: [8]byte{0xbf, 0xe3, 0xff, 0xd8, 0xf5, 0xa0, 0x89, 0x57},
|
||||
}
|
||||
|
||||
var (
|
||||
cFWPM_CONDITION_ICMP_TYPE = cFWPM_CONDITION_IP_LOCAL_PORT
|
||||
cFWPM_CONDITION_ICMP_CODE = cFWPM_CONDITION_IP_REMOTE_PORT
|
||||
)
|
||||
|
||||
// 7bc43cbf-37ba-45f1-b74a-82ff518eeb10
|
||||
var cFWPM_CONDITION_L2_FLAGS = windows.GUID{
|
||||
Data1: 0x7bc43cbf,
|
||||
Data2: 0x37ba,
|
||||
Data3: 0x45f1,
|
||||
Data4: [8]byte{0xb7, 0x4a, 0x82, 0xff, 0x51, 0x8e, 0xeb, 0x10},
|
||||
}
|
||||
|
||||
type wtFwpmL2Flags uint32
|
||||
|
||||
const cFWP_CONDITION_L2_IS_VM2VM wtFwpmL2Flags = 0x00000010
|
||||
|
||||
var cFWPM_CONDITION_FLAGS = windows.GUID{
|
||||
Data1: 0x632ce23b,
|
||||
Data2: 0x5167,
|
||||
Data3: 0x435c,
|
||||
Data4: [8]byte{0x86, 0xd7, 0xe9, 0x03, 0x68, 0x4a, 0xa8, 0x0c},
|
||||
}
|
||||
|
||||
type wtFwpmFlags uint32
|
||||
|
||||
const cFWP_CONDITION_FLAG_IS_LOOPBACK wtFwpmFlags = 0x00000001
|
||||
|
||||
// Defined in fwpmtypes.h
|
||||
type wtFwpmFilterFlags uint32
|
||||
|
||||
const (
|
||||
cFWPM_FILTER_FLAG_NONE wtFwpmFilterFlags = 0x00000000
|
||||
cFWPM_FILTER_FLAG_PERSISTENT wtFwpmFilterFlags = 0x00000001
|
||||
cFWPM_FILTER_FLAG_BOOTTIME wtFwpmFilterFlags = 0x00000002
|
||||
cFWPM_FILTER_FLAG_HAS_PROVIDER_CONTEXT wtFwpmFilterFlags = 0x00000004
|
||||
cFWPM_FILTER_FLAG_CLEAR_ACTION_RIGHT wtFwpmFilterFlags = 0x00000008
|
||||
cFWPM_FILTER_FLAG_PERMIT_IF_CALLOUT_UNREGISTERED wtFwpmFilterFlags = 0x00000010
|
||||
cFWPM_FILTER_FLAG_DISABLED wtFwpmFilterFlags = 0x00000020
|
||||
cFWPM_FILTER_FLAG_INDEXED wtFwpmFilterFlags = 0x00000040
|
||||
cFWPM_FILTER_FLAG_HAS_SECURITY_REALM_PROVIDER_CONTEXT wtFwpmFilterFlags = 0x00000080
|
||||
cFWPM_FILTER_FLAG_SYSTEMOS_ONLY wtFwpmFilterFlags = 0x00000100
|
||||
cFWPM_FILTER_FLAG_GAMEOS_ONLY wtFwpmFilterFlags = 0x00000200
|
||||
cFWPM_FILTER_FLAG_SILENT_MODE wtFwpmFilterFlags = 0x00000400
|
||||
cFWPM_FILTER_FLAG_IPSEC_NO_ACQUIRE_INITIATE wtFwpmFilterFlags = 0x00000800
|
||||
)
|
||||
|
||||
// FWPM_LAYER_ALE_AUTH_CONNECT_V4 (c38d57d1-05a7-4c33-904f-7fbceee60e82) defined in fwpmu.h
|
||||
var cFWPM_LAYER_ALE_AUTH_CONNECT_V4 = windows.GUID{
|
||||
Data1: 0xc38d57d1,
|
||||
Data2: 0x05a7,
|
||||
Data3: 0x4c33,
|
||||
Data4: [8]byte{0x90, 0x4f, 0x7f, 0xbc, 0xee, 0xe6, 0x0e, 0x82},
|
||||
}
|
||||
|
||||
// e1cd9fe7-f4b5-4273-96c0-592e487b8650
|
||||
var cFWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V4 = windows.GUID{
|
||||
Data1: 0xe1cd9fe7,
|
||||
Data2: 0xf4b5,
|
||||
Data3: 0x4273,
|
||||
Data4: [8]byte{0x96, 0xc0, 0x59, 0x2e, 0x48, 0x7b, 0x86, 0x50},
|
||||
}
|
||||
|
||||
// FWPM_LAYER_ALE_AUTH_CONNECT_V6 (4a72393b-319f-44bc-84c3-ba54dcb3b6b4) defined in fwpmu.h
|
||||
var cFWPM_LAYER_ALE_AUTH_CONNECT_V6 = windows.GUID{
|
||||
Data1: 0x4a72393b,
|
||||
Data2: 0x319f,
|
||||
Data3: 0x44bc,
|
||||
Data4: [8]byte{0x84, 0xc3, 0xba, 0x54, 0xdc, 0xb3, 0xb6, 0xb4},
|
||||
}
|
||||
|
||||
// a3b42c97-9f04-4672-b87e-cee9c483257f
|
||||
var cFWPM_LAYER_ALE_AUTH_RECV_ACCEPT_V6 = windows.GUID{
|
||||
Data1: 0xa3b42c97,
|
||||
Data2: 0x9f04,
|
||||
Data3: 0x4672,
|
||||
Data4: [8]byte{0xb8, 0x7e, 0xce, 0xe9, 0xc4, 0x83, 0x25, 0x7f},
|
||||
}
|
||||
|
||||
// 94c44912-9d6f-4ebf-b995-05ab8a088d1b
|
||||
var cFWPM_LAYER_OUTBOUND_MAC_FRAME_NATIVE = windows.GUID{
|
||||
Data1: 0x94c44912,
|
||||
Data2: 0x9d6f,
|
||||
Data3: 0x4ebf,
|
||||
Data4: [8]byte{0xb9, 0x95, 0x05, 0xab, 0x8a, 0x08, 0x8d, 0x1b},
|
||||
}
|
||||
|
||||
// d4220bd3-62ce-4f08-ae88-b56e8526df50
|
||||
var cFWPM_LAYER_INBOUND_MAC_FRAME_NATIVE = windows.GUID{
|
||||
Data1: 0xd4220bd3,
|
||||
Data2: 0x62ce,
|
||||
Data3: 0x4f08,
|
||||
Data4: [8]byte{0xae, 0x88, 0xb5, 0x6e, 0x85, 0x26, 0xdf, 0x50},
|
||||
}
|
||||
|
||||
// FWP_BITMAP_ARRAY64 defined in fwtypes.h
|
||||
type wtFwpBitmapArray64 struct {
|
||||
bitmapArray64 [8]uint8 // Windows type: [8]UINT8
|
||||
}
|
||||
|
||||
// FWP_BYTE_ARRAY6 defined in fwtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_array6_)
|
||||
type wtFwpByteArray6 struct {
|
||||
byteArray6 [6]uint8 // Windows type: [6]UINT8
|
||||
}
|
||||
|
||||
// FWP_BYTE_ARRAY16 defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_byte_array16_)
|
||||
type wtFwpByteArray16 struct {
|
||||
byteArray16 [16]uint8 // Windows type [16]UINT8
|
||||
}
|
||||
|
||||
// FWP_CONDITION_VALUE0 defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_condition_value0).
|
||||
type wtFwpConditionValue0 wtFwpValue0
|
||||
|
||||
// FWP_DATA_TYPE defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ne-fwptypes-fwp_data_type_)
|
||||
type wtFwpDataType uint
|
||||
|
||||
const (
|
||||
cFWP_EMPTY wtFwpDataType = 0
|
||||
cFWP_UINT8 wtFwpDataType = cFWP_EMPTY + 1
|
||||
cFWP_UINT16 wtFwpDataType = cFWP_UINT8 + 1
|
||||
cFWP_UINT32 wtFwpDataType = cFWP_UINT16 + 1
|
||||
cFWP_UINT64 wtFwpDataType = cFWP_UINT32 + 1
|
||||
cFWP_INT8 wtFwpDataType = cFWP_UINT64 + 1
|
||||
cFWP_INT16 wtFwpDataType = cFWP_INT8 + 1
|
||||
cFWP_INT32 wtFwpDataType = cFWP_INT16 + 1
|
||||
cFWP_INT64 wtFwpDataType = cFWP_INT32 + 1
|
||||
cFWP_FLOAT wtFwpDataType = cFWP_INT64 + 1
|
||||
cFWP_DOUBLE wtFwpDataType = cFWP_FLOAT + 1
|
||||
cFWP_BYTE_ARRAY16_TYPE wtFwpDataType = cFWP_DOUBLE + 1
|
||||
cFWP_BYTE_BLOB_TYPE wtFwpDataType = cFWP_BYTE_ARRAY16_TYPE + 1
|
||||
cFWP_SID wtFwpDataType = cFWP_BYTE_BLOB_TYPE + 1
|
||||
cFWP_SECURITY_DESCRIPTOR_TYPE wtFwpDataType = cFWP_SID + 1
|
||||
cFWP_TOKEN_INFORMATION_TYPE wtFwpDataType = cFWP_SECURITY_DESCRIPTOR_TYPE + 1
|
||||
cFWP_TOKEN_ACCESS_INFORMATION_TYPE wtFwpDataType = cFWP_TOKEN_INFORMATION_TYPE + 1
|
||||
cFWP_UNICODE_STRING_TYPE wtFwpDataType = cFWP_TOKEN_ACCESS_INFORMATION_TYPE + 1
|
||||
cFWP_BYTE_ARRAY6_TYPE wtFwpDataType = cFWP_UNICODE_STRING_TYPE + 1
|
||||
cFWP_BITMAP_INDEX_TYPE wtFwpDataType = cFWP_BYTE_ARRAY6_TYPE + 1
|
||||
cFWP_BITMAP_ARRAY64_TYPE wtFwpDataType = cFWP_BITMAP_INDEX_TYPE + 1
|
||||
cFWP_SINGLE_DATA_TYPE_MAX wtFwpDataType = 0xff
|
||||
cFWP_V4_ADDR_MASK wtFwpDataType = cFWP_SINGLE_DATA_TYPE_MAX + 1
|
||||
cFWP_V6_ADDR_MASK wtFwpDataType = cFWP_V4_ADDR_MASK + 1
|
||||
cFWP_RANGE_TYPE wtFwpDataType = cFWP_V6_ADDR_MASK + 1
|
||||
cFWP_DATA_TYPE_MAX wtFwpDataType = cFWP_RANGE_TYPE + 1
|
||||
)
|
||||
|
||||
// FWP_V4_ADDR_AND_MASK defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_v4_addr_and_mask).
|
||||
type wtFwpV4AddrAndMask struct {
|
||||
addr uint32
|
||||
mask uint32
|
||||
}
|
||||
|
||||
// FWP_V6_ADDR_AND_MASK defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_v6_addr_and_mask).
|
||||
type wtFwpV6AddrAndMask struct {
|
||||
addr [16]uint8
|
||||
prefixLength uint8
|
||||
}
|
||||
|
||||
// FWP_VALUE0 defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwp_value0_)
|
||||
type wtFwpValue0 struct {
|
||||
_type wtFwpDataType
|
||||
value uintptr
|
||||
}
|
||||
|
||||
// FWPM_DISPLAY_DATA0 defined in fwptypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwptypes/ns-fwptypes-fwpm_display_data0).
|
||||
type wtFwpmDisplayData0 struct {
|
||||
name *uint16 // Windows type: *wchar_t
|
||||
description *uint16 // Windows type: *wchar_t
|
||||
}
|
||||
|
||||
// FWPM_FILTER_CONDITION0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter_condition0).
|
||||
type wtFwpmFilterCondition0 struct {
|
||||
fieldKey windows.GUID // Windows type: GUID
|
||||
matchType wtFwpMatchType
|
||||
conditionValue wtFwpConditionValue0
|
||||
}
|
||||
|
||||
// FWPM_PROVIDER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_provider0_)
|
||||
type wtFwpProvider0 struct {
|
||||
providerKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags uint32
|
||||
providerData wtFwpByteBlob
|
||||
serviceName *uint16 // Windows type: *wchar_t
|
||||
}
|
||||
|
||||
type wtFwpmSessionFlagsValue uint32
|
||||
|
||||
const (
|
||||
cFWPM_SESSION_FLAG_DYNAMIC wtFwpmSessionFlagsValue = 0x00000001 // FWPM_SESSION_FLAG_DYNAMIC defined in fwpmtypes.h
|
||||
)
|
||||
|
||||
// FWPM_SESSION0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_session0).
|
||||
type wtFwpmSession0 struct {
|
||||
sessionKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags wtFwpmSessionFlagsValue // Windows type UINT32
|
||||
txnWaitTimeoutInMSec uint32
|
||||
processId uint32 // Windows type: DWORD
|
||||
sid *windows.SID
|
||||
username *uint16 // Windows type: *wchar_t
|
||||
kernelMode uint8 // Windows type: BOOL
|
||||
}
|
||||
|
||||
type wtFwpmSublayerFlags uint32
|
||||
|
||||
const (
|
||||
cFWPM_SUBLAYER_FLAG_PERSISTENT wtFwpmSublayerFlags = 0x00000001 // FWPM_SUBLAYER_FLAG_PERSISTENT defined in fwpmtypes.h
|
||||
)
|
||||
|
||||
// FWPM_SUBLAYER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_sublayer0_)
|
||||
type wtFwpmSublayer0 struct {
|
||||
subLayerKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags wtFwpmSublayerFlags
|
||||
providerKey *windows.GUID // Windows type: *GUID
|
||||
providerData wtFwpByteBlob
|
||||
weight uint16
|
||||
}
|
||||
|
||||
// Defined in rpcdce.h
|
||||
type wtRpcCAuthN uint32
|
||||
|
||||
const (
|
||||
cRPC_C_AUTHN_NONE wtRpcCAuthN = 0
|
||||
cRPC_C_AUTHN_WINNT wtRpcCAuthN = 10
|
||||
cRPC_C_AUTHN_DEFAULT wtRpcCAuthN = 0xFFFFFFFF
|
||||
)
|
||||
|
||||
// FWPM_PROVIDER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/sv-se/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_provider0).
|
||||
type wtFwpmProvider0 struct {
|
||||
providerKey windows.GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags uint32
|
||||
providerData wtFwpByteBlob
|
||||
serviceName *uint16
|
||||
}
|
||||
|
||||
type wtIPProto uint32
|
||||
|
||||
const (
|
||||
cIPPROTO_ICMP wtIPProto = 1
|
||||
cIPPROTO_ICMPV6 wtIPProto = 58
|
||||
cIPPROTO_TCP wtIPProto = 6
|
||||
cIPPROTO_UDP wtIPProto = 17
|
||||
)
|
||||
|
||||
const (
|
||||
cFWP_ACTRL_MATCH_FILTER = 1
|
||||
)
|
||||
92
client/internal/dns/dnsfw/types_windows_32.go
Normal file
92
client/internal/dns/dnsfw/types_windows_32.go
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build windows && (386 || arm)
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/types_windows_32.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
const (
|
||||
wtFwpByteBlob_Size = 8
|
||||
wtFwpByteBlob_data_Offset = 4
|
||||
|
||||
wtFwpConditionValue0_Size = 8
|
||||
wtFwpConditionValue0_uint8_Offset = 4
|
||||
|
||||
wtFwpmDisplayData0_Size = 8
|
||||
wtFwpmDisplayData0_description_Offset = 4
|
||||
|
||||
wtFwpmFilter0_Size = 152
|
||||
wtFwpmFilter0_displayData_Offset = 16
|
||||
wtFwpmFilter0_flags_Offset = 24
|
||||
wtFwpmFilter0_providerKey_Offset = 28
|
||||
wtFwpmFilter0_providerData_Offset = 32
|
||||
wtFwpmFilter0_layerKey_Offset = 40
|
||||
wtFwpmFilter0_subLayerKey_Offset = 56
|
||||
wtFwpmFilter0_weight_Offset = 72
|
||||
wtFwpmFilter0_numFilterConditions_Offset = 80
|
||||
wtFwpmFilter0_filterCondition_Offset = 84
|
||||
wtFwpmFilter0_action_Offset = 88
|
||||
wtFwpmFilter0_providerContextKey_Offset = 112
|
||||
wtFwpmFilter0_reserved_Offset = 128
|
||||
wtFwpmFilter0_filterID_Offset = 136
|
||||
wtFwpmFilter0_effectiveWeight_Offset = 144
|
||||
|
||||
wtFwpmFilterCondition0_Size = 28
|
||||
wtFwpmFilterCondition0_matchType_Offset = 16
|
||||
wtFwpmFilterCondition0_conditionValue_Offset = 20
|
||||
|
||||
wtFwpmSession0_Size = 48
|
||||
wtFwpmSession0_displayData_Offset = 16
|
||||
wtFwpmSession0_flags_Offset = 24
|
||||
wtFwpmSession0_txnWaitTimeoutInMSec_Offset = 28
|
||||
wtFwpmSession0_processId_Offset = 32
|
||||
wtFwpmSession0_sid_Offset = 36
|
||||
wtFwpmSession0_username_Offset = 40
|
||||
wtFwpmSession0_kernelMode_Offset = 44
|
||||
|
||||
wtFwpmSublayer0_Size = 44
|
||||
wtFwpmSublayer0_displayData_Offset = 16
|
||||
wtFwpmSublayer0_flags_Offset = 24
|
||||
wtFwpmSublayer0_providerKey_Offset = 28
|
||||
wtFwpmSublayer0_providerData_Offset = 32
|
||||
wtFwpmSublayer0_weight_Offset = 40
|
||||
|
||||
wtFwpProvider0_Size = 40
|
||||
wtFwpProvider0_displayData_Offset = 16
|
||||
wtFwpProvider0_flags_Offset = 24
|
||||
wtFwpProvider0_providerData_Offset = 28
|
||||
wtFwpProvider0_serviceName_Offset = 36
|
||||
|
||||
wtFwpTokenInformation_Size = 16
|
||||
|
||||
wtFwpValue0_Size = 8
|
||||
wtFwpValue0_value_Offset = 4
|
||||
)
|
||||
|
||||
// FWPM_FILTER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter0).
|
||||
type wtFwpmFilter0 struct {
|
||||
filterKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags wtFwpmFilterFlags
|
||||
providerKey *windows.GUID // Windows type: *GUID
|
||||
providerData wtFwpByteBlob
|
||||
layerKey windows.GUID // Windows type: GUID
|
||||
subLayerKey windows.GUID // Windows type: GUID
|
||||
weight wtFwpValue0
|
||||
numFilterConditions uint32
|
||||
filterCondition *wtFwpmFilterCondition0
|
||||
action wtFwpmAction0
|
||||
offset1 [4]byte // Layout correction field
|
||||
providerContextKey windows.GUID // Windows type: GUID
|
||||
reserved *windows.GUID // Windows type: *GUID
|
||||
offset2 [4]byte // Layout correction field
|
||||
filterID uint64
|
||||
effectiveWeight wtFwpValue0
|
||||
}
|
||||
89
client/internal/dns/dnsfw/types_windows_64.go
Normal file
89
client/internal/dns/dnsfw/types_windows_64.go
Normal file
@@ -0,0 +1,89 @@
|
||||
//go:build windows && (amd64 || arm64)
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* Adapted from wireguard-windows tunnel/firewall/types_windows_64.go.
|
||||
*/
|
||||
|
||||
package dnsfw
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
const (
|
||||
wtFwpByteBlob_Size = 16
|
||||
wtFwpByteBlob_data_Offset = 8
|
||||
|
||||
wtFwpConditionValue0_Size = 16
|
||||
wtFwpConditionValue0_uint8_Offset = 8
|
||||
|
||||
wtFwpmDisplayData0_Size = 16
|
||||
wtFwpmDisplayData0_description_Offset = 8
|
||||
|
||||
wtFwpmFilter0_Size = 200
|
||||
wtFwpmFilter0_displayData_Offset = 16
|
||||
wtFwpmFilter0_flags_Offset = 32
|
||||
wtFwpmFilter0_providerKey_Offset = 40
|
||||
wtFwpmFilter0_providerData_Offset = 48
|
||||
wtFwpmFilter0_layerKey_Offset = 64
|
||||
wtFwpmFilter0_subLayerKey_Offset = 80
|
||||
wtFwpmFilter0_weight_Offset = 96
|
||||
wtFwpmFilter0_numFilterConditions_Offset = 112
|
||||
wtFwpmFilter0_filterCondition_Offset = 120
|
||||
wtFwpmFilter0_action_Offset = 128
|
||||
wtFwpmFilter0_providerContextKey_Offset = 152
|
||||
wtFwpmFilter0_reserved_Offset = 168
|
||||
wtFwpmFilter0_filterID_Offset = 176
|
||||
wtFwpmFilter0_effectiveWeight_Offset = 184
|
||||
|
||||
wtFwpmFilterCondition0_Size = 40
|
||||
wtFwpmFilterCondition0_matchType_Offset = 16
|
||||
wtFwpmFilterCondition0_conditionValue_Offset = 24
|
||||
|
||||
wtFwpmSession0_Size = 72
|
||||
wtFwpmSession0_displayData_Offset = 16
|
||||
wtFwpmSession0_flags_Offset = 32
|
||||
wtFwpmSession0_txnWaitTimeoutInMSec_Offset = 36
|
||||
wtFwpmSession0_processId_Offset = 40
|
||||
wtFwpmSession0_sid_Offset = 48
|
||||
wtFwpmSession0_username_Offset = 56
|
||||
wtFwpmSession0_kernelMode_Offset = 64
|
||||
|
||||
wtFwpmSublayer0_Size = 72
|
||||
wtFwpmSublayer0_displayData_Offset = 16
|
||||
wtFwpmSublayer0_flags_Offset = 32
|
||||
wtFwpmSublayer0_providerKey_Offset = 40
|
||||
wtFwpmSublayer0_providerData_Offset = 48
|
||||
wtFwpmSublayer0_weight_Offset = 64
|
||||
|
||||
wtFwpProvider0_Size = 64
|
||||
wtFwpProvider0_displayData_Offset = 16
|
||||
wtFwpProvider0_flags_Offset = 32
|
||||
wtFwpProvider0_providerData_Offset = 40
|
||||
wtFwpProvider0_serviceName_Offset = 56
|
||||
|
||||
wtFwpValue0_Size = 16
|
||||
wtFwpValue0_value_Offset = 8
|
||||
)
|
||||
|
||||
// FWPM_FILTER0 defined in fwpmtypes.h
|
||||
// (https://docs.microsoft.com/en-us/windows/desktop/api/fwpmtypes/ns-fwpmtypes-fwpm_filter0).
|
||||
type wtFwpmFilter0 struct {
|
||||
filterKey windows.GUID // Windows type: GUID
|
||||
displayData wtFwpmDisplayData0
|
||||
flags wtFwpmFilterFlags // Windows type: UINT32
|
||||
providerKey *windows.GUID // Windows type: *GUID
|
||||
providerData wtFwpByteBlob
|
||||
layerKey windows.GUID // Windows type: GUID
|
||||
subLayerKey windows.GUID // Windows type: GUID
|
||||
weight wtFwpValue0
|
||||
numFilterConditions uint32
|
||||
filterCondition *wtFwpmFilterCondition0
|
||||
action wtFwpmAction0
|
||||
offset1 [4]byte // Layout correction field
|
||||
providerContextKey windows.GUID // Windows type: GUID
|
||||
reserved *windows.GUID // Windows type: *GUID
|
||||
filterID uint64
|
||||
effectiveWeight wtFwpValue0
|
||||
}
|
||||
130
client/internal/dns/dnsfw/zsyscall_windows.go
Normal file
130
client/internal/dns/dnsfw/zsyscall_windows.go
Normal file
@@ -0,0 +1,130 @@
|
||||
// Code generated by 'go generate'; DO NOT EDIT.
|
||||
|
||||
package dnsfw
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modfwpuclnt = windows.NewLazySystemDLL("fwpuclnt.dll")
|
||||
|
||||
procFwpmEngineClose0 = modfwpuclnt.NewProc("FwpmEngineClose0")
|
||||
procFwpmEngineOpen0 = modfwpuclnt.NewProc("FwpmEngineOpen0")
|
||||
procFwpmFilterAdd0 = modfwpuclnt.NewProc("FwpmFilterAdd0")
|
||||
procFwpmFreeMemory0 = modfwpuclnt.NewProc("FwpmFreeMemory0")
|
||||
procFwpmGetAppIdFromFileName0 = modfwpuclnt.NewProc("FwpmGetAppIdFromFileName0")
|
||||
procFwpmProviderAdd0 = modfwpuclnt.NewProc("FwpmProviderAdd0")
|
||||
procFwpmSubLayerAdd0 = modfwpuclnt.NewProc("FwpmSubLayerAdd0")
|
||||
procFwpmTransactionAbort0 = modfwpuclnt.NewProc("FwpmTransactionAbort0")
|
||||
procFwpmTransactionBegin0 = modfwpuclnt.NewProc("FwpmTransactionBegin0")
|
||||
procFwpmTransactionCommit0 = modfwpuclnt.NewProc("FwpmTransactionCommit0")
|
||||
)
|
||||
|
||||
func fwpmEngineClose0(engineHandle uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmEngineClose0.Addr(), 1, uintptr(engineHandle), 0, 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmEngineOpen0(serverName *uint16, authnService wtRpcCAuthN, authIdentity *uintptr, session *wtFwpmSession0, engineHandle unsafe.Pointer) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procFwpmEngineOpen0.Addr(), 5, uintptr(unsafe.Pointer(serverName)), uintptr(authnService), uintptr(unsafe.Pointer(authIdentity)), uintptr(unsafe.Pointer(session)), uintptr(engineHandle), 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmFilterAdd0(engineHandle uintptr, filter *wtFwpmFilter0, sd uintptr, id *uint64) (err error) {
|
||||
r1, _, e1 := syscall.Syscall6(procFwpmFilterAdd0.Addr(), 4, uintptr(engineHandle), uintptr(unsafe.Pointer(filter)), uintptr(sd), uintptr(unsafe.Pointer(id)), 0, 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmFreeMemory0(p unsafe.Pointer) {
|
||||
syscall.Syscall(procFwpmFreeMemory0.Addr(), 1, uintptr(p), 0, 0)
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmGetAppIdFromFileName0(fileName *uint16, appID unsafe.Pointer) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmGetAppIdFromFileName0.Addr(), 2, uintptr(unsafe.Pointer(fileName)), uintptr(appID), 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmProviderAdd0(engineHandle uintptr, provider *wtFwpmProvider0, sd uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmProviderAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(provider)), uintptr(sd))
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmSubLayerAdd0(engineHandle uintptr, subLayer *wtFwpmSublayer0, sd uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmSubLayerAdd0.Addr(), 3, uintptr(engineHandle), uintptr(unsafe.Pointer(subLayer)), uintptr(sd))
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmTransactionAbort0(engineHandle uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmTransactionAbort0.Addr(), 1, uintptr(engineHandle), 0, 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmTransactionBegin0(engineHandle uintptr, flags uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmTransactionBegin0.Addr(), 2, uintptr(engineHandle), uintptr(flags), 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fwpmTransactionCommit0(engineHandle uintptr) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procFwpmTransactionCommit0.Addr(), 1, uintptr(engineHandle), 0, 0)
|
||||
if r1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/dnsfw"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/winregistry"
|
||||
)
|
||||
@@ -74,6 +75,7 @@ type registryConfigurator struct {
|
||||
routingAll bool
|
||||
gpo bool
|
||||
nrptEntryCount int
|
||||
dnsFirewall dnsfw.Manager
|
||||
origNameservers []netip.Addr
|
||||
}
|
||||
|
||||
@@ -94,8 +96,9 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||
}
|
||||
|
||||
configurator := ®istryConfigurator{
|
||||
guid: guid,
|
||||
gpo: useGPO,
|
||||
guid: guid,
|
||||
gpo: useGPO,
|
||||
dnsFirewall: dnsfw.New(),
|
||||
}
|
||||
|
||||
origNameservers, err := configurator.captureOriginalNameservers()
|
||||
@@ -276,16 +279,8 @@ func (r *registryConfigurator) disableWINSForInterface() error {
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||
if config.RouteAll {
|
||||
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
|
||||
return fmt.Errorf("add dns setup: %w", err)
|
||||
}
|
||||
} else if r.routingAll {
|
||||
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
|
||||
return fmt.Errorf("delete interface registry key property: %w", err)
|
||||
}
|
||||
r.routingAll = false
|
||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||
if err := r.applyRouteAll(config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.updateState(stateManager)
|
||||
@@ -327,6 +322,35 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) applyRouteAll(config HostDNSConfig) error {
|
||||
if config.RouteAll {
|
||||
if err := r.dnsFirewall.Enable(r.guid, config.ServerIP); err != nil {
|
||||
return fmt.Errorf("dns firewall: %w", err)
|
||||
}
|
||||
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
|
||||
merr := multierror.Append(nil, fmt.Errorf("add dns setup: %w", err))
|
||||
if dErr := r.dnsFirewall.Disable(); dErr != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("rollback dns firewall: %w", dErr))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.dnsFirewall.Disable(); err != nil {
|
||||
log.Errorf("disable dns firewall: %v", err)
|
||||
}
|
||||
if !r.routingAll {
|
||||
return nil
|
||||
}
|
||||
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
|
||||
return fmt.Errorf("delete interface registry key property: %w", err)
|
||||
}
|
||||
r.routingAll = false
|
||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
Guid: r.guid,
|
||||
@@ -513,6 +537,10 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
||||
return fmt.Errorf("remove interface registry key: %w", err)
|
||||
}
|
||||
|
||||
if err := r.dnsFirewall.Disable(); err != nil {
|
||||
log.Errorf("disable dns firewall: %v", err)
|
||||
}
|
||||
|
||||
go r.flushDNSCache()
|
||||
|
||||
return nil
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns/dnsfw"
|
||||
)
|
||||
|
||||
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
|
||||
@@ -34,8 +36,9 @@ func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
||||
}()
|
||||
|
||||
cfg := ®istryConfigurator{
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
dnsFirewall: dnsfw.New(),
|
||||
}
|
||||
|
||||
// Create 125 domains which will result in 3 NRPT rules (50+50+25)
|
||||
@@ -134,8 +137,9 @@ func TestNRPTDomainBatching(t *testing.T) {
|
||||
}()
|
||||
|
||||
cfg := ®istryConfigurator{
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
guid: testGUID,
|
||||
gpo: false,
|
||||
dnsFirewall: dnsfw.New(),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
|
||||
@@ -35,7 +35,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/approval"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
@@ -124,8 +123,6 @@ type EngineConfig struct {
|
||||
RosenpassPermissive bool
|
||||
|
||||
ServerSSHAllowed bool
|
||||
ServerVNCAllowed bool
|
||||
DisableVNCApproval *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
@@ -207,9 +204,7 @@ type Engine struct {
|
||||
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServer sshServer
|
||||
vncSrv vncServer
|
||||
approvalBroker *approval.Broker
|
||||
sshServer sshServer
|
||||
|
||||
statusRecorder *peer.Status
|
||||
|
||||
@@ -290,7 +285,6 @@ func NewEngine(
|
||||
TURNs: []*stun.URI{},
|
||||
networkSerial: 0,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
approvalBroker: approval.New(services.StatusRecorder),
|
||||
stateManager: services.StateManager,
|
||||
portForwardManager: portforward.NewManager(),
|
||||
checks: services.Checks,
|
||||
@@ -326,10 +320,6 @@ func (e *Engine) Stop() error {
|
||||
log.Warnf("failed to stop SSH server: %v", err)
|
||||
}
|
||||
|
||||
if err := e.stopVNCServer(); err != nil {
|
||||
log.Warnf("failed to stop VNC server: %v", err)
|
||||
}
|
||||
|
||||
e.cleanupSSHConfig()
|
||||
|
||||
if e.ingressGatewayMgr != nil {
|
||||
@@ -1020,7 +1010,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -1068,10 +1057,6 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := e.updateVNC(); err != nil {
|
||||
log.Warnf("failed handling VNC server setup: %v", err)
|
||||
}
|
||||
|
||||
state := e.statusRecorder.GetLocalPeerState()
|
||||
state.IP = e.wgInterface.Address().String()
|
||||
state.IPv6 = e.wgInterface.Address().IPv6String()
|
||||
@@ -1197,7 +1182,6 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -1387,11 +1371,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||
}
|
||||
|
||||
// VNC auth: always sync, including nil so cleared auth on the management
|
||||
// side is applied locally, and so it isn't skipped on the RemotePeersIsEmpty
|
||||
// cleanup path.
|
||||
e.updateVNCServerAuth(networkMap.GetVncAuth())
|
||||
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
|
||||
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
||||
@@ -1847,7 +1826,6 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -2612,16 +2590,3 @@ func decodeRelayIP(b []byte) netip.Addr {
|
||||
}
|
||||
return ip.Unmap()
|
||||
}
|
||||
|
||||
// RespondApproval relays the user's decision for a pending approval to
|
||||
// the broker. viewOnly is honoured only when accept is true. Returns
|
||||
// true when the request_id matched a live prompt.
|
||||
func (e *Engine) RespondApproval(requestID string, accept, viewOnly bool) bool {
|
||||
if e == nil || e.approvalBroker == nil {
|
||||
return false
|
||||
}
|
||||
return e.approvalBroker.Respond(requestID, approval.Decision{
|
||||
Accept: accept,
|
||||
ViewOnly: accept && viewOnly,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
@@ -237,18 +237,22 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
|
||||
return errors.New("wg interface not initialized")
|
||||
}
|
||||
|
||||
wgAddr := e.wgInterface.Address()
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: e.config.SSHKey,
|
||||
JWT: jwtConfig,
|
||||
NetstackNet: e.wgInterface.GetNet(),
|
||||
NetworkValidation: wgAddr,
|
||||
HostKeyPEM: e.config.SSHKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
|
||||
wgAddr := e.wgInterface.Address()
|
||||
server.SetNetworkValidation(wgAddr)
|
||||
|
||||
netbirdIP := wgAddr.IP
|
||||
listenAddr := netip.AddrPortFrom(netbirdIP, sshserver.InternalSSHPort)
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
server.SetNetstackNet(netstackNet)
|
||||
}
|
||||
|
||||
e.configureSSHServer(server)
|
||||
|
||||
if err := server.Start(e.ctx, listenAddr); err != nil {
|
||||
|
||||
@@ -1,303 +0,0 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/approval"
|
||||
"github.com/netbirdio/netbird/client/internal/metrics"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/vnc"
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
|
||||
type vncServer interface {
|
||||
Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error
|
||||
Stop() error
|
||||
ActiveSessions() []vncserver.ActiveSessionInfo
|
||||
}
|
||||
|
||||
func (e *Engine) setupVNCPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, vnc.ExternalPort, vnc.InternalPort); err != nil {
|
||||
return fmt.Errorf("add VNC port redirection: %w", err)
|
||||
}
|
||||
log.Infof("VNC port redirection: %s:%d -> %s:%d", localAddr, vnc.ExternalPort, localAddr, vnc.InternalPort)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) cleanupVNCPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, vnc.ExternalPort, vnc.InternalPort); err != nil {
|
||||
return fmt.Errorf("remove VNC port redirection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateVNC handles starting/stopping the VNC server based on the config flag.
|
||||
func (e *Engine) updateVNC() error {
|
||||
if !e.config.ServerVNCAllowed {
|
||||
if e.vncSrv != nil {
|
||||
log.Info("VNC server disabled, stopping")
|
||||
}
|
||||
return e.stopVNCServer()
|
||||
}
|
||||
|
||||
if e.config.BlockInbound {
|
||||
log.Info("VNC server disabled because inbound connections are blocked")
|
||||
return e.stopVNCServer()
|
||||
}
|
||||
|
||||
if e.vncSrv != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.startVNCServer()
|
||||
}
|
||||
|
||||
func (e *Engine) startVNCServer() error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wg interface not initialized")
|
||||
}
|
||||
|
||||
capturer, injector, ok := newPlatformVNC()
|
||||
if !ok {
|
||||
log.Debug("VNC server not supported on this platform")
|
||||
return nil
|
||||
}
|
||||
|
||||
netbirdIP := e.wgInterface.Address().IP
|
||||
|
||||
var sessionRecorder func(vncserver.SessionTick)
|
||||
if e.clientMetrics != nil {
|
||||
sessionRecorder = func(t vncserver.SessionTick) {
|
||||
e.clientMetrics.RecordVNCSessionTick(e.ctx, metrics.VNCSessionTick{
|
||||
Period: t.Period,
|
||||
BytesOut: t.BytesOut,
|
||||
Writes: t.Writes,
|
||||
FBUs: t.FBUs,
|
||||
MaxFBUBytes: t.MaxFBUBytes,
|
||||
MaxFBURects: t.MaxFBURects,
|
||||
MaxWriteBytes: t.MaxWriteBytes,
|
||||
WriteNanos: t.WriteNanos,
|
||||
})
|
||||
}
|
||||
}
|
||||
serviceMode := vncNeedsServiceMode()
|
||||
if serviceMode {
|
||||
log.Info("VNC: running as system service, enabling service mode (per-session agent proxy)")
|
||||
}
|
||||
requireApproval := e.config.DisableVNCApproval == nil || !*e.config.DisableVNCApproval
|
||||
srv := vncserver.New(vncserver.Config{
|
||||
Capturer: capturer,
|
||||
Injector: injector,
|
||||
IdentityKey: e.config.WgPrivateKey[:],
|
||||
ServiceMode: serviceMode,
|
||||
SessionRecorder: sessionRecorder,
|
||||
NetstackNet: e.wgInterface.GetNet(),
|
||||
RequireApproval: requireApproval,
|
||||
Approver: &vncApprover{broker: e.approvalBroker, statusRecorder: e.statusRecorder},
|
||||
})
|
||||
|
||||
listenAddr := netip.AddrPortFrom(netbirdIP, vnc.InternalPort)
|
||||
network := e.wgInterface.Address().Network
|
||||
if err := srv.Start(e.ctx, listenAddr, network); err != nil {
|
||||
return fmt.Errorf("start VNC server: %w", err)
|
||||
}
|
||||
|
||||
e.vncSrv = srv
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.RegisterNetstackService(nftypes.TCP, vnc.InternalPort)
|
||||
log.Debugf("registered VNC service with netstack for TCP:%d", vnc.InternalPort)
|
||||
}
|
||||
}
|
||||
|
||||
if err := e.setupVNCPortRedirection(); err != nil {
|
||||
log.Warnf("setup VNC port redirection: %v", err)
|
||||
}
|
||||
|
||||
log.Info("VNC server enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateVNCServerAuth updates VNC fine-grained access control from management.
|
||||
// A nil vncAuth clears all authorized users and session pubkeys so management
|
||||
// can revoke access by omitting the field on the next sync.
|
||||
func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
|
||||
if e.vncSrv == nil {
|
||||
return
|
||||
}
|
||||
|
||||
vncSrv, ok := e.vncSrv.(*vncserver.Server)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if vncAuth == nil {
|
||||
vncSrv.UpdateVNCAuth(&sshauth.Config{})
|
||||
return
|
||||
}
|
||||
|
||||
protoUsers := vncAuth.GetAuthorizedUsers()
|
||||
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
|
||||
for i, hash := range protoUsers {
|
||||
if len(hash) != 16 {
|
||||
log.Warnf("invalid VNC auth hash length %d, expected 16", len(hash))
|
||||
return
|
||||
}
|
||||
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
|
||||
}
|
||||
|
||||
machineUsers := make(map[string][]uint32)
|
||||
for osUser, indexes := range vncAuth.GetMachineUsers() {
|
||||
machineUsers[osUser] = indexes.GetIndexes()
|
||||
}
|
||||
|
||||
sessionPubKeys := make([]sshauth.SessionPubKey, 0, len(vncAuth.GetSessionPubKeys()))
|
||||
for _, e := range vncAuth.GetSessionPubKeys() {
|
||||
pub := e.GetPubKey()
|
||||
if len(pub) != 32 {
|
||||
log.Warnf("VNC session pubkey wrong length %d", len(pub))
|
||||
continue
|
||||
}
|
||||
hash := e.GetUserIdHash()
|
||||
if len(hash) != 16 {
|
||||
log.Warnf("VNC session user id hash wrong length %d", len(hash))
|
||||
continue
|
||||
}
|
||||
sessionPubKeys = append(sessionPubKeys, sshauth.SessionPubKey{
|
||||
PubKey: pub,
|
||||
UserIDHash: sshuserhash.UserIDHash(hash),
|
||||
DisplayName: e.GetDisplayName(),
|
||||
})
|
||||
}
|
||||
|
||||
vncSrv.UpdateVNCAuth(&sshauth.Config{
|
||||
AuthorizedUsers: authorizedUsers,
|
||||
MachineUsers: machineUsers,
|
||||
SessionPubKeys: sessionPubKeys,
|
||||
})
|
||||
}
|
||||
|
||||
// GetVNCServerStatus returns whether the VNC server is running and the list
|
||||
// of active VNC sessions. The pointer is captured under syncMsgMux so a
|
||||
// concurrent updateVNC/stopVNCServer cannot swap it out between the nil
|
||||
// check and the ActiveSessions call.
|
||||
func (e *Engine) GetVNCServerStatus() (enabled bool, sessions []vncserver.ActiveSessionInfo) {
|
||||
e.syncMsgMux.Lock()
|
||||
vncSrv := e.vncSrv
|
||||
e.syncMsgMux.Unlock()
|
||||
if vncSrv == nil {
|
||||
return false, nil
|
||||
}
|
||||
return true, vncSrv.ActiveSessions()
|
||||
}
|
||||
|
||||
func (e *Engine) stopVNCServer() error {
|
||||
if e.vncSrv == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.cleanupVNCPortRedirection(); err != nil {
|
||||
log.Warnf("cleanup VNC port redirection: %v", err)
|
||||
}
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.UnregisterNetstackService(nftypes.TCP, vnc.InternalPort)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("stopping VNC server")
|
||||
err := e.vncSrv.Stop()
|
||||
e.vncSrv = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop VNC server: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// vncApprover adapts the generic approval.Broker for the VNC server.
|
||||
type vncApprover struct {
|
||||
broker *approval.Broker
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
func (a *vncApprover) Request(ctx context.Context, info vncserver.ApprovalInfo) (vncserver.ApprovalDecision, error) {
|
||||
// Resolve the source overlay IP to a peer FQDN for the prompt label.
|
||||
if info.PeerName == "" && info.SourceIP != "" && a.statusRecorder != nil {
|
||||
if fqdn, ok := a.statusRecorder.PeerByIP(info.SourceIP); ok {
|
||||
info.PeerName = fqdn
|
||||
}
|
||||
}
|
||||
subject := fmt.Sprintf("VNC connection from %s", displayPeer(info))
|
||||
meta := map[string]string{
|
||||
"peer_name": info.PeerName,
|
||||
"peer_pubkey": info.PeerPubKey,
|
||||
"source_ip": info.SourceIP,
|
||||
"mode": info.Mode,
|
||||
"username": info.Username,
|
||||
"initiator": info.Initiator,
|
||||
}
|
||||
d, err := a.broker.Request(ctx, approval.Prompt{
|
||||
Kind: approval.KindVNC,
|
||||
Subject: subject,
|
||||
Metadata: meta,
|
||||
})
|
||||
if err != nil {
|
||||
return vncserver.ApprovalDecision{}, err
|
||||
}
|
||||
return vncserver.ApprovalDecision{ViewOnly: d.ViewOnly}, nil
|
||||
}
|
||||
|
||||
func displayPeer(info vncserver.ApprovalInfo) string {
|
||||
if info.Initiator != "" {
|
||||
return info.Initiator
|
||||
}
|
||||
if info.PeerName != "" {
|
||||
return info.PeerName
|
||||
}
|
||||
if info.SourceIP != "" {
|
||||
return info.SourceIP
|
||||
}
|
||||
if info.PeerPubKey != "" {
|
||||
return info.PeerPubKey
|
||||
}
|
||||
return "unknown peer"
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
//go:build freebsd
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
// newConsoleVNC builds the FreeBSD console fallback: vt(4) framebuffer
|
||||
// for capture, /dev/uinput for input. The uinput device requires the
|
||||
// `uinput` kernel module (`kldload uinput`); without it, input init
|
||||
// fails and we drop to a stub injector so the user still gets a
|
||||
// view-only screen mirror.
|
||||
func newConsoleVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
|
||||
poller := vncserver.NewFBPoller("")
|
||||
w, h := poller.Width(), poller.Height()
|
||||
if w == 0 || h == 0 {
|
||||
poller.Close()
|
||||
return nil, nil, fmt.Errorf("vt framebuffer init failed (vt may not allow mmap on this driver)")
|
||||
}
|
||||
if inj, err := vncserver.NewUInputInjector(w, h); err == nil {
|
||||
return poller, inj, nil
|
||||
} else {
|
||||
log.Infof("VNC console: uinput unavailable (%v); view-only mode. Run `kldload uinput` to enable input.", err)
|
||||
return poller, &vncserver.StubInputInjector{}, nil
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
// newConsoleVNC builds a framebuffer + uinput VNC backend for boxes
|
||||
// without a running X server. Used as the auto-fallback when
|
||||
// newPlatformVNC can't reach X. Returns an error when /dev/fb0 or
|
||||
// /dev/uinput aren't usable so the caller can drop back to a stub.
|
||||
func newConsoleVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
|
||||
poller := vncserver.NewFBPoller("")
|
||||
w, h := poller.Width(), poller.Height()
|
||||
if w == 0 || h == 0 {
|
||||
poller.Close()
|
||||
return nil, nil, fmt.Errorf("framebuffer capturer init failed (is /dev/fb0 readable?)")
|
||||
}
|
||||
inj, err := vncserver.NewUInputInjector(w, h)
|
||||
if err != nil {
|
||||
log.Debugf("uinput unavailable, falling back to view-only VNC: %v", err)
|
||||
return poller, &vncserver.StubInputInjector{}, nil
|
||||
}
|
||||
return poller, inj, nil
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
|
||||
capturer := vncserver.NewMacPoller()
|
||||
// Prompt for Screen Recording at server-enable time rather than first
|
||||
// client-connect. The native prompt is far easier for users to act on
|
||||
// in the moment they toggled VNC on than later when "the screen looks
|
||||
// like wallpaper" would otherwise be the only clue.
|
||||
vncserver.PrimeScreenCapturePermission()
|
||||
injector, err := vncserver.NewMacInputInjector()
|
||||
if err != nil {
|
||||
log.Debugf("VNC: macOS input injector: %v", err)
|
||||
return capturer, &vncserver.StubInputInjector{}, true
|
||||
}
|
||||
return capturer, injector, true
|
||||
}
|
||||
|
||||
// vncNeedsServiceMode reports whether the running process is a system
|
||||
// LaunchDaemon (root, parented by launchd). Daemons sit in the global
|
||||
// bootstrap namespace and cannot talk to WindowServer; we route capture
|
||||
// through a per-user agent in that case.
|
||||
func vncNeedsServiceMode() bool {
|
||||
return os.Geteuid() == 0 && os.Getppid() == 1
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
//go:build js || ios || android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type vncServer interface{}
|
||||
|
||||
func (e *Engine) updateVNC() error { return nil }
|
||||
|
||||
func (e *Engine) updateVNCServerAuth(auth *mgmProto.VNCAuth) {
|
||||
if auth == nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("ignoring VNC auth push on platform without a VNC server: %d session pubkeys, %d authorized users",
|
||||
len(auth.GetSessionPubKeys()), len(auth.GetAuthorizedUsers()))
|
||||
}
|
||||
|
||||
func (e *Engine) stopVNCServer() error { return nil }
|
||||
@@ -1,13 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package internal
|
||||
|
||||
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
|
||||
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector(), true
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return vncserver.GetCurrentSessionID() == 0
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
|
||||
// Prefer X11 when an X server is reachable. NewX11InputInjector probes
|
||||
// DISPLAY (and /proc) eagerly, so a non-nil error here means no X.
|
||||
injector, err := vncserver.NewX11InputInjector("", "", "")
|
||||
if err == nil {
|
||||
return vncserver.NewX11Poller("", ""), injector, true
|
||||
}
|
||||
log.Debugf("VNC: X11 not available: %v", err)
|
||||
|
||||
// Fallback for headless / pre-X states (kernel console, login manager
|
||||
// without X, physical server in recovery): stream the framebuffer and
|
||||
// inject input via /dev/uinput.
|
||||
consoleCap, consoleInj, err := newConsoleVNC()
|
||||
if err == nil {
|
||||
log.Infof("VNC: using framebuffer console capture (%dx%d)", consoleCap.Width(), consoleCap.Height())
|
||||
return consoleCap, consoleInj, true
|
||||
}
|
||||
log.Debugf("VNC: framebuffer console fallback unavailable: %v", err)
|
||||
|
||||
return &vncserver.StubCapturer{}, &vncserver.StubInputInjector{}, false
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return false
|
||||
}
|
||||
@@ -120,36 +120,6 @@ func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentI
|
||||
m.trimLocked()
|
||||
}
|
||||
|
||||
func (m *influxDBMetrics) RecordVNCSessionTick(_ context.Context, agentInfo AgentInfo, tick VNCSessionTick) {
|
||||
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s",
|
||||
agentInfo.DeploymentType.String(),
|
||||
agentInfo.Version,
|
||||
agentInfo.OS,
|
||||
agentInfo.Arch,
|
||||
agentInfo.peerID,
|
||||
)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.samples = append(m.samples, influxSample{
|
||||
measurement: "netbird_vnc_traffic",
|
||||
tags: tags,
|
||||
fields: map[string]float64{
|
||||
"period_seconds": tick.Period.Seconds(),
|
||||
"bytes_out": float64(tick.BytesOut),
|
||||
"writes": float64(tick.Writes),
|
||||
"fbus": float64(tick.FBUs),
|
||||
"max_fbu_bytes": float64(tick.MaxFBUBytes),
|
||||
"max_fbu_rects": float64(tick.MaxFBURects),
|
||||
"max_write_bytes": float64(tick.MaxWriteBytes),
|
||||
"write_time_seconds": float64(tick.WriteNanos) / 1e9,
|
||||
},
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
m.trimLocked()
|
||||
}
|
||||
|
||||
func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) {
|
||||
result := "success"
|
||||
if !success {
|
||||
|
||||
@@ -59,11 +59,6 @@ type metricsImplementation interface {
|
||||
// RecordLoginDuration records how long the login to management took
|
||||
RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool)
|
||||
|
||||
// RecordVNCSessionTick records a periodic snapshot of one VNC
|
||||
// session's wire activity. Called once per metricsConn tick interval
|
||||
// (and once at session close), only when the tick saw activity.
|
||||
RecordVNCSessionTick(ctx context.Context, agentInfo AgentInfo, tick VNCSessionTick)
|
||||
|
||||
// Export exports metrics in InfluxDB line protocol format
|
||||
Export(w io.Writer) error
|
||||
|
||||
@@ -83,21 +78,6 @@ type ClientMetrics struct {
|
||||
pushCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// VNCSessionTick is one sampling slice of a VNC session's wire activity.
|
||||
// BytesOut / Writes / FBUs / WriteNanos are deltas observed during this
|
||||
// tick; Max* fields are the high-water marks observed during the tick.
|
||||
// Period is the wall-clock duration the deltas cover.
|
||||
type VNCSessionTick struct {
|
||||
Period time.Duration
|
||||
BytesOut uint64
|
||||
Writes uint64
|
||||
FBUs uint64
|
||||
MaxFBUBytes uint64
|
||||
MaxFBURects uint64
|
||||
MaxWriteBytes uint64
|
||||
WriteNanos uint64
|
||||
}
|
||||
|
||||
// ConnectionStageTimestamps holds timestamps for each connection stage
|
||||
type ConnectionStageTimestamps struct {
|
||||
SignalingReceived time.Time // First signal received from remote peer (both initial and reconnection)
|
||||
@@ -147,17 +127,6 @@ func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Du
|
||||
c.impl.RecordSyncDuration(ctx, agentInfo, duration)
|
||||
}
|
||||
|
||||
// RecordVNCSessionTick records a periodic snapshot of one VNC session.
|
||||
func (c *ClientMetrics) RecordVNCSessionTick(ctx context.Context, tick VNCSessionTick) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.mu.RLock()
|
||||
agentInfo := c.agentInfo
|
||||
c.mu.RUnlock()
|
||||
c.impl.RecordVNCSessionTick(ctx, agentInfo, tick)
|
||||
}
|
||||
|
||||
// RecordLoginDuration records how long the login to management server took
|
||||
func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) {
|
||||
if c == nil {
|
||||
|
||||
@@ -73,9 +73,6 @@ func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.
|
||||
func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) {
|
||||
}
|
||||
|
||||
func (m *mockMetrics) RecordVNCSessionTick(_ context.Context, _ AgentInfo, _ VNCSessionTick) {
|
||||
}
|
||||
|
||||
func (m *mockMetrics) Export(w io.Writer) error {
|
||||
if m.exportData != "" {
|
||||
_, err := w.Write([]byte(m.exportData))
|
||||
|
||||
@@ -1191,15 +1191,6 @@ func (d *Status) SubscribeToEvents() *EventSubscription {
|
||||
}
|
||||
}
|
||||
|
||||
// HasEventSubscribers reports whether any client is currently subscribed
|
||||
// to the daemon's SystemEvent stream. Used by the VNC approval broker to
|
||||
// fail closed when no UI is connected to prompt the user.
|
||||
func (d *Status) HasEventSubscribers() bool {
|
||||
d.eventMux.Lock()
|
||||
defer d.eventMux.Unlock()
|
||||
return len(d.eventStreams) > 0
|
||||
}
|
||||
|
||||
// UnsubscribeFromEvents removes an event subscription
|
||||
func (d *Status) UnsubscribeFromEvents(sub *EventSubscription) {
|
||||
if sub == nil {
|
||||
|
||||
@@ -65,8 +65,6 @@ type ConfigInput struct {
|
||||
StateFilePath string
|
||||
PreSharedKey *string
|
||||
ServerSSHAllowed *bool
|
||||
ServerVNCAllowed *bool
|
||||
DisableVNCApproval *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
@@ -118,8 +116,6 @@ type Config struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed *bool
|
||||
ServerVNCAllowed *bool
|
||||
DisableVNCApproval *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
@@ -422,33 +418,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.ServerVNCAllowed != nil {
|
||||
if config.ServerVNCAllowed == nil || *input.ServerVNCAllowed != *config.ServerVNCAllowed {
|
||||
if *input.ServerVNCAllowed {
|
||||
log.Infof("enabling VNC server")
|
||||
} else {
|
||||
log.Infof("disabling VNC server")
|
||||
}
|
||||
config.ServerVNCAllowed = input.ServerVNCAllowed
|
||||
updated = true
|
||||
}
|
||||
} else if config.ServerVNCAllowed == nil {
|
||||
config.ServerVNCAllowed = util.False()
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableVNCApproval != nil {
|
||||
if config.DisableVNCApproval == nil || *input.DisableVNCApproval != *config.DisableVNCApproval {
|
||||
if *input.DisableVNCApproval {
|
||||
log.Infof("disabling VNC connection approval prompt")
|
||||
} else {
|
||||
log.Infof("enabling VNC connection approval prompt")
|
||||
}
|
||||
config.DisableVNCApproval = input.DisableVNCApproval
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
|
||||
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
||||
if *input.EnableSSHRoot {
|
||||
log.Infof("enabling SSH root login")
|
||||
|
||||
@@ -188,9 +188,7 @@ func (d *Detector) triggerCallback(event EventType, cb func(event EventType), do
|
||||
}
|
||||
|
||||
doneChan := make(chan struct{})
|
||||
// macOS forces sleep ~30s after kIOMessageSystemWillSleep, so block long
|
||||
// enough for teardown to finish while staying under that deadline.
|
||||
timeout := time.NewTimer(20 * time.Second)
|
||||
timeout := time.NewTimer(500 * time.Millisecond)
|
||||
defer timeout.Stop()
|
||||
|
||||
go func() {
|
||||
|
||||
@@ -74,14 +74,6 @@ func New(filePath string) *Manager {
|
||||
}
|
||||
}
|
||||
|
||||
// FilePath returns the path of the underlying state file.
|
||||
func (m *Manager) FilePath() string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
return m.filePath
|
||||
}
|
||||
|
||||
// Start starts the state manager periodic save routine
|
||||
func (m *Manager) Start() {
|
||||
if m == nil {
|
||||
|
||||
@@ -64,13 +64,6 @@
|
||||
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
|
||||
</RegistryKey>
|
||||
</Component>
|
||||
<!-- Drop the HKCU Run\Netbird value written by legacy NSIS installers. -->
|
||||
<Component Id="NetbirdLegacyHKCUCleanup" Guid="*">
|
||||
<RegistryValue Root="HKCU" Key="Software\NetBird GmbH\Installer"
|
||||
Name="LegacyHKCUCleanup" Type="integer" Value="1" KeyPath="yes" />
|
||||
<RemoveRegistryValue Root="HKCU"
|
||||
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
|
||||
</Component>
|
||||
</StandardDirectory>
|
||||
|
||||
<StandardDirectory Id="CommonAppDataFolder">
|
||||
@@ -83,28 +76,10 @@
|
||||
</Directory>
|
||||
</StandardDirectory>
|
||||
|
||||
<!-- Drop Run, App Paths and Uninstall entries written by legacy NSIS
|
||||
installers into the 32-bit registry view (HKLM\Software\Wow6432Node). -->
|
||||
<Component Id="NetbirdLegacyWow6432Cleanup" Directory="NetbirdInstallDir"
|
||||
Guid="bda5d628-16bd-4086-b2c1-5099d8d51763" Bitness="always32">
|
||||
<RegistryValue Root="HKLM" Key="Software\NetBird GmbH\Installer"
|
||||
Name="LegacyWow6432Cleanup" Type="integer" Value="1" KeyPath="yes" />
|
||||
<RemoveRegistryValue Root="HKLM"
|
||||
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
|
||||
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
|
||||
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird" />
|
||||
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
|
||||
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird-ui" />
|
||||
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
|
||||
Key="Software\Microsoft\Windows\CurrentVersion\Uninstall\Netbird" />
|
||||
</Component>
|
||||
|
||||
<ComponentGroup Id="NetbirdFilesComponent">
|
||||
<ComponentRef Id="NetbirdFiles" />
|
||||
<ComponentRef Id="NetbirdAumidRegistry" />
|
||||
<ComponentRef Id="NetbirdAutoStart" />
|
||||
<ComponentRef Id="NetbirdLegacyHKCUCleanup" />
|
||||
<ComponentRef Id="NetbirdLegacyWow6432Cleanup" />
|
||||
</ComponentGroup>
|
||||
|
||||
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -119,14 +119,6 @@ service DaemonService {
|
||||
|
||||
// ExposeService exposes a local port via the NetBird reverse proxy
|
||||
rpc ExposeService(ExposeServiceRequest) returns (stream ExposeServiceEvent) {}
|
||||
|
||||
// RespondApproval delivers the user's accept/deny decision for a
|
||||
// pending user-approval prompt. The daemon pushes the prompt as a
|
||||
// SystemEvent with category APPROVAL and metadata key "request_id";
|
||||
// the UI calls this RPC with the same request_id to unblock whichever
|
||||
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
|
||||
// the UI which subsystem the prompt belongs to.
|
||||
rpc RespondApproval(RespondApprovalRequest) returns (RespondApprovalResponse) {}
|
||||
}
|
||||
|
||||
|
||||
@@ -213,10 +205,6 @@ message LoginRequest {
|
||||
optional bool disableSSHAuth = 38;
|
||||
optional int32 sshJWTCacheTTL = 39;
|
||||
optional bool disable_ipv6 = 40;
|
||||
|
||||
optional bool serverVNCAllowed = 41;
|
||||
|
||||
optional bool disableVNCApproval = 42;
|
||||
}
|
||||
|
||||
message LoginResponse {
|
||||
@@ -326,10 +314,6 @@ message GetConfigResponse {
|
||||
int32 sshJWTCacheTTL = 26;
|
||||
|
||||
bool disable_ipv6 = 27;
|
||||
|
||||
bool serverVNCAllowed = 28;
|
||||
|
||||
bool disableVNCApproval = 29;
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
@@ -410,25 +394,6 @@ message SSHServerState {
|
||||
repeated SSHSessionInfo sessions = 2;
|
||||
}
|
||||
|
||||
// VNCSessionInfo contains information about an active VNC session
|
||||
message VNCSessionInfo {
|
||||
string remoteAddress = 1;
|
||||
string mode = 2;
|
||||
string username = 3;
|
||||
// userID is the Noise-verified session identity (hashed user ID from
|
||||
// the ACL session-key entry), empty when auth is disabled.
|
||||
string userID = 4;
|
||||
// initiator is the human-readable display name of the dashboard user
|
||||
// who minted the SessionPubKey, when known.
|
||||
string initiator = 5;
|
||||
}
|
||||
|
||||
// VNCServerState contains the latest state of the VNC server
|
||||
message VNCServerState {
|
||||
bool enabled = 1;
|
||||
repeated VNCSessionInfo sessions = 2;
|
||||
}
|
||||
|
||||
// FullStatus contains the full state held by the Status instance
|
||||
message FullStatus {
|
||||
ManagementState managementState = 1;
|
||||
@@ -443,7 +408,6 @@ message FullStatus {
|
||||
|
||||
bool lazyConnectionEnabled = 9;
|
||||
SSHServerState sshServerState = 10;
|
||||
VNCServerState vncServerState = 11;
|
||||
}
|
||||
|
||||
// Networks
|
||||
@@ -631,7 +595,6 @@ message SystemEvent {
|
||||
AUTHENTICATION = 2;
|
||||
CONNECTIVITY = 3;
|
||||
SYSTEM = 4;
|
||||
APPROVAL = 5;
|
||||
}
|
||||
|
||||
string id = 1;
|
||||
@@ -715,10 +678,6 @@ message SetConfigRequest {
|
||||
optional bool disableSSHAuth = 33;
|
||||
optional int32 sshJWTCacheTTL = 34;
|
||||
optional bool disable_ipv6 = 35;
|
||||
|
||||
optional bool serverVNCAllowed = 36;
|
||||
|
||||
optional bool disableVNCApproval = 37;
|
||||
}
|
||||
|
||||
message SetConfigResponse{}
|
||||
@@ -913,18 +872,3 @@ message StartBundleCaptureRequest {
|
||||
message StartBundleCaptureResponse {}
|
||||
message StopBundleCaptureRequest {}
|
||||
message StopBundleCaptureResponse {}
|
||||
|
||||
message RespondApprovalRequest {
|
||||
// request_id matches the SystemEvent metadata key emitted by the daemon
|
||||
// when a subsystem awaits user approval for an inbound connection.
|
||||
string request_id = 1;
|
||||
// accept is true if the user approved the request, false if they
|
||||
// denied it. A missing or unknown request_id is treated as a no-op.
|
||||
bool accept = 2;
|
||||
// view_only signals that the user granted the connection but withheld
|
||||
// input control. Only meaningful when accept is true; ignored when
|
||||
// accept is false.
|
||||
bool view_only = 3;
|
||||
}
|
||||
|
||||
message RespondApprovalResponse {}
|
||||
|
||||
@@ -58,7 +58,6 @@ const (
|
||||
DaemonService_StopCPUProfile_FullMethodName = "/daemon.DaemonService/StopCPUProfile"
|
||||
DaemonService_GetInstallerResult_FullMethodName = "/daemon.DaemonService/GetInstallerResult"
|
||||
DaemonService_ExposeService_FullMethodName = "/daemon.DaemonService/ExposeService"
|
||||
DaemonService_RespondApproval_FullMethodName = "/daemon.DaemonService/RespondApproval"
|
||||
)
|
||||
|
||||
// DaemonServiceClient is the client API for DaemonService service.
|
||||
@@ -135,13 +134,6 @@ type DaemonServiceClient interface {
|
||||
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
|
||||
// ExposeService exposes a local port via the NetBird reverse proxy
|
||||
ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ExposeServiceEvent], error)
|
||||
// RespondApproval delivers the user's accept/deny decision for a
|
||||
// pending user-approval prompt. The daemon pushes the prompt as a
|
||||
// SystemEvent with category APPROVAL and metadata key "request_id";
|
||||
// the UI calls this RPC with the same request_id to unblock whichever
|
||||
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
|
||||
// the UI which subsystem the prompt belongs to.
|
||||
RespondApproval(ctx context.Context, in *RespondApprovalRequest, opts ...grpc.CallOption) (*RespondApprovalResponse, error)
|
||||
}
|
||||
|
||||
type daemonServiceClient struct {
|
||||
@@ -569,16 +561,6 @@ func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServi
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type DaemonService_ExposeServiceClient = grpc.ServerStreamingClient[ExposeServiceEvent]
|
||||
|
||||
func (c *daemonServiceClient) RespondApproval(ctx context.Context, in *RespondApprovalRequest, opts ...grpc.CallOption) (*RespondApprovalResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(RespondApprovalResponse)
|
||||
err := c.cc.Invoke(ctx, DaemonService_RespondApproval_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// DaemonServiceServer is the server API for DaemonService service.
|
||||
// All implementations must embed UnimplementedDaemonServiceServer
|
||||
// for forward compatibility.
|
||||
@@ -653,13 +635,6 @@ type DaemonServiceServer interface {
|
||||
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
|
||||
// ExposeService exposes a local port via the NetBird reverse proxy
|
||||
ExposeService(*ExposeServiceRequest, grpc.ServerStreamingServer[ExposeServiceEvent]) error
|
||||
// RespondApproval delivers the user's accept/deny decision for a
|
||||
// pending user-approval prompt. The daemon pushes the prompt as a
|
||||
// SystemEvent with category APPROVAL and metadata key "request_id";
|
||||
// the UI calls this RPC with the same request_id to unblock whichever
|
||||
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
|
||||
// the UI which subsystem the prompt belongs to.
|
||||
RespondApproval(context.Context, *RespondApprovalRequest) (*RespondApprovalResponse, error)
|
||||
mustEmbedUnimplementedDaemonServiceServer()
|
||||
}
|
||||
|
||||
@@ -787,9 +762,6 @@ func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *Ins
|
||||
func (UnimplementedDaemonServiceServer) ExposeService(*ExposeServiceRequest, grpc.ServerStreamingServer[ExposeServiceEvent]) error {
|
||||
return status.Error(codes.Unimplemented, "method ExposeService not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) RespondApproval(context.Context, *RespondApprovalRequest) (*RespondApprovalResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method RespondApproval not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||
func (UnimplementedDaemonServiceServer) testEmbeddedByValue() {}
|
||||
|
||||
@@ -1492,24 +1464,6 @@ func _DaemonService_ExposeService_Handler(srv interface{}, stream grpc.ServerStr
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type DaemonService_ExposeServiceServer = grpc.ServerStreamingServer[ExposeServiceEvent]
|
||||
|
||||
func _DaemonService_RespondApproval_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RespondApprovalRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).RespondApproval(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: DaemonService_RespondApproval_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).RespondApproval(ctx, req.(*RespondApprovalRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
@@ -1661,10 +1615,6 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "GetInstallerResult",
|
||||
Handler: _DaemonService_GetInstallerResult_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "RespondApproval",
|
||||
Handler: _DaemonService_RespondApproval_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
|
||||
@@ -111,7 +111,7 @@ func (s *Server) StartCapture(req *proto.StartCaptureRequest, stream proto.Daemo
|
||||
return status.Errorf(codes.Internal, "create capture session: %v", err)
|
||||
}
|
||||
|
||||
engine, err := s.claimCapture(sess, func() { pw.Close() })
|
||||
engine, err := s.claimCapture(sess)
|
||||
if err != nil {
|
||||
sess.Stop()
|
||||
pw.Close()
|
||||
@@ -190,7 +190,10 @@ func (s *Server) StartBundleCapture(_ context.Context, req *proto.StartBundleCap
|
||||
|
||||
s.stopBundleCaptureLocked()
|
||||
s.cleanupBundleCapture()
|
||||
s.evictActiveCaptureLocked()
|
||||
|
||||
if s.activeCapture != nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
|
||||
}
|
||||
|
||||
engine, err := s.getCaptureEngineLocked()
|
||||
if err != nil {
|
||||
@@ -301,58 +304,29 @@ func (s *Server) cleanupBundleCapture() {
|
||||
s.bundleCapture = nil
|
||||
}
|
||||
|
||||
// claimCapture reserves the engine's capture slot for sess. If another
|
||||
// capture is already running it is evicted: a previous streaming session
|
||||
// whose gRPC client died and never freed the slot stays stuck otherwise,
|
||||
// and a bundle capture is just informational state.
|
||||
func (s *Server) claimCapture(sess *capture.Session, cancel func()) (*internal.Engine, error) {
|
||||
// claimCapture reserves the engine's capture slot for sess. Returns
|
||||
// FailedPrecondition if another capture is already active.
|
||||
func (s *Server) claimCapture(sess *capture.Session) (*internal.Engine, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
s.evictActiveCaptureLocked()
|
||||
if s.activeCapture != nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
|
||||
}
|
||||
engine, err := s.getCaptureEngineLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.activeCapture = sess
|
||||
s.activeCaptureCancel = cancel
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
// evictActiveCaptureLocked tears down whatever capture currently owns
|
||||
// the engine slot so a fresh claim can succeed. Caller must hold mutex.
|
||||
func (s *Server) evictActiveCaptureLocked() {
|
||||
if s.activeCapture == nil {
|
||||
return
|
||||
}
|
||||
if s.bundleCapture != nil && s.bundleCapture.sess == s.activeCapture {
|
||||
log.Infof("evicting running bundle capture to start a new capture")
|
||||
s.stopBundleCaptureLocked()
|
||||
return
|
||||
}
|
||||
log.Infof("evicting previous streaming capture to start a new one")
|
||||
prev := s.activeCapture
|
||||
cancel := s.activeCaptureCancel
|
||||
if engine, err := s.getCaptureEngineLocked(); err == nil {
|
||||
if err := engine.SetCapture(nil); err != nil {
|
||||
log.Debugf("clear previous capture: %v", err)
|
||||
}
|
||||
}
|
||||
s.activeCapture = nil
|
||||
s.activeCaptureCancel = nil
|
||||
prev.Stop()
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// releaseCapture clears the active-capture owner if it still matches sess.
|
||||
func (s *Server) releaseCapture(sess *capture.Session) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
if s.activeCapture == sess {
|
||||
s.activeCapture = nil
|
||||
s.activeCaptureCancel = nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -367,7 +341,6 @@ func (s *Server) clearCaptureIfOwner(sess *capture.Session, engine *internal.Eng
|
||||
log.Debugf("clear capture: %v", err)
|
||||
}
|
||||
s.activeCapture = nil
|
||||
s.activeCaptureCancel = nil
|
||||
}
|
||||
|
||||
func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) {
|
||||
|
||||
@@ -93,12 +93,8 @@ type Server struct {
|
||||
captureEnabled bool
|
||||
bundleCapture *bundleCapture
|
||||
// activeCapture is the session currently installed on the engine; guarded by s.mutex.
|
||||
activeCapture *capture.Session
|
||||
// activeCaptureCancel tears down the streaming pipe/cancel for the
|
||||
// active streaming capture so eviction unblocks the StartCapture RPC
|
||||
// handler. Nil for bundle captures (they own their own context).
|
||||
activeCaptureCancel func()
|
||||
networksDisabled bool
|
||||
activeCapture *capture.Session
|
||||
networksDisabled bool
|
||||
|
||||
sleepHandler *sleephandler.SleepHandler
|
||||
|
||||
@@ -380,8 +376,6 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
config.RosenpassPermissive = msg.RosenpassPermissive
|
||||
config.DisableAutoConnect = msg.DisableAutoConnect
|
||||
config.ServerSSHAllowed = msg.ServerSSHAllowed
|
||||
config.ServerVNCAllowed = msg.ServerVNCAllowed
|
||||
config.DisableVNCApproval = msg.DisableVNCApproval
|
||||
config.NetworkMonitor = msg.NetworkMonitor
|
||||
config.DisableClientRoutes = msg.DisableClientRoutes
|
||||
config.DisableServerRoutes = msg.DisableServerRoutes
|
||||
@@ -1142,7 +1136,6 @@ func (s *Server) Status(
|
||||
pbFullStatus := fullStatus.ToProto()
|
||||
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||
pbFullStatus.SshServerState = s.getSSHServerState()
|
||||
pbFullStatus.VncServerState = s.getVNCServerState()
|
||||
statusResponse.FullStatus = pbFullStatus
|
||||
}
|
||||
|
||||
@@ -1182,38 +1175,6 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
|
||||
return sshServerState
|
||||
}
|
||||
|
||||
// getVNCServerState retrieves the current VNC server state.
|
||||
func (s *Server) getVNCServerState() *proto.VNCServerState {
|
||||
s.mutex.Lock()
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
enabled, sessions := engine.GetVNCServerStatus()
|
||||
pbSessions := make([]*proto.VNCSessionInfo, 0, len(sessions))
|
||||
for _, sess := range sessions {
|
||||
pbSessions = append(pbSessions, &proto.VNCSessionInfo{
|
||||
RemoteAddress: sess.RemoteAddress,
|
||||
Mode: sess.Mode,
|
||||
Username: sess.Username,
|
||||
UserID: sess.UserID,
|
||||
Initiator: sess.Initiator,
|
||||
})
|
||||
}
|
||||
return &proto.VNCServerState{
|
||||
Enabled: enabled,
|
||||
Sessions: pbSessions,
|
||||
}
|
||||
}
|
||||
|
||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||
func (s *Server) GetPeerSSHHostKey(
|
||||
ctx context.Context,
|
||||
@@ -1454,27 +1415,6 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon
|
||||
return nil
|
||||
}
|
||||
|
||||
// RespondApproval relays the user's accept/deny decision for a pending
|
||||
// approval prompt to the engine's broker. Unknown or already-resolved
|
||||
// request_ids are silently no-op'd so a slow UI cannot deny a prompt the
|
||||
// user already handled (or that already timed out).
|
||||
func (s *Server) RespondApproval(_ context.Context, msg *proto.RespondApprovalRequest) (*proto.RespondApprovalResponse, error) {
|
||||
s.mutex.Lock()
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
if connectClient == nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "client not initialized")
|
||||
}
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "engine not running")
|
||||
}
|
||||
if !engine.RespondApproval(msg.GetRequestId(), msg.GetAccept(), msg.GetViewOnly()) {
|
||||
log.Debugf("approval response for unknown request_id %s", msg.GetRequestId())
|
||||
}
|
||||
return &proto.RespondApprovalResponse{}, nil
|
||||
}
|
||||
|
||||
func isUnixRunningDesktop() bool {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
return false
|
||||
@@ -1591,8 +1531,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
Mtu: int64(cfg.MTU),
|
||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||
ServerVNCAllowed: cfg.ServerVNCAllowed != nil && *cfg.ServerVNCAllowed,
|
||||
DisableVNCApproval: cfg.DisableVNCApproval != nil && *cfg.DisableVNCApproval,
|
||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||
|
||||
@@ -58,8 +58,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
rosenpassEnabled := true
|
||||
rosenpassPermissive := true
|
||||
serverSSHAllowed := true
|
||||
serverVNCAllowed := true
|
||||
disableVNCApproval := true
|
||||
interfaceName := "utun100"
|
||||
wireguardPort := int64(51820)
|
||||
preSharedKey := "test-psk"
|
||||
@@ -85,8 +83,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
RosenpassPermissive: &rosenpassPermissive,
|
||||
ServerSSHAllowed: &serverSSHAllowed,
|
||||
ServerVNCAllowed: &serverVNCAllowed,
|
||||
DisableVNCApproval: &disableVNCApproval,
|
||||
InterfaceName: &interfaceName,
|
||||
WireguardPort: &wireguardPort,
|
||||
OptionalPreSharedKey: &preSharedKey,
|
||||
@@ -131,10 +127,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
|
||||
require.NotNil(t, cfg.ServerSSHAllowed)
|
||||
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
|
||||
require.NotNil(t, cfg.ServerVNCAllowed)
|
||||
require.Equal(t, serverVNCAllowed, *cfg.ServerVNCAllowed)
|
||||
require.NotNil(t, cfg.DisableVNCApproval)
|
||||
require.Equal(t, disableVNCApproval, *cfg.DisableVNCApproval)
|
||||
require.Equal(t, interfaceName, cfg.WgIface)
|
||||
require.Equal(t, int(wireguardPort), cfg.WgPort)
|
||||
require.Equal(t, preSharedKey, cfg.PreSharedKey)
|
||||
@@ -187,8 +179,6 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
||||
"RosenpassEnabled": true,
|
||||
"RosenpassPermissive": true,
|
||||
"ServerSSHAllowed": true,
|
||||
"ServerVNCAllowed": true,
|
||||
"DisableVNCApproval": true,
|
||||
"InterfaceName": true,
|
||||
"WireguardPort": true,
|
||||
"OptionalPreSharedKey": true,
|
||||
@@ -250,8 +240,6 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
||||
"enable-rosenpass": "RosenpassEnabled",
|
||||
"rosenpass-permissive": "RosenpassPermissive",
|
||||
"allow-server-ssh": "ServerSSHAllowed",
|
||||
"allow-server-vnc": "ServerVNCAllowed",
|
||||
"disable-vnc-approval": "DisableVNCApproval",
|
||||
"interface-name": "InterfaceName",
|
||||
"wireguard-port": "WireguardPort",
|
||||
"preshared-key": "OptionalPreSharedKey",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package sessionauth
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -15,16 +15,13 @@ const (
|
||||
DefaultUserIDClaim = "sub"
|
||||
// Wildcard is a special user ID that matches all users
|
||||
Wildcard = "*"
|
||||
// sessionPubKeyLen is the size of an X25519 static public key in bytes.
|
||||
sessionPubKeyLen = 32
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyUserID = errors.New("JWT user ID is empty")
|
||||
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
|
||||
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
|
||||
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
|
||||
ErrSessionKeyNotKnown = errors.New("session pubkey not registered")
|
||||
ErrEmptyUserID = errors.New("JWT user ID is empty")
|
||||
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
|
||||
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
|
||||
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
|
||||
)
|
||||
|
||||
// Authorizer handles SSH fine-grained access control authorization
|
||||
@@ -38,17 +35,6 @@ type Authorizer struct {
|
||||
// machineUsers maps OS login usernames to lists of authorized user indexes
|
||||
machineUsers map[string][]uint32
|
||||
|
||||
// sessionPubKeys maps an X25519 static public key (as map-safe
|
||||
// array) to the hashed user identity that key authenticates as.
|
||||
// Populated from management's temporary-access flow; used by VNC to
|
||||
// authenticate via the Noise_IK handshake.
|
||||
sessionPubKeys map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash
|
||||
// sessionDisplayNames mirrors sessionPubKeys with the optional
|
||||
// human-readable display name management associated with each
|
||||
// session key. Used by the per-connection UI approval prompt; not
|
||||
// consulted by any authorization decision.
|
||||
sessionDisplayNames map[[sessionPubKeyLen]byte]string
|
||||
|
||||
// mu protects the list of users
|
||||
mu sync.RWMutex
|
||||
}
|
||||
@@ -64,29 +50,13 @@ type Config struct {
|
||||
// MachineUsers maps OS login usernames to indexes in AuthorizedUsers
|
||||
// If a user wants to login as a specific OS user, their index must be in the corresponding list
|
||||
MachineUsers map[string][]uint32
|
||||
|
||||
// SessionPubKeys binds ephemeral X25519 static public keys to hashed
|
||||
// user identities. Populated for VNC; ignored on the SSH side.
|
||||
SessionPubKeys []SessionPubKey
|
||||
}
|
||||
|
||||
// SessionPubKey is a single ephemeral-key entry: the 32-byte X25519
|
||||
// static public key plus the hashed user identity it authenticates as,
|
||||
// optionally plus a human-readable display name for the UI approval
|
||||
// prompt to identify the requester.
|
||||
type SessionPubKey struct {
|
||||
PubKey []byte
|
||||
UserIDHash sshuserhash.UserIDHash
|
||||
DisplayName string
|
||||
}
|
||||
|
||||
// NewAuthorizer creates a new SSH authorizer with empty configuration
|
||||
func NewAuthorizer() *Authorizer {
|
||||
a := &Authorizer{
|
||||
userIDClaim: DefaultUserIDClaim,
|
||||
machineUsers: make(map[string][]uint32),
|
||||
sessionPubKeys: make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash),
|
||||
sessionDisplayNames: make(map[[sessionPubKeyLen]byte]string),
|
||||
userIDClaim: DefaultUserIDClaim,
|
||||
machineUsers: make(map[string][]uint32),
|
||||
}
|
||||
|
||||
return a
|
||||
@@ -102,8 +72,6 @@ func (a *Authorizer) Update(config *Config) {
|
||||
a.userIDClaim = DefaultUserIDClaim
|
||||
a.authorizedUsers = []sshuserhash.UserIDHash{}
|
||||
a.machineUsers = make(map[string][]uint32)
|
||||
a.sessionPubKeys = make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash)
|
||||
a.sessionDisplayNames = make(map[[sessionPubKeyLen]byte]string)
|
||||
log.Info("SSH authorization cleared")
|
||||
return
|
||||
}
|
||||
@@ -126,35 +94,8 @@ func (a *Authorizer) Update(config *Config) {
|
||||
}
|
||||
a.machineUsers = machineUsers
|
||||
|
||||
sessionPubKeys := make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash, len(config.SessionPubKeys))
|
||||
sessionDisplayNames := make(map[[sessionPubKeyLen]byte]string, len(config.SessionPubKeys))
|
||||
conflicted := make(map[[sessionPubKeyLen]byte]struct{})
|
||||
for _, e := range config.SessionPubKeys {
|
||||
if len(e.PubKey) != sessionPubKeyLen {
|
||||
continue
|
||||
}
|
||||
var key [sessionPubKeyLen]byte
|
||||
copy(key[:], e.PubKey)
|
||||
if _, bad := conflicted[key]; bad {
|
||||
continue
|
||||
}
|
||||
if existing, ok := sessionPubKeys[key]; ok && existing != e.UserIDHash {
|
||||
log.Warnf("SSH auth: session pubkey bound to conflicting user hashes; dropping binding")
|
||||
delete(sessionPubKeys, key)
|
||||
delete(sessionDisplayNames, key)
|
||||
conflicted[key] = struct{}{}
|
||||
continue
|
||||
}
|
||||
sessionPubKeys[key] = e.UserIDHash
|
||||
if e.DisplayName != "" {
|
||||
sessionDisplayNames[key] = e.DisplayName
|
||||
}
|
||||
}
|
||||
a.sessionPubKeys = sessionPubKeys
|
||||
a.sessionDisplayNames = sessionDisplayNames
|
||||
|
||||
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings, %d session pubkeys",
|
||||
len(config.AuthorizedUsers), len(machineUsers), len(sessionPubKeys))
|
||||
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings",
|
||||
len(config.AuthorizedUsers), len(machineUsers))
|
||||
}
|
||||
|
||||
// Authorize validates if a user is authorized to login as the specified OS user.
|
||||
@@ -214,54 +155,6 @@ func (a *Authorizer) GetUserIDClaim() string {
|
||||
return a.userIDClaim
|
||||
}
|
||||
|
||||
// LookupSessionKey resolves a Noise-verified static public key to the
|
||||
// hashed user identity registered with it. Fails closed when the key is
|
||||
// unknown.
|
||||
func (a *Authorizer) LookupSessionKey(pubKey []byte) (sshuserhash.UserIDHash, error) {
|
||||
var zero sshuserhash.UserIDHash
|
||||
if len(pubKey) != sessionPubKeyLen {
|
||||
return zero, fmt.Errorf("session pubkey wrong length: %d", len(pubKey))
|
||||
}
|
||||
var key [sessionPubKeyLen]byte
|
||||
copy(key[:], pubKey)
|
||||
a.mu.RLock()
|
||||
hash, ok := a.sessionPubKeys[key]
|
||||
a.mu.RUnlock()
|
||||
if !ok {
|
||||
return zero, ErrSessionKeyNotKnown
|
||||
}
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
// LookupSessionDisplayName returns the human-readable display name
|
||||
// management associated with a session pubkey, or empty string when none
|
||||
// is recorded. Never returns an error: a missing/unknown key reports as
|
||||
// "" and the caller falls back to other identifiers.
|
||||
func (a *Authorizer) LookupSessionDisplayName(pubKey []byte) string {
|
||||
if len(pubKey) != sessionPubKeyLen {
|
||||
return ""
|
||||
}
|
||||
var key [sessionPubKeyLen]byte
|
||||
copy(key[:], pubKey)
|
||||
a.mu.RLock()
|
||||
name := a.sessionDisplayNames[key]
|
||||
a.mu.RUnlock()
|
||||
return name
|
||||
}
|
||||
|
||||
// AuthorizeOSUserBySessionKey resolves the OS-user mapping for a session
|
||||
// key. Mirrors Authorize but skips the JWT-hash step since the key has
|
||||
// already been verified and the user identity hash is in hand.
|
||||
func (a *Authorizer) AuthorizeOSUserBySessionKey(userIDHash sshuserhash.UserIDHash, osUsername string) (string, error) {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
userIndex, found := a.findUserIndex(userIDHash)
|
||||
if !found {
|
||||
return "", fmt.Errorf("session user (hash: %s) not in authorized list for OS user %q: %w", userIDHash, osUsername, ErrUserNotAuthorized)
|
||||
}
|
||||
return a.checkMachineUserMapping("session", osUsername, userIndex)
|
||||
}
|
||||
|
||||
// findUserIndex finds the index of a hashed user ID in the authorized users list
|
||||
// Returns the index and true if found, 0 and false if not found
|
||||
func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) {
|
||||
@@ -1,7 +1,6 @@
|
||||
package sessionauth
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -611,61 +610,3 @@ func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
|
||||
}
|
||||
|
||||
func TestAuthorizer_LookupSessionKey_Valid(t *testing.T) {
|
||||
pub := bytesRepeat(0x11, sessionPubKeyLen)
|
||||
userHash, err := sshauth.HashUserID("alice")
|
||||
require.NoError(t, err)
|
||||
|
||||
a := NewAuthorizer()
|
||||
a.Update(&Config{
|
||||
AuthorizedUsers: []sshauth.UserIDHash{userHash},
|
||||
MachineUsers: map[string][]uint32{Wildcard: {0}},
|
||||
SessionPubKeys: []SessionPubKey{{PubKey: pub, UserIDHash: userHash}},
|
||||
})
|
||||
|
||||
got, err := a.LookupSessionKey(pub)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userHash, got)
|
||||
|
||||
if _, err := a.AuthorizeOSUserBySessionKey(got, "alice"); err != nil {
|
||||
t.Fatalf("AuthorizeOSUserBySessionKey: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizer_LookupSessionKey_UnknownPub(t *testing.T) {
|
||||
a := NewAuthorizer()
|
||||
a.Update(&Config{})
|
||||
_, err := a.LookupSessionKey(bytesRepeat(0x22, sessionPubKeyLen))
|
||||
require.ErrorIs(t, err, ErrSessionKeyNotKnown)
|
||||
}
|
||||
|
||||
func TestAuthorizer_LookupSessionKey_WrongLength(t *testing.T) {
|
||||
a := NewAuthorizer()
|
||||
_, err := a.LookupSessionKey([]byte("short"))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAuthorizer_LookupSessionKey_UpdateClears(t *testing.T) {
|
||||
pub := bytesRepeat(0x33, sessionPubKeyLen)
|
||||
userHash, err := sshauth.HashUserID("alice")
|
||||
require.NoError(t, err)
|
||||
|
||||
a := NewAuthorizer()
|
||||
a.Update(&Config{SessionPubKeys: []SessionPubKey{{PubKey: pub, UserIDHash: userHash}}})
|
||||
if _, err := a.LookupSessionKey(pub); err != nil {
|
||||
t.Fatalf("setup lookup: %v", err)
|
||||
}
|
||||
a.Update(&Config{})
|
||||
if _, err := a.LookupSessionKey(pub); !errors.Is(err, ErrSessionKeyNotKnown) {
|
||||
t.Fatalf("expected ErrSessionKeyNotKnown, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func bytesRepeat(b byte, n int) []byte {
|
||||
out := make([]byte, n)
|
||||
for i := range out {
|
||||
out[i] = b
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -28,7 +28,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
"github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
"github.com/netbirdio/netbird/client/ssh/client"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
@@ -197,14 +197,6 @@ type Config struct {
|
||||
|
||||
// HostKey is the SSH server host key in PEM format
|
||||
HostKeyPEM []byte
|
||||
|
||||
// NetstackNet, when non-nil, makes the SSH server listen via the
|
||||
// supplied userspace network stack instead of an OS socket.
|
||||
NetstackNet *netstack.Net
|
||||
|
||||
// NetworkValidation, when non-zero, restricts inbound connections to
|
||||
// peers inside the NetBird overlay defined by this WireGuard address.
|
||||
NetworkValidation wgaddr.Address
|
||||
}
|
||||
|
||||
// SessionInfo contains information about an active SSH session
|
||||
@@ -216,15 +208,12 @@ type SessionInfo struct {
|
||||
PortForwards []string
|
||||
}
|
||||
|
||||
// New creates an SSH server instance from the supplied Config. Fields are
|
||||
// read once at construction; mutating Config afterwards has no effect.
|
||||
// JWT == nil disables JWT authentication.
|
||||
// New creates an SSH server instance with the provided host key and optional JWT configuration
|
||||
// If jwtConfig is nil, JWT authentication is disabled
|
||||
func New(config *Config) *Server {
|
||||
s := &Server{
|
||||
mu: sync.RWMutex{},
|
||||
hostKeyPEM: config.HostKeyPEM,
|
||||
netstackNet: config.NetstackNet,
|
||||
wgAddress: config.NetworkValidation,
|
||||
sessions: make(map[sessionKey]*sessionState),
|
||||
pendingAuthJWT: make(map[authKey]string),
|
||||
remoteForwardListeners: make(map[forwardKey]net.Listener),
|
||||
@@ -445,6 +434,20 @@ func (s *Server) buildSessionInfo(state *sessionState) SessionInfo {
|
||||
return info
|
||||
}
|
||||
|
||||
// SetNetstackNet sets the netstack network for userspace networking
|
||||
func (s *Server) SetNetstackNet(net *netstack.Net) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.netstackNet = net
|
||||
}
|
||||
|
||||
// SetNetworkValidation configures network-based connection filtering
|
||||
func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.wgAddress = addr
|
||||
}
|
||||
|
||||
// UpdateSSHAuth updates the SSH fine-grained access control configuration
|
||||
// This should be called when network map updates include new SSH auth configuration
|
||||
func (s *Server) UpdateSSHAuth(config *sshauth.Config) {
|
||||
|
||||
@@ -131,19 +131,6 @@ type SSHServerStateOutput struct {
|
||||
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
|
||||
}
|
||||
|
||||
type VNCSessionOutput struct {
|
||||
RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"`
|
||||
Mode string `json:"mode" yaml:"mode"`
|
||||
Username string `json:"username,omitempty" yaml:"username,omitempty"`
|
||||
UserID string `json:"userID,omitempty" yaml:"userID,omitempty"`
|
||||
Initiator string `json:"initiator,omitempty" yaml:"initiator,omitempty"`
|
||||
}
|
||||
|
||||
type VNCServerStateOutput struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Sessions []VNCSessionOutput `json:"sessions" yaml:"sessions"`
|
||||
}
|
||||
|
||||
type OutputOverview struct {
|
||||
Peers PeersStateOutput `json:"peers" yaml:"peers"`
|
||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||
@@ -166,7 +153,6 @@ type OutputOverview struct {
|
||||
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
|
||||
ProfileName string `json:"profileName" yaml:"profileName"`
|
||||
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
|
||||
VNCServerState VNCServerStateOutput `json:"vncServer" yaml:"vncServer"`
|
||||
}
|
||||
|
||||
// ConvertToStatusOutputOverview converts protobuf status to the output overview.
|
||||
@@ -187,7 +173,6 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
|
||||
|
||||
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
||||
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
|
||||
vncServerOverview := mapVNCServer(pbFullStatus.GetVncServerState())
|
||||
peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter)
|
||||
|
||||
overview := OutputOverview{
|
||||
@@ -212,7 +197,6 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
|
||||
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
|
||||
ProfileName: opts.ProfileName,
|
||||
SSHServerState: sshServerOverview,
|
||||
VNCServerState: vncServerOverview,
|
||||
}
|
||||
|
||||
if opts.Anonymize {
|
||||
@@ -287,26 +271,6 @@ func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput {
|
||||
}
|
||||
}
|
||||
|
||||
func mapVNCServer(state *proto.VNCServerState) VNCServerStateOutput {
|
||||
if state == nil {
|
||||
return VNCServerStateOutput{Sessions: []VNCSessionOutput{}}
|
||||
}
|
||||
sessions := make([]VNCSessionOutput, 0, len(state.GetSessions()))
|
||||
for _, sess := range state.GetSessions() {
|
||||
sessions = append(sessions, VNCSessionOutput{
|
||||
RemoteAddress: sess.GetRemoteAddress(),
|
||||
Mode: sess.GetMode(),
|
||||
Username: sess.GetUsername(),
|
||||
UserID: sess.GetUserID(),
|
||||
Initiator: sess.GetInitiator(),
|
||||
})
|
||||
}
|
||||
return VNCServerStateOutput{
|
||||
Enabled: state.GetEnabled(),
|
||||
Sessions: sessions,
|
||||
}
|
||||
}
|
||||
|
||||
func mapPeers(
|
||||
peers []*proto.PeerState,
|
||||
statusFilter string,
|
||||
@@ -569,26 +533,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
}
|
||||
}
|
||||
|
||||
vncServerStatus := "Disabled"
|
||||
if o.VNCServerState.Enabled {
|
||||
vncSessionCount := len(o.VNCServerState.Sessions)
|
||||
if vncSessionCount > 0 {
|
||||
sessionWord := "session"
|
||||
if vncSessionCount > 1 {
|
||||
sessionWord = "sessions"
|
||||
}
|
||||
vncServerStatus = fmt.Sprintf("Enabled (%d active %s)", vncSessionCount, sessionWord)
|
||||
} else {
|
||||
vncServerStatus = "Enabled"
|
||||
}
|
||||
|
||||
if showSSHSessions && vncSessionCount > 0 {
|
||||
for _, sess := range o.VNCServerState.Sessions {
|
||||
vncServerStatus += "\n " + formatVNCSessionLine(sess)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
|
||||
|
||||
var forwardingRulesString string
|
||||
@@ -619,7 +563,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
"Quantum resistance: %s\n"+
|
||||
"Lazy connection: %s\n"+
|
||||
"SSH Server: %s\n"+
|
||||
"VNC Server: %s\n"+
|
||||
"Networks: %s\n"+
|
||||
"%s"+
|
||||
"Peers count: %s\n",
|
||||
@@ -638,7 +581,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
rosenpassEnabledStatus,
|
||||
lazyConnectionEnabledStatus,
|
||||
sshServerStatus,
|
||||
vncServerStatus,
|
||||
networks,
|
||||
forwardingRulesString,
|
||||
peersCountString,
|
||||
@@ -998,26 +940,6 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *PeerStateDetailOutput) {
|
||||
}
|
||||
}
|
||||
|
||||
// formatVNCSessionLine renders a single VNC session row for the detailed
|
||||
// status output. The leading slot identifies the initiator (display name
|
||||
// when known, hashed UserID otherwise); the post-arrow slot is the OS
|
||||
// user the session targets and is omitted in attach mode where the
|
||||
// destination is the current console user (unknown to the daemon).
|
||||
func formatVNCSessionLine(sess VNCSessionOutput) string {
|
||||
who := sess.Initiator
|
||||
if who == "" {
|
||||
who = sess.UserID
|
||||
}
|
||||
prefix := sess.RemoteAddress
|
||||
if who != "" {
|
||||
prefix = fmt.Sprintf("%s@%s", who, sess.RemoteAddress)
|
||||
}
|
||||
if sess.Username != "" {
|
||||
return fmt.Sprintf("[%s -> %s] mode=%s", prefix, sess.Username, sess.Mode)
|
||||
}
|
||||
return fmt.Sprintf("[%s] mode=%s", prefix, sess.Mode)
|
||||
}
|
||||
|
||||
func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
for i, peer := range overview.Peers.Details {
|
||||
peer := peer
|
||||
@@ -1038,19 +960,6 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
overview.Relays.Details[i] = detail
|
||||
}
|
||||
|
||||
anonymizeNSServerGroups(a, overview)
|
||||
|
||||
for i, route := range overview.Networks {
|
||||
overview.Networks[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||
|
||||
anonymizeEvents(a, overview)
|
||||
anonymizeServerSessions(a, overview)
|
||||
}
|
||||
|
||||
func anonymizeNSServerGroups(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
for i, nsGroup := range overview.NSServerGroups {
|
||||
for j, domain := range nsGroup.Domains {
|
||||
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
|
||||
@@ -1062,9 +971,13 @@ func anonymizeNSServerGroups(a *anonymize.Anonymizer, overview *OutputOverview)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeEvents(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
for i, route := range overview.Networks {
|
||||
overview.Networks[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||
|
||||
for i, event := range overview.Events {
|
||||
overview.Events[i].Message = a.AnonymizeString(event.Message)
|
||||
overview.Events[i].UserMessage = a.AnonymizeString(event.UserMessage)
|
||||
@@ -1073,24 +986,13 @@ func anonymizeEvents(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
event.Metadata[k] = a.AnonymizeString(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeRemoteAddress(a *anonymize.Anonymizer, addr string) string {
|
||||
if host, port, err := net.SplitHostPort(addr); err == nil {
|
||||
return fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
|
||||
}
|
||||
return a.AnonymizeIPString(addr)
|
||||
}
|
||||
|
||||
func anonymizeServerSessions(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
for i, session := range overview.SSHServerState.Sessions {
|
||||
overview.SSHServerState.Sessions[i].RemoteAddress = anonymizeRemoteAddress(a, session.RemoteAddress)
|
||||
if host, port, err := net.SplitHostPort(session.RemoteAddress); err == nil {
|
||||
overview.SSHServerState.Sessions[i].RemoteAddress = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
|
||||
} else {
|
||||
overview.SSHServerState.Sessions[i].RemoteAddress = a.AnonymizeIPString(session.RemoteAddress)
|
||||
}
|
||||
overview.SSHServerState.Sessions[i].Command = a.AnonymizeString(session.Command)
|
||||
}
|
||||
for i, sess := range overview.VNCServerState.Sessions {
|
||||
overview.VNCServerState.Sessions[i].RemoteAddress = anonymizeRemoteAddress(a, sess.RemoteAddress)
|
||||
overview.VNCServerState.Sessions[i].Username = a.AnonymizeString(sess.Username)
|
||||
overview.VNCServerState.Sessions[i].UserID = a.AnonymizeString(sess.UserID)
|
||||
overview.VNCServerState.Sessions[i].Initiator = a.AnonymizeString(sess.Initiator)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,10 +240,6 @@ var overview = OutputOverview{
|
||||
Enabled: false,
|
||||
Sessions: []SSHSessionOutput{},
|
||||
},
|
||||
VNCServerState: VNCServerStateOutput{
|
||||
Enabled: false,
|
||||
Sessions: []VNCSessionOutput{},
|
||||
},
|
||||
}
|
||||
|
||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||
@@ -408,10 +404,6 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"sshServer":{
|
||||
"enabled":false,
|
||||
"sessions":[]
|
||||
},
|
||||
"vncServer":{
|
||||
"enabled":false,
|
||||
"sessions":[]
|
||||
}
|
||||
}`
|
||||
// @formatter:on
|
||||
@@ -521,9 +513,6 @@ profileName: ""
|
||||
sshServer:
|
||||
enabled: false
|
||||
sessions: []
|
||||
vncServer:
|
||||
enabled: false
|
||||
sessions: []
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedYAML, yaml)
|
||||
@@ -593,7 +582,6 @@ Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Lazy connection: false
|
||||
SSH Server: Disabled
|
||||
VNC Server: Disabled
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
||||
@@ -619,7 +607,6 @@ Interface type: Kernel
|
||||
Quantum resistance: false
|
||||
Lazy connection: false
|
||||
SSH Server: Disabled
|
||||
VNC Server: Disabled
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`
|
||||
|
||||
@@ -62,7 +62,6 @@ type Info struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed bool
|
||||
ServerVNCAllowed bool
|
||||
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
@@ -84,7 +83,6 @@ type Info struct {
|
||||
func (i *Info) SetFlags(
|
||||
rosenpassEnabled, rosenpassPermissive bool,
|
||||
serverSSHAllowed *bool,
|
||||
serverVNCAllowed *bool,
|
||||
disableClientRoutes, disableServerRoutes,
|
||||
disableDNS, disableFirewall, blockLANAccess, blockInbound, disableIPv6, lazyConnectionEnabled bool,
|
||||
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
|
||||
@@ -95,9 +93,6 @@ func (i *Info) SetFlags(
|
||||
if serverSSHAllowed != nil {
|
||||
i.ServerSSHAllowed = *serverSSHAllowed
|
||||
}
|
||||
if serverVNCAllowed != nil {
|
||||
i.ServerVNCAllowed = *serverVNCAllowed
|
||||
}
|
||||
|
||||
i.DisableClientRoutes = disableClientRoutes
|
||||
i.DisableServerRoutes = disableServerRoutes
|
||||
|
||||
@@ -1,206 +0,0 @@
|
||||
//go:build !(linux && 386)
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"fyne.io/fyne/v2"
|
||||
"fyne.io/fyne/v2/container"
|
||||
"fyne.io/fyne/v2/widget"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/approval"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// handleApprovalEvent forks a netbird-ui child process to render the
|
||||
// dialog on its own fyne main loop. Top-level windows opened from a
|
||||
// background goroutine of the tray process don't render reliably on
|
||||
// Linux/GTK, so the rest of the UI (settings, login URL, update) uses
|
||||
// the same fork pattern.
|
||||
func (s *serviceClient) handleApprovalEvent(ev *proto.SystemEvent) {
|
||||
if ev == nil || ev.Category != proto.SystemEvent_APPROVAL {
|
||||
return
|
||||
}
|
||||
requestID := ev.Metadata["request_id"]
|
||||
if requestID == "" {
|
||||
log.Warnf("approval event missing request_id: %v", ev.Metadata)
|
||||
return
|
||||
}
|
||||
args := []string{
|
||||
"--approval-request-id=" + requestID,
|
||||
"--approval-kind=" + ev.Metadata["kind"],
|
||||
"--approval-initiator=" + ev.Metadata["initiator"],
|
||||
"--approval-peer-name=" + ev.Metadata["peer_name"],
|
||||
"--approval-source-ip=" + ev.Metadata["source_ip"],
|
||||
"--approval-username=" + ev.Metadata["username"],
|
||||
"--approval-expires-at=" + ev.Metadata["expires_at"],
|
||||
"--approval-key-fingerprint=" + ev.Metadata["peer_pubkey"],
|
||||
"--approval-subject=" + ev.UserMessage,
|
||||
}
|
||||
go s.eventHandler.runSelfCommand(s.ctx, "approval", args...)
|
||||
}
|
||||
|
||||
// showApprovalUI runs the dialog on the forked process's fyne main loop
|
||||
// and forwards the user's decision to the daemon via RespondApproval.
|
||||
func (s *serviceClient) showApprovalUI(req approvalRequest) {
|
||||
w := s.app.NewWindow(approvalTitle(req.kind))
|
||||
w.Resize(fyne.NewSize(480, 260))
|
||||
w.CenterOnScreen()
|
||||
w.RequestFocus()
|
||||
|
||||
var rows []string
|
||||
if req.initiator != "" {
|
||||
// The display name comes from the management dashboard and is
|
||||
// not cryptographically asserted by the connecting client. The
|
||||
// key fingerprint that follows IS: it's the Noise_IK static
|
||||
// public key the client just proved possession of. Show both
|
||||
// so the user can sanity-check that "Alice" is really the
|
||||
// Alice they trust.
|
||||
rows = append(rows, "From user: "+req.initiator)
|
||||
}
|
||||
if fp := approval.ShortKeyFingerprint(req.keyFingerprint); fp != "" {
|
||||
rows = append(rows, "Key fp: "+fp)
|
||||
}
|
||||
if req.peerName != "" {
|
||||
rows = append(rows, "Via peer: "+req.peerName)
|
||||
}
|
||||
if req.sourceIP != "" && req.sourceIP != req.peerName {
|
||||
rows = append(rows, "Source IP: "+req.sourceIP)
|
||||
}
|
||||
if req.username != "" {
|
||||
rows = append(rows, "OS user: "+req.username)
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
rows = []string{"Remote: " + req.displayPeer()}
|
||||
}
|
||||
body := strings.Join(rows, "\n")
|
||||
bodyLabel := widget.NewLabel(body)
|
||||
bodyLabel.Wrapping = fyne.TextWrapWord
|
||||
|
||||
countdown := widget.NewLabel("")
|
||||
deadline := req.deadline()
|
||||
updateCountdown := func() {
|
||||
remaining := time.Until(deadline).Round(time.Second)
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
countdown.SetText(fmt.Sprintf("Auto-deny in %s", remaining))
|
||||
}
|
||||
updateCountdown()
|
||||
|
||||
type outcome struct {
|
||||
accept bool
|
||||
viewOnly bool
|
||||
}
|
||||
decided := make(chan outcome, 1)
|
||||
decide := func(o outcome) {
|
||||
select {
|
||||
case decided <- o:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
allow := widget.NewButton("Allow", func() { decide(outcome{accept: true}) })
|
||||
allow.Importance = widget.HighImportance
|
||||
allowView := widget.NewButton("Allow (view only)", func() { decide(outcome{accept: true, viewOnly: true}) })
|
||||
deny := widget.NewButton("Deny", func() { decide(outcome{accept: false}) })
|
||||
|
||||
header := widget.NewLabelWithStyle(req.subject, fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
|
||||
buttonRow := container.NewGridWithColumns(3, allow, allowView, deny)
|
||||
info := container.NewVBox(header, widget.NewSeparator(), bodyLabel, widget.NewSeparator(), countdown)
|
||||
w.SetContent(container.NewPadded(container.NewBorder(nil, buttonRow, nil, nil, info)))
|
||||
w.SetCloseIntercept(func() { decide(outcome{}) })
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
if time.Until(deadline) <= 0 {
|
||||
decide(outcome{})
|
||||
return
|
||||
}
|
||||
fyne.Do(updateCountdown)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
o := <-decided
|
||||
s.sendApprovalResponse(req.requestID, o.accept, o.viewOnly)
|
||||
fyne.Do(func() {
|
||||
w.Close()
|
||||
s.app.Quit()
|
||||
})
|
||||
}()
|
||||
|
||||
w.Show()
|
||||
}
|
||||
|
||||
func (s *serviceClient) sendApprovalResponse(requestID string, accept, viewOnly bool) {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
log.Warnf("approval response: get daemon client: %v", err)
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(s.ctx, defaultFailTimeout)
|
||||
defer cancel()
|
||||
if _, err := conn.RespondApproval(ctx, &proto.RespondApprovalRequest{
|
||||
RequestId: requestID,
|
||||
Accept: accept,
|
||||
ViewOnly: viewOnly,
|
||||
}); err != nil {
|
||||
log.Warnf("approval response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// approvalRequest is the parsed --approval-* CLI args that the forked
|
||||
// dialog process consumes.
|
||||
type approvalRequest struct {
|
||||
requestID string
|
||||
kind string
|
||||
initiator string
|
||||
peerName string
|
||||
sourceIP string
|
||||
username string
|
||||
subject string
|
||||
expiresAt string
|
||||
keyFingerprint string
|
||||
}
|
||||
|
||||
func (r approvalRequest) displayPeer() string {
|
||||
switch {
|
||||
case r.initiator != "":
|
||||
return r.initiator
|
||||
case r.peerName != "":
|
||||
return r.peerName
|
||||
case r.sourceIP != "":
|
||||
return r.sourceIP
|
||||
default:
|
||||
return "unknown peer"
|
||||
}
|
||||
}
|
||||
|
||||
// deadline returns the wall-clock auto-deny moment. Falls back to a short
|
||||
// local window when the daemon's expires_at is missing/unparsable, so a
|
||||
// stale value never leaves the dialog open indefinitely.
|
||||
func (r approvalRequest) deadline() time.Time {
|
||||
if t, err := time.Parse(time.RFC3339, r.expiresAt); err == nil {
|
||||
return t
|
||||
}
|
||||
return time.Now().Add(13 * time.Second)
|
||||
}
|
||||
|
||||
func approvalTitle(kind string) string {
|
||||
switch kind {
|
||||
case "vnc":
|
||||
return "Allow VNC Connection?"
|
||||
case "ssh":
|
||||
return "Allow SSH Connection?"
|
||||
default:
|
||||
return "Allow Incoming Connection?"
|
||||
}
|
||||
}
|
||||
@@ -97,25 +97,13 @@ func main() {
|
||||
showQuickActions: flags.showQuickActions,
|
||||
showUpdate: flags.showUpdate,
|
||||
showUpdateVersion: flags.showUpdateVersion,
|
||||
showApproval: flags.showApproval,
|
||||
approvalRequest: approvalRequest{
|
||||
requestID: flags.approvalRequestID,
|
||||
kind: flags.approvalKind,
|
||||
initiator: flags.approvalInitiator,
|
||||
peerName: flags.approvalPeerName,
|
||||
sourceIP: flags.approvalSourceIP,
|
||||
username: flags.approvalUsername,
|
||||
subject: flags.approvalSubject,
|
||||
expiresAt: flags.approvalExpiresAt,
|
||||
keyFingerprint: flags.approvalKeyFingerprint,
|
||||
},
|
||||
})
|
||||
|
||||
// Watch for theme/settings changes to update the icon.
|
||||
go watchSettingsChanges(a, client)
|
||||
|
||||
// Run in window mode if any UI flag was set.
|
||||
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions || flags.showUpdate || flags.showApproval {
|
||||
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions || flags.showUpdate {
|
||||
a.Run()
|
||||
return
|
||||
}
|
||||
@@ -152,17 +140,6 @@ type cliFlags struct {
|
||||
saveLogsInFile bool
|
||||
showUpdate bool
|
||||
showUpdateVersion string
|
||||
showApproval bool
|
||||
|
||||
approvalRequestID string
|
||||
approvalKind string
|
||||
approvalInitiator string
|
||||
approvalPeerName string
|
||||
approvalSourceIP string
|
||||
approvalUsername string
|
||||
approvalSubject string
|
||||
approvalExpiresAt string
|
||||
approvalKeyFingerprint string
|
||||
}
|
||||
|
||||
// parseFlags reads and returns all needed command-line flags.
|
||||
@@ -184,16 +161,6 @@ func parseFlags() *cliFlags {
|
||||
flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window")
|
||||
flag.BoolVar(&flags.showUpdate, "update", false, "show update progress window")
|
||||
flag.StringVar(&flags.showUpdateVersion, "update-version", "", "version to update to")
|
||||
flag.BoolVar(&flags.showApproval, "approval", false, "show inbound-connection approval prompt window")
|
||||
flag.StringVar(&flags.approvalRequestID, "approval-request-id", "", "approval prompt: daemon-issued request id")
|
||||
flag.StringVar(&flags.approvalKind, "approval-kind", "", "approval prompt: subsystem kind (vnc, ssh, ...)")
|
||||
flag.StringVar(&flags.approvalInitiator, "approval-initiator", "", "approval prompt: display name of the user who initiated the connection")
|
||||
flag.StringVar(&flags.approvalPeerName, "approval-peer-name", "", "approval prompt: remote peer FQDN")
|
||||
flag.StringVar(&flags.approvalSourceIP, "approval-source-ip", "", "approval prompt: remote source IP")
|
||||
flag.StringVar(&flags.approvalUsername, "approval-username", "", "approval prompt: requested OS username")
|
||||
flag.StringVar(&flags.approvalSubject, "approval-subject", "", "approval prompt: human-readable subject line")
|
||||
flag.StringVar(&flags.approvalExpiresAt, "approval-expires-at", "", "approval prompt: RFC3339 deadline at which the daemon auto-denies")
|
||||
flag.StringVar(&flags.approvalKeyFingerprint, "approval-key-fingerprint", "", "approval prompt: hex-encoded Noise static pubkey of the connecting client")
|
||||
flag.Parse()
|
||||
return &flags
|
||||
}
|
||||
@@ -282,7 +249,6 @@ type serviceClient struct {
|
||||
mQuit *systray.MenuItem
|
||||
mNetworks *systray.MenuItem
|
||||
mAllowSSH *systray.MenuItem
|
||||
mAllowVNC *systray.MenuItem
|
||||
mAutoConnect *systray.MenuItem
|
||||
mEnableRosenpass *systray.MenuItem
|
||||
mLazyConnEnabled *systray.MenuItem
|
||||
@@ -321,8 +287,6 @@ type serviceClient struct {
|
||||
sEnableSSHRemotePortForward *widget.Check
|
||||
sDisableSSHAuth *widget.Check
|
||||
iSSHJWTCacheTTL *widget.Entry
|
||||
sServerVNCAllowed *widget.Check
|
||||
sDisableVNCApproval *widget.Check
|
||||
|
||||
// observable settings over corresponding iMngURL and iPreSharedKey values.
|
||||
managementURL string
|
||||
@@ -344,8 +308,6 @@ type serviceClient struct {
|
||||
enableSSHRemotePortForward bool
|
||||
disableSSHAuth bool
|
||||
sshJWTCacheTTL int
|
||||
serverVNCAllowed bool
|
||||
disableVNCApproval bool
|
||||
|
||||
connected bool
|
||||
daemonVersion string
|
||||
@@ -393,8 +355,6 @@ type newServiceClientArgs struct {
|
||||
showQuickActions bool
|
||||
showUpdate bool
|
||||
showUpdateVersion string
|
||||
showApproval bool
|
||||
approvalRequest approvalRequest
|
||||
}
|
||||
|
||||
// newServiceClient instance constructor
|
||||
@@ -435,8 +395,6 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
|
||||
s.showQuickActionsUI()
|
||||
case args.showUpdate:
|
||||
s.showUpdateProgress(ctx, args.showUpdateVersion)
|
||||
case args.showApproval:
|
||||
s.showApprovalUI(args.approvalRequest)
|
||||
}
|
||||
|
||||
return s
|
||||
@@ -520,8 +478,6 @@ func (s *serviceClient) showSettingsUI() {
|
||||
s.sEnableSSHRemotePortForward = widget.NewCheck("Enable SSH Remote Port Forwarding", nil)
|
||||
s.sDisableSSHAuth = widget.NewCheck("Disable SSH Authentication", nil)
|
||||
s.iSSHJWTCacheTTL = widget.NewEntry()
|
||||
s.sServerVNCAllowed = widget.NewCheck("Allow embedded VNC server on this peer", nil)
|
||||
s.sDisableVNCApproval = widget.NewCheck("Skip per-connection approval prompt for VNC", nil)
|
||||
|
||||
s.wSettings.SetContent(s.getSettingsForm())
|
||||
s.wSettings.Resize(fyne.NewSize(600, 400))
|
||||
@@ -634,8 +590,7 @@ func (s *serviceClient) hasSettingsChanged(iMngURL string, port, mtu int64) bool
|
||||
s.disableServerRoutes != s.sDisableServerRoutes.Checked ||
|
||||
s.disableIPv6 != s.sDisableIPv6.Checked ||
|
||||
s.blockLANAccess != s.sBlockLANAccess.Checked ||
|
||||
s.hasSSHChanges() ||
|
||||
s.hasVNCChanges()
|
||||
s.hasSSHChanges()
|
||||
}
|
||||
|
||||
func (s *serviceClient) applySettingsChanges(iMngURL string, port, mtu int64) error {
|
||||
@@ -694,8 +649,6 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
|
||||
req.EnableSSHLocalPortForwarding = &s.sEnableSSHLocalPortForward.Checked
|
||||
req.EnableSSHRemotePortForwarding = &s.sEnableSSHRemotePortForward.Checked
|
||||
req.DisableSSHAuth = &s.sDisableSSHAuth.Checked
|
||||
req.ServerVNCAllowed = &s.sServerVNCAllowed.Checked
|
||||
req.DisableVNCApproval = &s.sDisableVNCApproval.Checked
|
||||
|
||||
sshJWTCacheTTLText := strings.TrimSpace(s.iSSHJWTCacheTTL.Text)
|
||||
if sshJWTCacheTTLText != "" {
|
||||
@@ -756,12 +709,10 @@ func (s *serviceClient) getSettingsForm() fyne.CanvasObject {
|
||||
connectionForm := s.getConnectionForm()
|
||||
networkForm := s.getNetworkForm()
|
||||
sshForm := s.getSSHForm()
|
||||
vncForm := s.getVNCForm()
|
||||
tabs := container.NewAppTabs(
|
||||
container.NewTabItem("Connection", connectionForm),
|
||||
container.NewTabItem("Network", networkForm),
|
||||
container.NewTabItem("SSH", sshForm),
|
||||
container.NewTabItem("VNC", vncForm),
|
||||
)
|
||||
saveButton := widget.NewButtonWithIcon("Save", theme.ConfirmIcon(), s.saveSettings)
|
||||
saveButton.Importance = widget.HighImportance
|
||||
@@ -802,15 +753,6 @@ func (s *serviceClient) getSSHForm() *widget.Form {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceClient) getVNCForm() *widget.Form {
|
||||
return &widget.Form{
|
||||
Items: []*widget.FormItem{
|
||||
{Text: "Allow VNC Server", Widget: s.sServerVNCAllowed},
|
||||
{Text: "Disable Connection Approval Prompt", Widget: s.sDisableVNCApproval},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceClient) hasSSHChanges() bool {
|
||||
currentSSHJWTCacheTTL := s.sshJWTCacheTTL
|
||||
if text := strings.TrimSpace(s.iSSHJWTCacheTTL.Text); text != "" {
|
||||
@@ -829,11 +771,6 @@ func (s *serviceClient) hasSSHChanges() bool {
|
||||
s.sshJWTCacheTTL != currentSSHJWTCacheTTL
|
||||
}
|
||||
|
||||
func (s *serviceClient) hasVNCChanges() bool {
|
||||
return s.serverVNCAllowed != s.sServerVNCAllowed.Checked ||
|
||||
s.disableVNCApproval != s.sDisableVNCApproval.Checked
|
||||
}
|
||||
|
||||
func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
@@ -1108,7 +1045,6 @@ func (s *serviceClient) onTrayReady() {
|
||||
|
||||
s.mSettings = systray.AddMenuItem("Settings", disabledMenuDescr)
|
||||
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false)
|
||||
s.mAllowVNC = s.mSettings.AddSubMenuItemCheckbox("Allow VNC", allowVNCMenuDescr, false)
|
||||
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false)
|
||||
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false)
|
||||
s.mLazyConnEnabled = s.mSettings.AddSubMenuItemCheckbox("Enable Lazy Connections", lazyConnMenuDescr, false)
|
||||
@@ -1182,7 +1118,6 @@ func (s *serviceClient) onTrayReady() {
|
||||
|
||||
s.eventManager = event.NewManager(s.notifier, s.addr)
|
||||
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
|
||||
s.eventManager.AddHandler(s.handleApprovalEvent)
|
||||
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
|
||||
if event.Category == proto.SystemEvent_SYSTEM {
|
||||
s.updateExitNodes()
|
||||
@@ -1418,12 +1353,6 @@ func (s *serviceClient) getSrvConfig() {
|
||||
if cfg.SSHJWTCacheTTL != nil {
|
||||
s.sshJWTCacheTTL = *cfg.SSHJWTCacheTTL
|
||||
}
|
||||
if cfg.ServerVNCAllowed != nil {
|
||||
s.serverVNCAllowed = *cfg.ServerVNCAllowed
|
||||
}
|
||||
if cfg.DisableVNCApproval != nil {
|
||||
s.disableVNCApproval = *cfg.DisableVNCApproval
|
||||
}
|
||||
|
||||
if s.showAdvancedSettings {
|
||||
s.iMngURL.SetText(s.managementURL)
|
||||
@@ -1464,12 +1393,6 @@ func (s *serviceClient) getSrvConfig() {
|
||||
if cfg.SSHJWTCacheTTL != nil {
|
||||
s.iSSHJWTCacheTTL.SetText(strconv.Itoa(*cfg.SSHJWTCacheTTL))
|
||||
}
|
||||
if cfg.ServerVNCAllowed != nil {
|
||||
s.sServerVNCAllowed.SetChecked(*cfg.ServerVNCAllowed)
|
||||
}
|
||||
if cfg.DisableVNCApproval != nil {
|
||||
s.sDisableVNCApproval.SetChecked(*cfg.DisableVNCApproval)
|
||||
}
|
||||
}
|
||||
|
||||
if s.mNotifications == nil {
|
||||
@@ -1529,8 +1452,6 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
|
||||
|
||||
config.DisableAutoConnect = cfg.DisableAutoConnect
|
||||
config.ServerSSHAllowed = &cfg.ServerSSHAllowed
|
||||
config.ServerVNCAllowed = &cfg.ServerVNCAllowed
|
||||
config.DisableVNCApproval = &cfg.DisableVNCApproval
|
||||
config.RosenpassEnabled = cfg.RosenpassEnabled
|
||||
config.RosenpassPermissive = cfg.RosenpassPermissive
|
||||
config.DisableNotifications = &cfg.DisableNotifications
|
||||
@@ -1626,12 +1547,6 @@ func (s *serviceClient) loadSettings() {
|
||||
s.mAllowSSH.Uncheck()
|
||||
}
|
||||
|
||||
if cfg.ServerVNCAllowed {
|
||||
s.mAllowVNC.Check()
|
||||
} else {
|
||||
s.mAllowVNC.Uncheck()
|
||||
}
|
||||
|
||||
if cfg.DisableAutoConnect {
|
||||
s.mAutoConnect.Uncheck()
|
||||
} else {
|
||||
@@ -1671,7 +1586,6 @@ func (s *serviceClient) loadSettings() {
|
||||
func (s *serviceClient) updateConfig() error {
|
||||
disableAutoStart := !s.mAutoConnect.Checked()
|
||||
sshAllowed := s.mAllowSSH.Checked()
|
||||
vncAllowed := s.mAllowVNC.Checked()
|
||||
rosenpassEnabled := s.mEnableRosenpass.Checked()
|
||||
lazyConnectionEnabled := s.mLazyConnEnabled.Checked()
|
||||
blockInbound := s.mBlockInbound.Checked()
|
||||
@@ -1700,7 +1614,6 @@ func (s *serviceClient) updateConfig() error {
|
||||
Username: currUser.Username,
|
||||
DisableAutoConnect: &disableAutoStart,
|
||||
ServerSSHAllowed: &sshAllowed,
|
||||
ServerVNCAllowed: &vncAllowed,
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
LazyConnectionEnabled: &lazyConnectionEnabled,
|
||||
BlockInbound: &blockInbound,
|
||||
|
||||
@@ -2,7 +2,6 @@ package main
|
||||
|
||||
const (
|
||||
allowSSHMenuDescr = "Allow SSH connections"
|
||||
allowVNCMenuDescr = "Allow embedded VNC server"
|
||||
autoConnectMenuDescr = "Connect automatically when the service starts"
|
||||
quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass"
|
||||
lazyConnMenuDescr = "[Experimental] Enable lazy connections"
|
||||
|
||||
@@ -112,7 +112,7 @@ func (e *Manager) handleEvent(event *proto.SystemEvent) {
|
||||
handlers := slices.Clone(e.handlers)
|
||||
e.mu.Unlock()
|
||||
|
||||
if event.UserMessage != "" && (enabled || event.Severity == proto.SystemEvent_CRITICAL) && !isV6DefaultRoutePartner(event) && event.Category != proto.SystemEvent_APPROVAL {
|
||||
if event.UserMessage != "" && (enabled || event.Severity == proto.SystemEvent_CRITICAL) && !isV6DefaultRoutePartner(event) {
|
||||
title := e.getEventTitle(event)
|
||||
body := event.UserMessage
|
||||
id := event.Metadata["id"]
|
||||
|
||||
@@ -39,8 +39,6 @@ func (h *eventHandler) listen(ctx context.Context) {
|
||||
h.handleDisconnectClick()
|
||||
case <-h.client.mAllowSSH.ClickedCh:
|
||||
h.handleAllowSSHClick()
|
||||
case <-h.client.mAllowVNC.ClickedCh:
|
||||
h.handleAllowVNCClick()
|
||||
case <-h.client.mAutoConnect.ClickedCh:
|
||||
h.handleAutoConnectClick()
|
||||
case <-h.client.mEnableRosenpass.ClickedCh:
|
||||
@@ -136,15 +134,6 @@ func (h *eventHandler) handleAllowSSHClick() {
|
||||
|
||||
}
|
||||
|
||||
func (h *eventHandler) handleAllowVNCClick() {
|
||||
h.toggleCheckbox(h.client.mAllowVNC)
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
h.toggleCheckbox(h.client.mAllowVNC) // revert checkbox state on error
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
h.client.notifier.Send("Error", "Failed to update VNC settings")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *eventHandler) handleAutoConnectClick() {
|
||||
h.toggleCheckbox(h.client.mAutoConnect)
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
// Package vnc holds shared constants for the NetBird embedded VNC stack
|
||||
// so non-server consumers (CLI capture, debug tooling) can refer to the
|
||||
// well-known ports without depending on internal engine packages.
|
||||
package vnc
|
||||
|
||||
// External and internal listen ports for the embedded VNC server.
|
||||
// ExternalPort is what dashboard / browser clients see; the daemon
|
||||
// DNATs it to InternalPort, where the in-process VNC server actually
|
||||
// listens. Both flow over the WireGuard interface. AgentLegacyPort is
|
||||
// the TCP port the per-session agent used before it switched to Unix
|
||||
// sockets; kept here so packet captures from older builds still get
|
||||
// tagged, and so any future on-wire agent variant has a reserved port.
|
||||
const (
|
||||
ExternalPort uint16 = 5900
|
||||
InternalPort uint16 = 25900
|
||||
AgentLegacyPort uint16 = 15900
|
||||
)
|
||||
|
||||
// WellKnownPorts is the unordered set of ports a packet capture should
|
||||
// treat as carrying NetBird VNC traffic.
|
||||
var WellKnownPorts = [...]uint16{ExternalPort, InternalPort, AgentLegacyPort}
|
||||
|
||||
// IsWellKnownPort reports whether port matches any of WellKnownPorts.
|
||||
func IsWellKnownPort(port uint16) bool {
|
||||
for _, p := range WellKnownPorts {
|
||||
if port == p {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,434 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/configs"
|
||||
)
|
||||
|
||||
// darwinAgentManager spawns a per-user VNC agent on demand and keeps it
|
||||
// alive across multiple client connections within the same console-user
|
||||
// session. A new agent is spawned the first time a client connects, or
|
||||
// whenever the console user changes underneath us.
|
||||
//
|
||||
// Lifecycle is lazy by design: a daemon that never receives a VNC
|
||||
// connection never spawns anything. The trade-off versus an eager spawn
|
||||
// (the Windows model) is that the first VNC client pays the launchctl
|
||||
// asuser + listen-readiness wait, ~hundreds of milliseconds in practice.
|
||||
// That cost only repeats on user switch.
|
||||
type darwinAgentManager struct {
|
||||
mu sync.Mutex
|
||||
authToken string
|
||||
socketPath string
|
||||
uid uint32
|
||||
running bool
|
||||
}
|
||||
|
||||
func newDarwinAgentManager(ctx context.Context) *darwinAgentManager {
|
||||
m := &darwinAgentManager{}
|
||||
go m.watchConsoleUser(ctx)
|
||||
return m
|
||||
}
|
||||
|
||||
// agentSocketName is the file name inside the per-uid socket directory
|
||||
// the agent binds. The directory itself is created and chowned by the
|
||||
// daemon (see prepareAgentSocketDir) so a non-root local user cannot
|
||||
// pre-create or symlink the path before the agent listens.
|
||||
const agentSocketName = "agent.sock"
|
||||
|
||||
// watchConsoleUser kills the cached agent whenever the console user
|
||||
// changes (logout, fast user switch, login window). Without it the daemon
|
||||
// keeps proxying to an agent whose TCC grant and WindowServer access
|
||||
// belong to a user who is no longer at the screen, so the new user only
|
||||
// ever sees the locked-screen wallpaper. Killing the agent breaks the
|
||||
// loopback TCP that the daemon proxies into, the client disconnects, and
|
||||
// the next reconnect runs ensure() against the new console uid.
|
||||
func (m *darwinAgentManager) watchConsoleUser(ctx context.Context) {
|
||||
t := time.NewTicker(2 * time.Second)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
uid, err := consoleUserID()
|
||||
m.mu.Lock()
|
||||
if !m.running {
|
||||
m.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
if err != nil || uid != m.uid {
|
||||
prev := m.uid
|
||||
m.killLocked()
|
||||
m.mu.Unlock()
|
||||
if err != nil {
|
||||
log.Infof("console user gone (was uid=%d): %v; agent stopped", prev, err)
|
||||
} else {
|
||||
log.Infof("console user changed %d -> %d; agent stopped, will respawn on next connect", prev, uid)
|
||||
}
|
||||
continue
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve spawns or respawns the per-user agent process as needed and
|
||||
// returns its Unix-socket path, shared token, and the uid the agent was
|
||||
// spawned under (so the daemon can validate peer credentials before
|
||||
// dispatching the token). Each call is serialized so concurrent VNC
|
||||
// clients share the same agent.
|
||||
func (m *darwinAgentManager) Resolve(ctx context.Context) (string, string, uint32, error) {
|
||||
consoleUID, err := consoleUserID()
|
||||
if err != nil {
|
||||
return "", "", 0, fmt.Errorf("no console user: %w", err)
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.running && m.uid == consoleUID && vncAgentRunning() {
|
||||
return m.socketPath, m.authToken, m.uid, nil
|
||||
}
|
||||
m.killLocked()
|
||||
// Reap stray agents so the new token is the only accepted one.
|
||||
killAllVNCAgents()
|
||||
|
||||
socketDir, err := prepareAgentSocketDir(consoleUID)
|
||||
if err != nil {
|
||||
return "", "", 0, fmt.Errorf("prepare agent socket dir: %w", err)
|
||||
}
|
||||
socketPath := socketDir + "/" + agentSocketName
|
||||
if err := os.Remove(socketPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
log.Debugf("clear stale agent socket %s: %v", socketPath, err)
|
||||
}
|
||||
|
||||
token, err := generateAuthToken()
|
||||
if err != nil {
|
||||
return "", "", 0, fmt.Errorf("generate agent auth token: %w", err)
|
||||
}
|
||||
if err := spawnAgentForUser(consoleUID, socketPath, token); err != nil {
|
||||
return "", "", 0, err
|
||||
}
|
||||
if err := waitForAgent(ctx, socketPath, 5*time.Second); err != nil {
|
||||
killAllVNCAgents()
|
||||
return "", "", 0, fmt.Errorf("agent did not start listening: %w", err)
|
||||
}
|
||||
m.authToken = token
|
||||
m.socketPath = socketPath
|
||||
m.uid = consoleUID
|
||||
m.running = true
|
||||
log.Infof("spawned VNC agent for console uid=%d on %s", consoleUID, socketPath)
|
||||
return socketPath, token, consoleUID, nil
|
||||
}
|
||||
|
||||
// prepareAgentSocketDir creates a per-uid subdirectory under the netbird
|
||||
// runtime directory where the agent will bind its Unix socket. The leaf is
|
||||
// owned by uid with mode 0700, so only the target user and root can write
|
||||
// there. The parent is created root-owned with mode 0755 if missing.
|
||||
// Symlinks at the per-uid level are refused (replaced with a fresh
|
||||
// directory) so a low-priv user cannot redirect the chown that follows.
|
||||
func prepareAgentSocketDir(uid uint32) (string, error) {
|
||||
parent := configs.RuntimeDir
|
||||
if err := ensureAgentSocketParent(parent); err != nil {
|
||||
return "", err
|
||||
}
|
||||
subdir := fmt.Sprintf("%s/vnc-%d", parent, uid)
|
||||
if err := purgeStaleAgentSubdir(subdir, uid); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := os.Mkdir(subdir, 0o700); err != nil && !errors.Is(err, os.ErrExist) {
|
||||
return "", fmt.Errorf("mkdir %s: %w", subdir, err)
|
||||
}
|
||||
if err := os.Chmod(subdir, 0o700); err != nil {
|
||||
return "", fmt.Errorf("chmod %s: %w", subdir, err)
|
||||
}
|
||||
if err := os.Chown(subdir, int(uid), -1); err != nil {
|
||||
return "", fmt.Errorf("chown %s -> uid %d: %w", subdir, uid, err)
|
||||
}
|
||||
return subdir, nil
|
||||
}
|
||||
|
||||
// ensureAgentSocketParent verifies the runtime parent dir exists, is not a
|
||||
// symlink, and is owned by root.
|
||||
func ensureAgentSocketParent(parent string) error {
|
||||
if parent == "" {
|
||||
return fmt.Errorf("no runtime directory configured for this platform")
|
||||
}
|
||||
if err := os.MkdirAll(parent, 0o755); err != nil {
|
||||
return fmt.Errorf("mkdir %s: %w", parent, err)
|
||||
}
|
||||
info, err := os.Lstat(parent)
|
||||
if err != nil {
|
||||
return fmt.Errorf("lstat %s: %w", parent, err)
|
||||
}
|
||||
if info.Mode()&os.ModeSymlink != 0 {
|
||||
return fmt.Errorf("%s is a symlink", parent)
|
||||
}
|
||||
if st, ok := info.Sys().(*syscall.Stat_t); ok && st.Uid != 0 {
|
||||
return fmt.Errorf("%s not owned by root (uid=%d)", parent, st.Uid)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// purgeStaleAgentSubdir removes a leftover subdir unless it is a real dir
|
||||
// owned by uid with mode 0700. Lstat (not Stat) so a symlink is detected.
|
||||
func purgeStaleAgentSubdir(subdir string, uid uint32) error {
|
||||
info, err := os.Lstat(subdir)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("lstat %s: %w", subdir, err)
|
||||
}
|
||||
if agentSubdirOK(info, uid) {
|
||||
return nil
|
||||
}
|
||||
if err := os.RemoveAll(subdir); err != nil {
|
||||
return fmt.Errorf("remove stale %s: %w", subdir, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func agentSubdirOK(info os.FileInfo, uid uint32) bool {
|
||||
if info.Mode()&os.ModeSymlink != 0 || !info.IsDir() {
|
||||
return false
|
||||
}
|
||||
st, ok := info.Sys().(*syscall.Stat_t)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return st.Uid == uid && info.Mode().Perm() == 0o700
|
||||
}
|
||||
|
||||
// stop terminates the spawned agent, if any. Intended for daemon shutdown.
|
||||
func (m *darwinAgentManager) stop() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.killLocked()
|
||||
}
|
||||
|
||||
func (m *darwinAgentManager) killLocked() {
|
||||
if !m.running {
|
||||
return
|
||||
}
|
||||
killAllVNCAgents()
|
||||
if m.socketPath != "" {
|
||||
if err := os.Remove(m.socketPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
log.Debugf("remove agent socket %s: %v", m.socketPath, err)
|
||||
}
|
||||
}
|
||||
m.running = false
|
||||
m.authToken = ""
|
||||
m.socketPath = ""
|
||||
m.uid = 0
|
||||
}
|
||||
|
||||
// consoleUserID returns the uid of the user currently sitting at the
|
||||
// console (the one whose Aqua session is active). Returns
|
||||
// errNoConsoleUser when nobody is logged in: at the login window
|
||||
// /dev/console is owned by root.
|
||||
func consoleUserID() (uint32, error) {
|
||||
info, err := os.Stat("/dev/console")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("stat /dev/console: %w", err)
|
||||
}
|
||||
st, ok := info.Sys().(*syscall.Stat_t)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("/dev/console stat has unexpected type")
|
||||
}
|
||||
if st.Uid == 0 {
|
||||
return 0, errNoConsoleUser
|
||||
}
|
||||
return st.Uid, nil
|
||||
}
|
||||
|
||||
// spawnAgentForUser uses launchctl asuser to start a netbird vnc-agent
|
||||
// process inside the target user's launchd bootstrap namespace. That is
|
||||
// the only spawn mode on macOS that gives the child access to the user's
|
||||
// WindowServer. The agent's stderr is relogged into the daemon log so
|
||||
// startup failures are not silently lost when the readiness check times
|
||||
// out.
|
||||
func spawnAgentForUser(uid uint32, socketPath, token string) error {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return fmt.Errorf("resolve own executable: %w", err)
|
||||
}
|
||||
cmd := exec.Command(
|
||||
"/bin/launchctl", "asuser", strconv.FormatUint(uint64(uid), 10),
|
||||
exe, vncAgentSubcommand,
|
||||
"--socket", socketPath,
|
||||
// Drop privs inside the agent: launchctl asuser preserves the
|
||||
// daemon's uid (root), so without this the capture/input/
|
||||
// encoder paths would run as root for the lifetime of the
|
||||
// session. validateAgentPeer on the daemon side also relies on
|
||||
// the agent's effective uid matching consoleUID.
|
||||
"--target-uid", strconv.FormatUint(uint64(uid), 10),
|
||||
)
|
||||
cmd.Env = append(os.Environ(), agentTokenEnvVar+"="+token)
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("agent stderr pipe: %w", err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("launchctl asuser: %w", err)
|
||||
}
|
||||
go func() {
|
||||
defer stderr.Close()
|
||||
relogAgentStream(stderr)
|
||||
}()
|
||||
go func() { _ = cmd.Wait() }()
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitForAgent dials the agent's Unix socket until it answers. Used to
|
||||
// gate proxy attempts until the spawned process has finished its Start.
|
||||
func waitForAgent(ctx context.Context, socketPath string, wait time.Duration) error {
|
||||
var d net.Dialer
|
||||
deadline := time.Now().Add(wait)
|
||||
for time.Now().Before(deadline) {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
dialCtx, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
|
||||
c, err := d.DialContext(dialCtx, "unix", socketPath)
|
||||
cancel()
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
return nil
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
return fmt.Errorf("timeout dialing %s", socketPath)
|
||||
}
|
||||
|
||||
// vncAgentRunning reports whether any vnc-agent process exists on the
|
||||
// system. There is at most one agent per machine, so any match is "the"
|
||||
// agent.
|
||||
func vncAgentRunning() bool {
|
||||
pids, err := vncAgentPIDs()
|
||||
if err != nil {
|
||||
log.Debugf("scan for vnc-agent: %v", err)
|
||||
return false
|
||||
}
|
||||
return len(pids) > 0
|
||||
}
|
||||
|
||||
// killAllVNCAgents sends SIGTERM to every process whose argv contains
|
||||
// "vnc-agent", waits briefly for them to exit, and escalates to SIGKILL
|
||||
// for any that remain. We enumerate kern.proc.all rather than
|
||||
// kern.proc.uid because launchctl asuser preserves the caller's uid
|
||||
// (root) on the spawned child, so a uid-scoped filter would never match.
|
||||
func killAllVNCAgents() {
|
||||
pids, err := vncAgentPIDs()
|
||||
if err != nil {
|
||||
log.Debugf("scan for vnc-agent: %v", err)
|
||||
return
|
||||
}
|
||||
for _, pid := range pids {
|
||||
_ = syscall.Kill(pid, syscall.SIGTERM)
|
||||
}
|
||||
if len(pids) == 0 {
|
||||
return
|
||||
}
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
remaining, _ := vncAgentPIDs()
|
||||
if len(remaining) == 0 {
|
||||
return
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
leftover, _ := vncAgentPIDs()
|
||||
for _, pid := range leftover {
|
||||
_ = syscall.Kill(pid, syscall.SIGKILL)
|
||||
}
|
||||
}
|
||||
|
||||
// vncAgentPIDs returns the pids of vnc-agent subprocesses spawned from
|
||||
// this binary. Matches exactly on argv[0] == our own executable path
|
||||
// AND argv[1] == "vnc-agent" so unrelated processes that happen to have
|
||||
// the same name elsewhere in argv are not targeted. Skips pid 0 and 1
|
||||
// defensively.
|
||||
func vncAgentPIDs() ([]int, error) {
|
||||
procs, err := unix.SysctlKinfoProcSlice("kern.proc.all")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sysctl kern.proc.all: %w", err)
|
||||
}
|
||||
ownExe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve own executable: %w", err)
|
||||
}
|
||||
var out []int
|
||||
for i := range procs {
|
||||
pid := int(procs[i].Proc.P_pid)
|
||||
if pid <= 1 {
|
||||
continue
|
||||
}
|
||||
argv, err := procArgv(pid)
|
||||
if err != nil || !argvIsVNCAgent(argv, ownExe) {
|
||||
continue
|
||||
}
|
||||
out = append(out, pid)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// procArgv reads the kernel's stored argv for pid via the kern.procargs2
|
||||
// sysctl. Format: 4-byte argc, then argv[0..argc) each NUL-terminated,
|
||||
// then envp, then padding. We only need argv so we stop after argc.
|
||||
func procArgv(pid int) ([]string, error) {
|
||||
raw, err := unix.SysctlRaw("kern.procargs2", pid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(raw) < 4 {
|
||||
return nil, fmt.Errorf("procargs2 truncated")
|
||||
}
|
||||
argc := int(raw[0]) | int(raw[1])<<8 | int(raw[2])<<16 | int(raw[3])<<24
|
||||
body := raw[4:]
|
||||
// Skip the executable path (NUL-terminated) and any zero padding that
|
||||
// follows before argv[0].
|
||||
end := bytes.IndexByte(body, 0)
|
||||
if end < 0 {
|
||||
return nil, fmt.Errorf("procargs2 path unterminated")
|
||||
}
|
||||
body = body[end+1:]
|
||||
for len(body) > 0 && body[0] == 0 {
|
||||
body = body[1:]
|
||||
}
|
||||
args := make([]string, 0, argc)
|
||||
for i := 0; i < argc; i++ {
|
||||
end := bytes.IndexByte(body, 0)
|
||||
if end < 0 {
|
||||
break
|
||||
}
|
||||
args = append(args, string(body[:end]))
|
||||
body = body[end+1:]
|
||||
}
|
||||
return args, nil
|
||||
}
|
||||
|
||||
// argvIsVNCAgent reports whether argv belongs to a vnc-agent subprocess
|
||||
// spawned from our binary. Requires argv[0] to match ownExe exactly and
|
||||
// argv[1] to be the vnc-agent subcommand. Matches the spawn shape in
|
||||
// spawnAgentForUser and rejects anything else.
|
||||
func argvIsVNCAgent(argv []string, ownExe string) bool {
|
||||
if len(argv) < 2 || ownExe == "" {
|
||||
return false
|
||||
}
|
||||
return argv[0] == ownExe && argv[1] == vncAgentSubcommand
|
||||
}
|
||||
@@ -1,305 +0,0 @@
|
||||
//go:build darwin || windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
crand "crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// errNoConsoleUser is the sentinel returned by sessionAgent.Resolve when
|
||||
// the platform has no interactive user to attach a capture agent to (the
|
||||
// macOS loginwindow state). Mapped to a distinct RFB reject code so the
|
||||
// browser can show a meaningful message.
|
||||
var errNoConsoleUser = errors.New("no user logged into console")
|
||||
|
||||
// sessionAgent abstracts the per-platform manager that spawns and tracks
|
||||
// the user-session VNC agent. Resolve returns the agent's Unix-socket
|
||||
// path, the shared per-spawn token, and the uid the agent was spawned
|
||||
// under (used to validate peer credentials before the daemon hands the
|
||||
// token to whoever is on the other end of the socket). Resolve may spawn
|
||||
// the agent lazily.
|
||||
type sessionAgent interface {
|
||||
Resolve(ctx context.Context) (socketPath, token string, peerUID uint32, err error)
|
||||
}
|
||||
|
||||
// prefixConn replays already-consumed header bytes ahead of the proxy
|
||||
// stream by swapping in a different Reader on the same underlying Conn.
|
||||
type prefixConn struct {
|
||||
io.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (p *prefixConn) Read(b []byte) (int, error) { return p.Reader.Read(b) }
|
||||
|
||||
// handleServiceConnection runs the connection-header handshake (source
|
||||
// check, Noise_IK auth) on conn, resolves the right per-session agent
|
||||
// via sa, and proxies to it. Every accepted connection emits exactly one
|
||||
// outcome line on the daemon log.
|
||||
func (s *Server) handleServiceConnection(conn net.Conn, sa sessionAgent) {
|
||||
start := time.Now()
|
||||
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
|
||||
|
||||
if !s.isAllowedSource(conn.RemoteAddr()) {
|
||||
connLog.Info("VNC connection rejected: source not allowed")
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var headerBuf bytes.Buffer
|
||||
tee := io.TeeReader(conn, &headerBuf)
|
||||
teeConn := &prefixConn{Reader: tee, Conn: conn}
|
||||
|
||||
header, err := s.readConnectionHeader(teeConn)
|
||||
if err != nil {
|
||||
connLog.Infof("VNC connection rejected: header read failed: %v", err)
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
authedLog, sessionUserID, ok := s.authorizeSession(conn, header, connLog)
|
||||
if !ok {
|
||||
authedLog.Info("VNC connection rejected: auth failed")
|
||||
return
|
||||
}
|
||||
if err := s.registerConnAuth(conn, header); err != nil {
|
||||
rejectConnection(conn, codeMessage(RejectCodeAuthForbidden, err.Error()))
|
||||
authedLog.Warnf("VNC connection rejected: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
decision, err := s.gateApproval(conn, header)
|
||||
if err != nil {
|
||||
authedLog.Infof("VNC connection rejected: %v", err)
|
||||
return
|
||||
}
|
||||
if decision.ViewOnly {
|
||||
authedLog.Info("VNC connection approved by user (view-only)")
|
||||
} else if s.requireApproval {
|
||||
authedLog.Info("VNC connection approved by user")
|
||||
}
|
||||
|
||||
socketPath, token, peerUID, err := sa.Resolve(s.ctx)
|
||||
if err != nil {
|
||||
code := RejectCodeCapturerError
|
||||
if errors.Is(err, errNoConsoleUser) {
|
||||
code = RejectCodeNoConsoleUser
|
||||
}
|
||||
rejectConnection(conn, codeMessage(code, err.Error()))
|
||||
authedLog.Warnf("VNC connection rejected: agent unavailable: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var initiator string
|
||||
if s.authorizer != nil {
|
||||
initiator = s.authorizer.LookupSessionDisplayName(header.clientStatic)
|
||||
}
|
||||
sessionID := s.addSession(ActiveSessionInfo{
|
||||
RemoteAddress: conn.RemoteAddr().String(),
|
||||
Mode: modeString(header.mode),
|
||||
Username: header.username,
|
||||
UserID: sessionUserID,
|
||||
Initiator: initiator,
|
||||
}, conn)
|
||||
defer s.removeSession(sessionID)
|
||||
|
||||
replayConn := &prefixConn{
|
||||
Reader: io.MultiReader(&headerBuf, conn),
|
||||
Conn: conn,
|
||||
}
|
||||
if err := proxyToAgent(s.ctx, replayConn, socketPath, token, peerUID, decision.ViewOnly, authedLog); err != nil {
|
||||
rejectConnection(conn, codeMessage(RejectCodeCapturerError, err.Error()))
|
||||
authedLog.Warnf("VNC connection rejected: agent unreachable: %v", err)
|
||||
return
|
||||
}
|
||||
authedLog.Infof("VNC connection closed (%dms)", time.Since(start).Milliseconds())
|
||||
}
|
||||
|
||||
const (
|
||||
// agentTokenLen is the size of the random per-spawn token in bytes.
|
||||
agentTokenLen = 32
|
||||
|
||||
// agentTokenEnvVar names the environment variable the daemon uses to
|
||||
// hand the per-spawn token to the agent child. Out-of-band channels
|
||||
// like this keep the secret out of the command line, where listings
|
||||
// such as `ps` or Windows tasklist would expose it.
|
||||
agentTokenEnvVar = "NB_VNC_AGENT_TOKEN" // #nosec G101 -- env var name, not a credential
|
||||
|
||||
// vncAgentSubcommand is the CLI subcommand the daemon invokes to start
|
||||
// the per-session agent process. Must match cmd.vncAgentCmd.Use in
|
||||
// client/cmd/vnc_agent.go.
|
||||
vncAgentSubcommand = "vnc-agent"
|
||||
)
|
||||
|
||||
// generateAuthToken returns a fresh hex-encoded random token for one
|
||||
// daemon→agent session. The daemon hands this to the spawned agent
|
||||
// out-of-band (env var on Windows) and verifies it on every connection
|
||||
// the agent accepts.
|
||||
func generateAuthToken() (string, error) {
|
||||
b := make([]byte, agentTokenLen)
|
||||
if _, err := crand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("read random: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// proxyToAgent dials the per-session agent's Unix socket, validates the
|
||||
// peer's kernel-asserted uid (so the daemon never hands its per-spawn
|
||||
// token to an impostor that won the listen race), writes the raw token
|
||||
// bytes plus a single view-only flag byte, then copies bytes both ways
|
||||
// until either side closes. The token + flag prefix must precede any RFB
|
||||
// byte so the agent's verifyAgentToken can run first. Returns nil once a
|
||||
// stream is established; the caller is responsible for sending an
|
||||
// RFB-level rejection on error so the client sees a reason instead of a
|
||||
// bare timeout. authedLog receives one audit line per dispatched
|
||||
// preamble so an operator can correlate daemon→agent traffic with the
|
||||
// remote session that triggered it.
|
||||
func proxyToAgent(ctx context.Context, client net.Conn, socketPath, authToken string, peerUID uint32, viewOnly bool, authedLog *log.Entry) error {
|
||||
tokenBytes, err := hex.DecodeString(authToken)
|
||||
if err != nil || len(tokenBytes) != agentTokenLen {
|
||||
return fmt.Errorf("invalid auth token (len=%d): %w", len(tokenBytes), err)
|
||||
}
|
||||
|
||||
agentConn, err := dialAgentWithRetry(ctx, socketPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial agent at %s: %w", socketPath, err)
|
||||
}
|
||||
|
||||
if err := validateAgentPeer(agentConn, peerUID); err != nil {
|
||||
_ = agentConn.Close()
|
||||
return fmt.Errorf("agent peer validation failed: %w", err)
|
||||
}
|
||||
|
||||
preamble := make([]byte, len(tokenBytes)+1)
|
||||
copy(preamble, tokenBytes)
|
||||
if viewOnly {
|
||||
preamble[len(tokenBytes)] = 1
|
||||
}
|
||||
if _, err := agentConn.Write(preamble); err != nil {
|
||||
_ = agentConn.Close()
|
||||
return fmt.Errorf("send auth preamble to agent: %w", err)
|
||||
}
|
||||
|
||||
// Audit: one line per successfully-dispatched daemon→agent preamble.
|
||||
// Token printed as its first 8 hex chars (enough to correlate, not
|
||||
// enough to use). Kept at Info so the default deployment captures it.
|
||||
tokenFp := authToken
|
||||
if len(tokenFp) > 8 {
|
||||
tokenFp = tokenFp[:8]
|
||||
}
|
||||
if authedLog != nil {
|
||||
authedLog.Infof("VNC IPC: dispatched preamble to agent socket=%s peer_uid=%d view_only=%v token_fp=%s", socketPath, peerUID, viewOnly, tokenFp)
|
||||
}
|
||||
|
||||
defer client.Close()
|
||||
defer agentConn.Close()
|
||||
log.Debugf("proxy connected to agent, starting bidirectional copy")
|
||||
done := make(chan struct{}, 2)
|
||||
cp := func(label string, dst, src net.Conn) {
|
||||
n, err := io.Copy(dst, src)
|
||||
log.Debugf("proxy %s: %d bytes, err=%v", label, n, err)
|
||||
done <- struct{}{}
|
||||
}
|
||||
go cp("client->agent", agentConn, client)
|
||||
go cp("agent->client", client, agentConn)
|
||||
<-done
|
||||
return nil
|
||||
}
|
||||
|
||||
// relogAgentStream reads log lines from the agent's stderr and re-emits
|
||||
// them through the daemon's logrus, so the merged log keeps a single
|
||||
// format. JSON lines (the agent's normal output) are parsed and dispatched
|
||||
// by level; plain-text lines (cobra errors, panic traces) are forwarded
|
||||
// verbatim so early-startup failures stay visible.
|
||||
func relogAgentStream(r io.Reader) {
|
||||
entry := log.WithField("component", "vnc-agent")
|
||||
scanner := bufio.NewScanner(r)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
if line[0] != '{' {
|
||||
entry.Warn(string(line))
|
||||
continue
|
||||
}
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(line, &m); err != nil {
|
||||
entry.Warn(string(line))
|
||||
continue
|
||||
}
|
||||
msg, _ := m["msg"].(string)
|
||||
if msg == "" {
|
||||
continue
|
||||
}
|
||||
fields := make(log.Fields)
|
||||
for k, v := range m {
|
||||
switch k {
|
||||
case "msg", "level", "time", "func":
|
||||
continue
|
||||
case "caller":
|
||||
fields["source"] = v
|
||||
default:
|
||||
fields[k] = v
|
||||
}
|
||||
}
|
||||
e := entry.WithFields(fields)
|
||||
switch m["level"] {
|
||||
case "error":
|
||||
e.Error(msg)
|
||||
case "warning":
|
||||
e.Warn(msg)
|
||||
case "debug":
|
||||
e.Debug(msg)
|
||||
case "trace":
|
||||
e.Trace(msg)
|
||||
default:
|
||||
e.Info(msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dialAgentWithRetry retries the loopback connect for up to ~10 s so the
|
||||
// daemon does not race the agent's first listen. Returns the live conn or
|
||||
// the final error. Aborts early when ctx is cancelled so a Stop() during
|
||||
// service-mode startup doesn't leave a goroutine sleeping for 10 s.
|
||||
func dialAgentWithRetry(ctx context.Context, addr string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
var lastErr error
|
||||
for range 50 {
|
||||
if err := ctx.Err(); err != nil {
|
||||
if lastErr == nil {
|
||||
lastErr = err
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
dialCtx, cancel := context.WithTimeout(ctx, time.Second)
|
||||
c, err := d.DialContext(dialCtx, "unix", addr)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return c, nil
|
||||
}
|
||||
lastErr = err
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if errors.Is(lastErr, context.Canceled) || errors.Is(lastErr, context.DeadlineExceeded) {
|
||||
lastErr = ctx.Err()
|
||||
}
|
||||
return nil, lastErr
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// validateAgentPeer enforces that the peer behind the just-connected Unix
|
||||
// socket is the agent we expect it to be: a process running under
|
||||
// expectedUID, with the right effective uid stamped by the kernel on the
|
||||
// socket. Refuses (with a non-nil error) if anything else is listening on
|
||||
// the path (an unrelated local process that won the listen race or
|
||||
// squatted the path before us). Defends against the daemon shipping its
|
||||
// per-spawn auth token to a process that isn't the spawned agent.
|
||||
func validateAgentPeer(conn net.Conn, expectedUID uint32) error {
|
||||
uconn, ok := conn.(*net.UnixConn)
|
||||
if !ok {
|
||||
return fmt.Errorf("peer cred: expected *net.UnixConn, got %T", conn)
|
||||
}
|
||||
raw, err := uconn.SyscallConn()
|
||||
if err != nil {
|
||||
return fmt.Errorf("peer cred: syscall conn: %w", err)
|
||||
}
|
||||
var cred *unix.Xucred
|
||||
var inner error
|
||||
ctlErr := raw.Control(func(fd uintptr) {
|
||||
cred, inner = unix.GetsockoptXucred(int(fd), unix.SOL_LOCAL, unix.LOCAL_PEERCRED)
|
||||
})
|
||||
if ctlErr != nil {
|
||||
return fmt.Errorf("peer cred: control: %w", ctlErr)
|
||||
}
|
||||
if inner != nil {
|
||||
return fmt.Errorf("peer cred: getsockopt LOCAL_PEERCRED: %w", inner)
|
||||
}
|
||||
if cred == nil {
|
||||
return fmt.Errorf("peer cred: nil xucred")
|
||||
}
|
||||
if cred.Uid != expectedUID {
|
||||
return fmt.Errorf("peer cred: agent uid %d does not match expected %d", cred.Uid, expectedUID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,115 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestValidateAgentPeerAcceptsOwnUID confirms the happy path: a Unix
|
||||
// socket whose peer is the current process must validate when the
|
||||
// expected uid matches the process's own. Both sides of a unix-socket
|
||||
// pair share the same kernel cred, so this exercises the real getsockopt
|
||||
// LOCAL_PEERCRED path.
|
||||
func TestValidateAgentPeerAcceptsOwnUID(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sockPath := filepath.Join(dir, "test.sock")
|
||||
ln, err := net.Listen("unix", sockPath)
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
c, err := ln.Accept()
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
}
|
||||
}()
|
||||
c, err := net.Dial("unix", sockPath)
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
if err := validateAgentPeer(c, uint32(os.Getuid())); err != nil {
|
||||
t.Fatalf("validateAgentPeer rejected own uid: %v", err)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestValidateAgentPeerRejectsWrongUID ensures the validator fails when
|
||||
// the expected uid differs from the kernel-reported peer uid. This is
|
||||
// the path that catches a hostile process that won the listen race.
|
||||
func TestValidateAgentPeerRejectsWrongUID(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
sockPath := filepath.Join(dir, "test.sock")
|
||||
ln, err := net.Listen("unix", sockPath)
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
c, err := ln.Accept()
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
}
|
||||
}()
|
||||
c, err := net.Dial("unix", sockPath)
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
// Pick a uid the test process certainly isn't running as.
|
||||
wrongUID := uint32(os.Getuid()) + 1
|
||||
err = validateAgentPeer(c, wrongUID)
|
||||
if err == nil {
|
||||
t.Fatal("expected mismatch error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "does not match expected") {
|
||||
t.Fatalf("error should mention uid mismatch, got: %v", err)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// TestValidateAgentPeerRejectsNonUnix protects against being handed a
|
||||
// non-Unix-socket connection (the validator can't enforce anything on
|
||||
// e.g. a *net.TCPConn so it must refuse rather than silently pass).
|
||||
func TestValidateAgentPeerRejectsNonUnix(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen tcp: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
c, err := ln.Accept()
|
||||
if err == nil {
|
||||
_ = c.Close()
|
||||
}
|
||||
}()
|
||||
c, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("dial tcp: %v", err)
|
||||
}
|
||||
defer c.Close()
|
||||
if err := validateAgentPeer(c, 0); err == nil {
|
||||
t.Fatal("expected refusal on non-unix conn, got nil")
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// validateAgentPeer is a best-effort no-op on Windows: AF_UNIX sockets on
|
||||
// Windows do not expose SO_PEERCRED equivalents, and both the daemon and
|
||||
// the spawned agent run as SYSTEM in distinct sessions. The remaining
|
||||
// trust comes from the location of the socket file (under
|
||||
// C:\Windows\Temp, writable only by SYSTEM/Administrators) and from the
|
||||
// per-spawn auth token preamble that follows this call. Documented as a
|
||||
// known gap; a future hardening pass could interrogate the connected
|
||||
// pipe's PID via process-token APIs.
|
||||
func validateAgentPeer(_ net.Conn, _ uint32) error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,628 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
stillActive = 259
|
||||
|
||||
tokenPrimary = 1
|
||||
securityImpersonation = 2
|
||||
tokenSessionID = 12
|
||||
|
||||
createUnicodeEnvironment = 0x00000400
|
||||
createNoWindow = 0x08000000
|
||||
createSuspended = 0x00000004
|
||||
createBreakawayFromJob = 0x01000000
|
||||
)
|
||||
|
||||
var (
|
||||
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
advapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
||||
userenv = windows.NewLazySystemDLL("userenv.dll")
|
||||
|
||||
procWTSGetActiveConsoleSessionId = kernel32.NewProc("WTSGetActiveConsoleSessionId")
|
||||
procCreateJobObjectW = kernel32.NewProc("CreateJobObjectW")
|
||||
procSetInformationJobObject = kernel32.NewProc("SetInformationJobObject")
|
||||
procAssignProcessToJobObject = kernel32.NewProc("AssignProcessToJobObject")
|
||||
procSetTokenInformation = advapi32.NewProc("SetTokenInformation")
|
||||
procCreateEnvironmentBlock = userenv.NewProc("CreateEnvironmentBlock")
|
||||
procDestroyEnvironmentBlock = userenv.NewProc("DestroyEnvironmentBlock")
|
||||
|
||||
wtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll")
|
||||
procWTSEnumerateSessionsW = wtsapi32.NewProc("WTSEnumerateSessionsW")
|
||||
procWTSFreeMemory = wtsapi32.NewProc("WTSFreeMemory")
|
||||
procWTSQuerySessionInformation = wtsapi32.NewProc("WTSQuerySessionInformationW")
|
||||
)
|
||||
|
||||
// GetCurrentSessionID returns the session ID of the current process.
|
||||
func GetCurrentSessionID() uint32 {
|
||||
var token windows.Token
|
||||
if err := windows.OpenProcessToken(windows.CurrentProcess(),
|
||||
windows.TOKEN_QUERY, &token); err != nil {
|
||||
return 0
|
||||
}
|
||||
defer token.Close()
|
||||
var id uint32
|
||||
var ret uint32
|
||||
_ = windows.GetTokenInformation(token, windows.TokenSessionId,
|
||||
(*byte)(unsafe.Pointer(&id)), 4, &ret)
|
||||
return id
|
||||
}
|
||||
|
||||
func getConsoleSessionID() uint32 {
|
||||
r, _, _ := procWTSGetActiveConsoleSessionId.Call()
|
||||
return uint32(r)
|
||||
}
|
||||
|
||||
const (
|
||||
wtsActive = 0
|
||||
wtsConnected = 1
|
||||
wtsDisconnected = 4
|
||||
)
|
||||
|
||||
// getActiveSessionID returns the session ID of the best session to attach to.
|
||||
// On a Windows Server with no console display attached, session 1 still
|
||||
// reports WTSActive (login screen "owns" the console), so a naive
|
||||
// first-active-wins pick lands on a session with no actual rendering.
|
||||
// Preference order:
|
||||
// 1. Active session with a user logged in (RDP user in session ≥2)
|
||||
// 2. Active session without a user (console at login screen)
|
||||
// 3. Console session ID
|
||||
func getActiveSessionID() uint32 {
|
||||
var sessionInfo uintptr
|
||||
var count uint32
|
||||
|
||||
r, _, _ := procWTSEnumerateSessionsW.Call(
|
||||
0, // WTS_CURRENT_SERVER_HANDLE
|
||||
0, // reserved
|
||||
1, // version
|
||||
uintptr(unsafe.Pointer(&sessionInfo)),
|
||||
uintptr(unsafe.Pointer(&count)),
|
||||
)
|
||||
if r == 0 || count == 0 {
|
||||
return getConsoleSessionID()
|
||||
}
|
||||
defer func() { _, _, _ = procWTSFreeMemory.Call(sessionInfo) }()
|
||||
|
||||
type wtsSession struct {
|
||||
SessionID uint32
|
||||
Station *uint16
|
||||
State uint32
|
||||
}
|
||||
sessions := unsafe.Slice((*wtsSession)(unsafe.Pointer(sessionInfo)), count)
|
||||
|
||||
var withUser uint32
|
||||
var withUserFound bool
|
||||
var anyActive uint32
|
||||
var anyActiveFound bool
|
||||
for _, s := range sessions {
|
||||
if s.SessionID == 0 {
|
||||
continue
|
||||
}
|
||||
if s.State != wtsActive {
|
||||
continue
|
||||
}
|
||||
if !anyActiveFound {
|
||||
anyActive = s.SessionID
|
||||
anyActiveFound = true
|
||||
}
|
||||
if !withUserFound && wtsSessionHasUser(s.SessionID) {
|
||||
withUser = s.SessionID
|
||||
withUserFound = true
|
||||
}
|
||||
}
|
||||
if withUserFound {
|
||||
return withUser
|
||||
}
|
||||
if anyActiveFound {
|
||||
return anyActive
|
||||
}
|
||||
return getConsoleSessionID()
|
||||
}
|
||||
|
||||
// wtsSessionHasUser returns true if the session has a non-empty user name,
|
||||
// i.e. someone is logged in (vs. the login/Welcome screen). The console
|
||||
// session at the lock screen has WTSUserName == "".
|
||||
const wtsUserName = 5
|
||||
|
||||
func wtsSessionHasUser(sessionID uint32) bool {
|
||||
var buf uintptr
|
||||
var bytesReturned uint32
|
||||
r, _, _ := procWTSQuerySessionInformation.Call(
|
||||
0, // WTS_CURRENT_SERVER_HANDLE
|
||||
uintptr(sessionID),
|
||||
uintptr(wtsUserName),
|
||||
uintptr(unsafe.Pointer(&buf)),
|
||||
uintptr(unsafe.Pointer(&bytesReturned)),
|
||||
)
|
||||
if r == 0 || buf == 0 {
|
||||
return false
|
||||
}
|
||||
defer func() { _, _, _ = procWTSFreeMemory.Call(buf) }()
|
||||
// First UTF-16 code unit non-zero ⇒ non-empty username.
|
||||
return *(*uint16)(unsafe.Pointer(buf)) != 0
|
||||
}
|
||||
|
||||
// getSystemTokenForSession duplicates the current SYSTEM token and sets its
|
||||
// session ID so the spawned process runs in the target session. Using a SYSTEM
|
||||
// token gives access to both Default and Winlogon desktops plus UIPI bypass.
|
||||
func getSystemTokenForSession(sessionID uint32) (windows.Token, error) {
|
||||
var cur windows.Token
|
||||
if err := windows.OpenProcessToken(windows.CurrentProcess(),
|
||||
windows.MAXIMUM_ALLOWED, &cur); err != nil {
|
||||
return 0, fmt.Errorf("OpenProcessToken: %w", err)
|
||||
}
|
||||
defer cur.Close()
|
||||
|
||||
var dup windows.Token
|
||||
if err := windows.DuplicateTokenEx(cur, windows.MAXIMUM_ALLOWED, nil,
|
||||
securityImpersonation, tokenPrimary, &dup); err != nil {
|
||||
return 0, fmt.Errorf("DuplicateTokenEx: %w", err)
|
||||
}
|
||||
|
||||
sid := sessionID
|
||||
r, _, err := procSetTokenInformation.Call(
|
||||
uintptr(dup),
|
||||
uintptr(tokenSessionID),
|
||||
uintptr(unsafe.Pointer(&sid)),
|
||||
unsafe.Sizeof(sid),
|
||||
)
|
||||
if r == 0 {
|
||||
dup.Close()
|
||||
return 0, fmt.Errorf("SetTokenInformation(SessionId=%d): %w", sessionID, err)
|
||||
}
|
||||
return dup, nil
|
||||
}
|
||||
|
||||
// injectEnvVar appends a KEY=VALUE entry to a Unicode environment block.
|
||||
// The block is a sequence of null-terminated UTF-16 strings, terminated by
|
||||
// an extra null. Returns the new []uint16 backing slice; the caller must
|
||||
// hold the returned slice alive until CreateProcessAsUser completes.
|
||||
func injectEnvVar(envBlock uintptr, key, value string) []uint16 {
|
||||
entry := key + "=" + value
|
||||
|
||||
// Walk the existing block to find its total length.
|
||||
ptr := (*uint16)(unsafe.Pointer(envBlock))
|
||||
var totalChars int
|
||||
for {
|
||||
ch := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars)*2))
|
||||
if ch == 0 {
|
||||
// Check for double-null terminator.
|
||||
next := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars+1)*2))
|
||||
totalChars++
|
||||
if next == 0 {
|
||||
// End of block (don't count the final null yet, we'll rebuild).
|
||||
break
|
||||
}
|
||||
} else {
|
||||
totalChars++
|
||||
}
|
||||
}
|
||||
|
||||
entryUTF16, _ := windows.UTF16FromString(entry)
|
||||
// New block: existing entries + new entry (null-terminated) + final null.
|
||||
newLen := totalChars + len(entryUTF16) + 1
|
||||
newBlock := make([]uint16, newLen)
|
||||
// Copy existing entries (up to but not including the final null).
|
||||
for i := range totalChars {
|
||||
newBlock[i] = *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(i)*2))
|
||||
}
|
||||
copy(newBlock[totalChars:], entryUTF16)
|
||||
newBlock[newLen-1] = 0 // final null terminator
|
||||
|
||||
return newBlock
|
||||
}
|
||||
|
||||
func spawnAgentInSession(sessionID uint32, socketPath, authToken string, jobHandle windows.Handle) (windows.Handle, error) {
|
||||
token, err := getSystemTokenForSession(sessionID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get SYSTEM token for session %d: %w", sessionID, err)
|
||||
}
|
||||
defer token.Close()
|
||||
|
||||
var envBlock uintptr
|
||||
r, _, e := procCreateEnvironmentBlock.Call(
|
||||
uintptr(unsafe.Pointer(&envBlock)),
|
||||
uintptr(token),
|
||||
0,
|
||||
)
|
||||
if r == 0 {
|
||||
// Without an environment block we cannot inject NB_VNC_AGENT_TOKEN;
|
||||
// the agent would start unauthenticated. Abort instead of launching.
|
||||
return 0, fmt.Errorf("CreateEnvironmentBlock: %w", e)
|
||||
}
|
||||
defer func() { _, _, _ = procDestroyEnvironmentBlock.Call(envBlock) }()
|
||||
|
||||
// Inject the auth token into the environment block so it doesn't appear
|
||||
// in the process command line (visible via tasklist/wmic). injectedBlock
|
||||
// must stay alive until CreateProcessAsUser returns.
|
||||
injectedBlock := injectEnvVar(envBlock, agentTokenEnvVar, authToken)
|
||||
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get executable path: %w", err)
|
||||
}
|
||||
|
||||
cmdLine := fmt.Sprintf(`"%s" %s --socket %q`, exePath, vncAgentSubcommand, socketPath)
|
||||
cmdLineW, err := windows.UTF16PtrFromString(cmdLine)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("UTF16 cmdline: %w", err)
|
||||
}
|
||||
|
||||
// Create an inheritable pipe for the agent's stderr so we can relog
|
||||
// its output in the service process.
|
||||
var sa windows.SecurityAttributes
|
||||
sa.Length = uint32(unsafe.Sizeof(sa))
|
||||
sa.InheritHandle = 1
|
||||
|
||||
var stderrRead, stderrWrite windows.Handle
|
||||
if err := windows.CreatePipe(&stderrRead, &stderrWrite, &sa, 0); err != nil {
|
||||
return 0, fmt.Errorf("create stderr pipe: %w", err)
|
||||
}
|
||||
// The read end must NOT be inherited by the child.
|
||||
_ = windows.SetHandleInformation(stderrRead, windows.HANDLE_FLAG_INHERIT, 0)
|
||||
|
||||
desktop, _ := windows.UTF16PtrFromString(`WinSta0\Default`)
|
||||
si := windows.StartupInfo{
|
||||
Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})),
|
||||
Desktop: desktop,
|
||||
Flags: windows.STARTF_USESHOWWINDOW | windows.STARTF_USESTDHANDLES,
|
||||
ShowWindow: 0,
|
||||
StdErr: stderrWrite,
|
||||
StdOutput: stderrWrite,
|
||||
}
|
||||
var pi windows.ProcessInformation
|
||||
|
||||
var envPtr *uint16
|
||||
if len(injectedBlock) > 0 {
|
||||
envPtr = &injectedBlock[0]
|
||||
} else if envBlock != 0 {
|
||||
envPtr = (*uint16)(unsafe.Pointer(envBlock))
|
||||
}
|
||||
|
||||
// CREATE_SUSPENDED so we can assign the process to our Job Object
|
||||
// before it executes. Without this the agent could spawn its own child
|
||||
// processes and have them inherit the SCM service-job (not ours), or
|
||||
// briefly listen on the agent port before we tear it down on rollback.
|
||||
// CREATE_BREAKAWAY_FROM_JOB lets the child leave the SCM-managed
|
||||
// service job; harmless if that job allows breakaway, and is required
|
||||
// before AssignProcessToJobObject can succeed in the no-nested-jobs case.
|
||||
err = windows.CreateProcessAsUser(
|
||||
token, nil, cmdLineW,
|
||||
nil, nil, true, // inheritHandles=true for the pipe
|
||||
createUnicodeEnvironment|createNoWindow|createSuspended|createBreakawayFromJob,
|
||||
envPtr, nil, &si, &pi,
|
||||
)
|
||||
runtime.KeepAlive(injectedBlock)
|
||||
// Close the write end in the parent so reads will get EOF when the child exits.
|
||||
_ = windows.CloseHandle(stderrWrite)
|
||||
if err != nil {
|
||||
_ = windows.CloseHandle(stderrRead)
|
||||
return 0, fmt.Errorf("CreateProcessAsUser: %w", err)
|
||||
}
|
||||
|
||||
if jobHandle != 0 {
|
||||
r, _, e := procAssignProcessToJobObject.Call(uintptr(jobHandle), uintptr(pi.Process))
|
||||
if r == 0 {
|
||||
log.Warnf("assign agent to job object: %v (orphan possible on service crash)", e)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := windows.ResumeThread(pi.Thread); err != nil {
|
||||
_ = windows.CloseHandle(pi.Thread)
|
||||
_ = windows.TerminateProcess(pi.Process, 1)
|
||||
_ = windows.CloseHandle(pi.Process)
|
||||
_ = windows.CloseHandle(stderrRead)
|
||||
return 0, fmt.Errorf("ResumeThread: %w", err)
|
||||
}
|
||||
_ = windows.CloseHandle(pi.Thread)
|
||||
|
||||
// Relog agent output in the service with a [vnc-agent] prefix.
|
||||
go relogAgentOutput(stderrRead)
|
||||
|
||||
log.Infof("spawned agent PID=%d in session %d on %s", pi.ProcessId, sessionID, socketPath)
|
||||
return pi.Process, nil
|
||||
}
|
||||
|
||||
// sessionManager monitors the active console session and ensures a VNC agent
|
||||
// process is running in it. When the session changes (e.g., user switch, RDP
|
||||
// connect/disconnect), it kills the old agent and spawns a new one. Each
|
||||
// spawn picks a per-session Unix-socket path the agent binds and the
|
||||
// daemon dials over local IPC.
|
||||
type sessionManager struct {
|
||||
mu sync.Mutex
|
||||
agentProc windows.Handle
|
||||
everSpawned bool
|
||||
agentStartedAt time.Time
|
||||
spawnFailures int
|
||||
nextSpawnAt time.Time
|
||||
sessionID uint32
|
||||
authToken string
|
||||
socketPath string
|
||||
done chan struct{}
|
||||
// jobHandle owns the agent processes via a Windows Job Object with
|
||||
// JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE. When the service exits or crashes,
|
||||
// the OS closes the handle and terminates every assigned agent: no
|
||||
// orphaned agent processes holding a socket across restarts.
|
||||
jobHandle windows.Handle
|
||||
}
|
||||
|
||||
// agentSocketPathFmt parameterizes the per-session agent socket path by
|
||||
// the Windows session id. C:\Windows\Temp is writable to both the daemon
|
||||
// (SYSTEM) and the spawned agent (SYSTEM token impersonating the session).
|
||||
const agentSocketPathFmt = `C:\Windows\Temp\netbird-vnc-%d.sock`
|
||||
|
||||
func newSessionManager() *sessionManager {
|
||||
m := &sessionManager{sessionID: ^uint32(0), done: make(chan struct{})}
|
||||
if h, err := createKillOnCloseJob(); err != nil {
|
||||
log.Warnf("create job object for vnc-agent (orphan agents possible after crash): %v", err)
|
||||
} else {
|
||||
m.jobHandle = h
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// createKillOnCloseJob returns a Job Object configured so that closing its
|
||||
// handle (process exit or explicit Close) terminates every process assigned
|
||||
// to it. Used to keep orphaned vnc-agent processes from outliving the service.
|
||||
func createKillOnCloseJob() (windows.Handle, error) {
|
||||
r, _, e := procCreateJobObjectW.Call(0, 0)
|
||||
if r == 0 {
|
||||
return 0, fmt.Errorf("CreateJobObject: %w", e)
|
||||
}
|
||||
job := windows.Handle(r)
|
||||
|
||||
// JOBOBJECT_EXTENDED_LIMIT_INFORMATION on amd64 = 144 bytes.
|
||||
//
|
||||
// JOBOBJECT_BASIC_LIMIT_INFORMATION (64 bytes with alignment padding)
|
||||
// PerProcessUserTimeLimit LARGE_INTEGER off 0
|
||||
// PerJobUserTimeLimit LARGE_INTEGER off 8
|
||||
// LimitFlags DWORD off 16
|
||||
// [4 byte pad to align SIZE_T]
|
||||
// MinimumWorkingSetSize SIZE_T off 24
|
||||
// MaximumWorkingSetSize SIZE_T off 32
|
||||
// ActiveProcessLimit DWORD off 40
|
||||
// [4 byte pad to align ULONG_PTR]
|
||||
// Affinity ULONG_PTR off 48
|
||||
// PriorityClass DWORD off 56
|
||||
// SchedulingClass DWORD off 60
|
||||
// IO_COUNTERS (48) + 4 * SIZE_T (32) = 144 total.
|
||||
//
|
||||
// We only set LimitFlags; the rest stays zero.
|
||||
const sizeofExtended = 144
|
||||
const offsetLimitFlags = 16
|
||||
const jobObjectExtendedLimitInformation = 9
|
||||
const jobObjectLimitKillOnJobClose = 0x00002000
|
||||
|
||||
var info [sizeofExtended]byte
|
||||
binary.LittleEndian.PutUint32(info[offsetLimitFlags:offsetLimitFlags+4], jobObjectLimitKillOnJobClose)
|
||||
|
||||
r, _, e = procSetInformationJobObject.Call(
|
||||
uintptr(job),
|
||||
uintptr(jobObjectExtendedLimitInformation),
|
||||
uintptr(unsafe.Pointer(&info[0])),
|
||||
uintptr(sizeofExtended),
|
||||
)
|
||||
if r == 0 {
|
||||
_ = windows.CloseHandle(job)
|
||||
return 0, fmt.Errorf("SetInformationJobObject(KILL_ON_JOB_CLOSE): %w", e)
|
||||
}
|
||||
return job, nil
|
||||
}
|
||||
|
||||
// Resolve returns the current agent socket path, shared token, and the
|
||||
// uid the agent runs under (0 on Windows since the agent runs as
|
||||
// SYSTEM in the interactive session; validateAgentPeer is a no-op
|
||||
// there). When no agent is spawned yet (initial boot, between session
|
||||
// switches, or permanently disabled when SE_TCB_NAME is missing) it
|
||||
// surfaces a distinct error so the daemon can reject the connection
|
||||
// with a meaningful message instead of timing out the proxy dial.
|
||||
func (m *sessionManager) Resolve(_ context.Context) (string, string, uint32, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.socketPath == "" {
|
||||
return "", "", 0, errAgentNotReady
|
||||
}
|
||||
return m.socketPath, m.authToken, 0, nil
|
||||
}
|
||||
|
||||
var errAgentNotReady = errors.New("VNC agent not running yet")
|
||||
|
||||
// Stop signals the session manager to exit its polling loop and closes the
|
||||
// Job Object handle, which Windows uses as the trigger to terminate every
|
||||
// agent process this manager spawned.
|
||||
func (m *sessionManager) Stop() {
|
||||
select {
|
||||
case <-m.done:
|
||||
default:
|
||||
close(m.done)
|
||||
}
|
||||
m.mu.Lock()
|
||||
if m.jobHandle != 0 {
|
||||
_ = windows.CloseHandle(m.jobHandle)
|
||||
m.jobHandle = 0
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *sessionManager) run() {
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
if !m.tick() {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-m.done:
|
||||
m.mu.Lock()
|
||||
m.killAgent()
|
||||
m.mu.Unlock()
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tick performs one session/agent-state update. Returns false if the manager
|
||||
// should permanently stop (e.g. missing SYSTEM privileges).
|
||||
func (m *sessionManager) tick() bool {
|
||||
sid := getActiveSessionID()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.handleSessionChange(sid)
|
||||
m.reapExitedAgent()
|
||||
return m.maybeSpawnAgent(sid)
|
||||
}
|
||||
|
||||
func (m *sessionManager) handleSessionChange(sid uint32) {
|
||||
if sid == m.sessionID {
|
||||
return
|
||||
}
|
||||
log.Infof("active session changed: %d -> %d", m.sessionID, sid)
|
||||
m.killAgent()
|
||||
m.sessionID = sid
|
||||
}
|
||||
|
||||
func (m *sessionManager) reapExitedAgent() {
|
||||
if m.agentProc == 0 {
|
||||
return
|
||||
}
|
||||
var code uint32
|
||||
if err := windows.GetExitCodeProcess(m.agentProc, &code); err != nil {
|
||||
log.Debugf("GetExitCodeProcess: %v", err)
|
||||
return
|
||||
}
|
||||
if code == stillActive {
|
||||
return
|
||||
}
|
||||
m.scheduleNextSpawn(code, time.Since(m.agentStartedAt))
|
||||
if err := windows.CloseHandle(m.agentProc); err != nil {
|
||||
log.Debugf("close agent handle: %v", err)
|
||||
}
|
||||
m.agentProc = 0
|
||||
m.authToken = ""
|
||||
m.socketPath = ""
|
||||
}
|
||||
|
||||
// scheduleNextSpawn applies an exponential backoff on fast crashes (<5s) and
|
||||
// resets immediately otherwise.
|
||||
func (m *sessionManager) scheduleNextSpawn(exitCode uint32, lifetime time.Duration) {
|
||||
if lifetime < 5*time.Second {
|
||||
m.spawnFailures++
|
||||
backoff := time.Duration(1<<min(m.spawnFailures, 5)) * time.Second
|
||||
if backoff > 30*time.Second {
|
||||
backoff = 30 * time.Second
|
||||
}
|
||||
m.nextSpawnAt = time.Now().Add(backoff)
|
||||
log.Warnf("agent exited (code=%d) after %v, retrying in %v (failures=%d)", exitCode, lifetime.Round(time.Millisecond), backoff, m.spawnFailures)
|
||||
return
|
||||
}
|
||||
m.spawnFailures = 0
|
||||
m.nextSpawnAt = time.Time{}
|
||||
log.Infof("agent exited (code=%d) after %v, respawning", exitCode, lifetime.Round(time.Second))
|
||||
}
|
||||
|
||||
// maybeSpawnAgent spawns a new agent if there's no current one and the backoff
|
||||
// window has elapsed. Returns false to permanently stop the manager when the
|
||||
// service lacks the privileges needed to spawn cross-session.
|
||||
func (m *sessionManager) maybeSpawnAgent(sid uint32) bool {
|
||||
if m.agentProc != 0 || sid == 0xFFFFFFFF || !time.Now().After(m.nextSpawnAt) {
|
||||
return true
|
||||
}
|
||||
// Reap any orphan still holding the agent port from a previous
|
||||
// service instance, only on our very first spawn. Once we own
|
||||
// an agent, we manage its lifecycle ourselves and never need to
|
||||
// kill an unknown listener; if a kill+respawn races on port
|
||||
// release, the spawn-failure backoff handles it without forcing
|
||||
// a synchronous wait or duplicate kill.
|
||||
socketPath := fmt.Sprintf(agentSocketPathFmt, sid)
|
||||
// Covers a previous-run crash that escaped Job Object kill-on-close.
|
||||
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
|
||||
log.Debugf("clear stale agent socket %s: %v", socketPath, err)
|
||||
}
|
||||
token, err := generateAuthToken()
|
||||
if err != nil {
|
||||
log.Warnf("generate agent auth token: %v", err)
|
||||
return true
|
||||
}
|
||||
m.authToken = token
|
||||
m.socketPath = socketPath
|
||||
h, err := spawnAgentInSession(sid, socketPath, m.authToken, m.jobHandle)
|
||||
if err != nil {
|
||||
m.authToken = ""
|
||||
m.socketPath = ""
|
||||
if errors.Is(err, windows.ERROR_PRIVILEGE_NOT_HELD) {
|
||||
// SE_TCB_NAME (token-impersonation across sessions) is only
|
||||
// granted to SYSTEM. Without it spawnAgent will fail every 2
|
||||
// seconds forever: log once and give up.
|
||||
log.Warnf("VNC service mode disabled: agent spawn requires SYSTEM privileges (got: %v)", err)
|
||||
return false
|
||||
}
|
||||
log.Warnf("spawn agent in session %d: %v", sid, err)
|
||||
return true
|
||||
}
|
||||
m.agentProc = h
|
||||
m.agentStartedAt = time.Now()
|
||||
m.everSpawned = true
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *sessionManager) killAgent() {
|
||||
if m.agentProc == 0 {
|
||||
return
|
||||
}
|
||||
_ = windows.TerminateProcess(m.agentProc, 0)
|
||||
_ = windows.CloseHandle(m.agentProc)
|
||||
m.agentProc = 0
|
||||
m.authToken = ""
|
||||
m.socketPath = ""
|
||||
log.Info("killed old agent")
|
||||
}
|
||||
|
||||
// relogAgentOutput reads log lines from the agent's stderr pipe and
|
||||
// relogs them with the service's formatter. The *os.File owns the
|
||||
// underlying handle, so closing it suffices.
|
||||
func relogAgentOutput(pipe windows.Handle) {
|
||||
f := os.NewFile(uintptr(pipe), "vnc-agent-stderr")
|
||||
defer func() { _ = f.Close() }()
|
||||
relogAgentStream(f)
|
||||
}
|
||||
|
||||
// logCleanupCall invokes a Windows syscall used solely as a cleanup primitive
|
||||
// (CloseClipboard, ReleaseDC, etc.) and logs failures at trace level. The
|
||||
// indirection lets us satisfy errcheck without scattering ignored returns at
|
||||
// each call site, while still capturing diagnostic info when the OS reports
|
||||
// a failure.
|
||||
func logCleanupCall(name string, proc *windows.LazyProc) {
|
||||
r, _, err := proc.Call()
|
||||
if r == 0 && err != nil && err != windows.NTE_OP_OK {
|
||||
log.Tracef("%s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// logCleanupCallArgs is logCleanupCall with one argument; common pattern for
|
||||
// release-by-handle syscalls.
|
||||
func logCleanupCallArgs(name string, proc *windows.LazyProc, args ...uintptr) {
|
||||
r, _, err := proc.Call(args...)
|
||||
if r == 0 && err != nil && err != windows.NTE_OP_OK {
|
||||
log.Tracef("%s: %v", name, err)
|
||||
}
|
||||
}
|
||||
@@ -1,643 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"image"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var darwinCaptureOnce sync.Once
|
||||
|
||||
var (
|
||||
cgMainDisplayID func() uint32
|
||||
cgDisplayPixelsWide func(uint32) uintptr
|
||||
cgDisplayPixelsHigh func(uint32) uintptr
|
||||
cgDisplayCreateImage func(uint32) uintptr
|
||||
cgImageGetWidth func(uintptr) uintptr
|
||||
cgImageGetHeight func(uintptr) uintptr
|
||||
cgImageGetBytesPerRow func(uintptr) uintptr
|
||||
cgImageGetBitsPerPixel func(uintptr) uintptr
|
||||
cgImageGetDataProvider func(uintptr) uintptr
|
||||
cgDataProviderCopyData func(uintptr) uintptr
|
||||
cgImageRelease func(uintptr)
|
||||
cfDataGetLength func(uintptr) int64
|
||||
cfDataGetBytePtr func(uintptr) uintptr
|
||||
cfRelease func(uintptr)
|
||||
cgRequestScreenCaptureAccess func() bool
|
||||
cgEventCreate func(uintptr) uintptr
|
||||
cgEventGetLocation func(uintptr) cgPoint
|
||||
darwinCaptureReady bool
|
||||
)
|
||||
|
||||
// cgPoint mirrors CoreGraphics CGPoint: two doubles, 16 bytes, returned
|
||||
// in registers on Darwin amd64/arm64. Used to receive cursor coordinates
|
||||
// from CGEventGetLocation via purego.
|
||||
type cgPoint struct {
|
||||
X, Y float64
|
||||
}
|
||||
|
||||
func initDarwinCapture() {
|
||||
darwinCaptureOnce.Do(func() {
|
||||
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load CoreGraphics: %v", err)
|
||||
return
|
||||
}
|
||||
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
log.Debugf("load CoreFoundation: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
purego.RegisterLibFunc(&cgMainDisplayID, cg, "CGMainDisplayID")
|
||||
purego.RegisterLibFunc(&cgDisplayPixelsWide, cg, "CGDisplayPixelsWide")
|
||||
purego.RegisterLibFunc(&cgDisplayPixelsHigh, cg, "CGDisplayPixelsHigh")
|
||||
purego.RegisterLibFunc(&cgDisplayCreateImage, cg, "CGDisplayCreateImage")
|
||||
purego.RegisterLibFunc(&cgImageGetWidth, cg, "CGImageGetWidth")
|
||||
purego.RegisterLibFunc(&cgImageGetHeight, cg, "CGImageGetHeight")
|
||||
purego.RegisterLibFunc(&cgImageGetBytesPerRow, cg, "CGImageGetBytesPerRow")
|
||||
purego.RegisterLibFunc(&cgImageGetBitsPerPixel, cg, "CGImageGetBitsPerPixel")
|
||||
purego.RegisterLibFunc(&cgImageGetDataProvider, cg, "CGImageGetDataProvider")
|
||||
purego.RegisterLibFunc(&cgDataProviderCopyData, cg, "CGDataProviderCopyData")
|
||||
purego.RegisterLibFunc(&cgImageRelease, cg, "CGImageRelease")
|
||||
purego.RegisterLibFunc(&cfDataGetLength, cf, "CFDataGetLength")
|
||||
purego.RegisterLibFunc(&cfDataGetBytePtr, cf, "CFDataGetBytePtr")
|
||||
purego.RegisterLibFunc(&cfRelease, cf, "CFRelease")
|
||||
|
||||
// CGRequestScreenCaptureAccess (macOS 11+) prompts on first call and
|
||||
// is a cheap no-op once granted. The Preflight companion is unreliable
|
||||
// on Sequoia (returns false even when access is granted), so we drive
|
||||
// the permission flow from actual capture failures instead.
|
||||
if sym, err := purego.Dlsym(cg, "CGRequestScreenCaptureAccess"); err == nil {
|
||||
purego.RegisterFunc(&cgRequestScreenCaptureAccess, sym)
|
||||
}
|
||||
// CGEventCreate / CGEventGetLocation feed the cursor position used
|
||||
// by remote-cursor compositing. Optional; absence reports as a
|
||||
// position-source error and disables that feature on this host.
|
||||
if sym, err := purego.Dlsym(cg, "CGEventCreate"); err == nil {
|
||||
purego.RegisterFunc(&cgEventCreate, sym)
|
||||
}
|
||||
if sym, err := purego.Dlsym(cg, "CGEventGetLocation"); err == nil {
|
||||
purego.RegisterFunc(&cgEventGetLocation, sym)
|
||||
}
|
||||
|
||||
darwinCaptureReady = true
|
||||
})
|
||||
}
|
||||
|
||||
// CGCapturer captures the macOS main display using Core Graphics.
|
||||
type CGCapturer struct {
|
||||
displayID uint32
|
||||
w, h int
|
||||
// downscale is 1 for pixel-perfect, 2 for Retina 2:1 box-filter downscale.
|
||||
downscale int
|
||||
hashSeed maphash.Seed
|
||||
lastHash uint64
|
||||
hasHash bool
|
||||
// cursor lazily binds the private CGSCreateCurrentCursorImage symbol
|
||||
// so we can emit the Cursor pseudo-encoding without a per-frame cost
|
||||
// on builds that never query it.
|
||||
cursorOnce sync.Once
|
||||
cursor *cgCursor
|
||||
}
|
||||
|
||||
// PrimeScreenCapturePermission triggers the macOS Screen Recording
|
||||
// permission prompt without creating a full capturer. The platform wiring
|
||||
// calls this at VNC-server enable time so the user sees the prompt the
|
||||
// moment they turn the feature on. CGRequestScreenCaptureAccess is a
|
||||
// no-op when the grant already exists, so calling it on every enable is
|
||||
// cheap and safe.
|
||||
func PrimeScreenCapturePermission() {
|
||||
initDarwinCapture()
|
||||
if !darwinCaptureReady {
|
||||
return
|
||||
}
|
||||
if cgRequestScreenCaptureAccess != nil {
|
||||
cgRequestScreenCaptureAccess()
|
||||
}
|
||||
}
|
||||
|
||||
// notifyScreenRecordingMissing nudges the user once per agent process to
|
||||
// approve Screen Recording. The capturer init retries on backoff when the
|
||||
// grant is missing; without the sync.Once we would reopen System Settings
|
||||
// every tick and flood the daemon log with the same warning.
|
||||
var screenRecordingNotifyOnce sync.Once
|
||||
|
||||
func notifyScreenRecordingMissing() {
|
||||
screenRecordingNotifyOnce.Do(func() {
|
||||
if cgRequestScreenCaptureAccess != nil {
|
||||
cgRequestScreenCaptureAccess()
|
||||
}
|
||||
openPrivacyPane("Privacy_ScreenCapture")
|
||||
log.Warn("Screen Recording permission not granted. " +
|
||||
"Opened System Settings > Privacy & Security > Screen Recording; enable netbird and restart.")
|
||||
})
|
||||
}
|
||||
|
||||
// NewCGCapturer creates a screen capturer for the main display.
|
||||
func NewCGCapturer() (*CGCapturer, error) {
|
||||
initDarwinCapture()
|
||||
if !darwinCaptureReady {
|
||||
return nil, fmt.Errorf("CoreGraphics not available")
|
||||
}
|
||||
|
||||
displayID := cgMainDisplayID()
|
||||
c := &CGCapturer{displayID: displayID, downscale: 1, hashSeed: maphash.MakeSeed()}
|
||||
|
||||
img, err := c.Capture()
|
||||
if err != nil {
|
||||
notifyScreenRecordingMissing()
|
||||
return nil, fmt.Errorf("probe capture: %w", err)
|
||||
}
|
||||
nativeW := img.Rect.Dx()
|
||||
nativeH := img.Rect.Dy()
|
||||
c.hasHash = false
|
||||
if nativeW == 0 || nativeH == 0 {
|
||||
return nil, errors.New("display dimensions are zero")
|
||||
}
|
||||
|
||||
logicalW := int(cgDisplayPixelsWide(displayID))
|
||||
logicalH := int(cgDisplayPixelsHigh(displayID))
|
||||
|
||||
// Enable 2:1 downscale on Retina unless explicitly disabled. Cuts pixel
|
||||
// count 4x, shrinking convert, diff, and wire data proportionally.
|
||||
if !retinaDownscaleDisabled() && nativeW >= 2*logicalW && nativeH >= 2*logicalH && nativeW%2 == 0 && nativeH%2 == 0 {
|
||||
c.downscale = 2
|
||||
}
|
||||
c.w = nativeW / c.downscale
|
||||
c.h = nativeH / c.downscale
|
||||
|
||||
log.Infof("macOS capturer ready: %dx%d (native %dx%d, logical %dx%d, downscale=%d, display=%d)",
|
||||
c.w, c.h, nativeW, nativeH, logicalW, logicalH, c.downscale, displayID)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func retinaDownscaleDisabled() bool {
|
||||
v := os.Getenv(EnvVNCDisableDownscale)
|
||||
if v == "" {
|
||||
return false
|
||||
}
|
||||
disabled, err := strconv.ParseBool(v)
|
||||
if err != nil {
|
||||
log.Warnf("parse %s: %v", EnvVNCDisableDownscale, err)
|
||||
return false
|
||||
}
|
||||
return disabled
|
||||
}
|
||||
|
||||
// Width returns the screen width.
|
||||
func (c *CGCapturer) Width() int { return c.w }
|
||||
|
||||
// Height returns the screen height.
|
||||
func (c *CGCapturer) Height() int { return c.h }
|
||||
|
||||
// Capture returns the current screen as an RGBA image.
|
||||
// CaptureInto writes a fresh frame directly into dst, skipping the
|
||||
// per-frame image.RGBA allocation that Capture() does. Returns
|
||||
// errFrameUnchanged when the screen hash matches the prior call.
|
||||
func (c *CGCapturer) CaptureInto(dst *image.RGBA) error {
|
||||
cgImage := cgDisplayCreateImage(c.displayID)
|
||||
if cgImage == 0 {
|
||||
return fmt.Errorf("CGDisplayCreateImage returned nil (screen recording permission?)")
|
||||
}
|
||||
defer cgImageRelease(cgImage)
|
||||
w := int(cgImageGetWidth(cgImage))
|
||||
h := int(cgImageGetHeight(cgImage))
|
||||
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
|
||||
bpp := int(cgImageGetBitsPerPixel(cgImage))
|
||||
provider := cgImageGetDataProvider(cgImage)
|
||||
if provider == 0 {
|
||||
return fmt.Errorf("CGImageGetDataProvider returned nil")
|
||||
}
|
||||
cfData := cgDataProviderCopyData(provider)
|
||||
if cfData == 0 {
|
||||
return fmt.Errorf("CGDataProviderCopyData returned nil")
|
||||
}
|
||||
defer cfRelease(cfData)
|
||||
dataLen := int(cfDataGetLength(cfData))
|
||||
dataPtr := cfDataGetBytePtr(cfData)
|
||||
if dataPtr == 0 || dataLen == 0 {
|
||||
return fmt.Errorf("empty image data")
|
||||
}
|
||||
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
|
||||
hash := maphash.Bytes(c.hashSeed, src)
|
||||
if c.hasHash && hash == c.lastHash {
|
||||
return errFrameUnchanged
|
||||
}
|
||||
c.lastHash = hash
|
||||
c.hasHash = true
|
||||
|
||||
ds := c.downscale
|
||||
if ds < 1 {
|
||||
ds = 1
|
||||
}
|
||||
outW := w / ds
|
||||
outH := h / ds
|
||||
if dst.Rect.Dx() != outW || dst.Rect.Dy() != outH {
|
||||
return fmt.Errorf("dst size mismatch: dst=%dx%d capturer=%dx%d",
|
||||
dst.Rect.Dx(), dst.Rect.Dy(), outW, outH)
|
||||
}
|
||||
bytesPerPixel := bpp / 8
|
||||
if bytesPerPixel == 4 && ds == 1 {
|
||||
convertBGRAToRGBA(dst.Pix, dst.Stride, src, bytesPerRow, w, h)
|
||||
return nil
|
||||
}
|
||||
if bytesPerPixel == 4 && ds == 2 {
|
||||
convertBGRAToRGBADownscale2(dst.Pix, dst.Stride, src, bytesPerRow, outW, outH)
|
||||
return nil
|
||||
}
|
||||
for row := 0; row < outH; row++ {
|
||||
srcOff := row * ds * bytesPerRow
|
||||
dstOff := row * dst.Stride
|
||||
for col := 0; col < outW; col++ {
|
||||
si := srcOff + col*ds*bytesPerPixel
|
||||
di := dstOff + col*4
|
||||
dst.Pix[di+0] = src[si+2]
|
||||
dst.Pix[di+1] = src[si+1]
|
||||
dst.Pix[di+2] = src[si+0]
|
||||
dst.Pix[di+3] = 0xff
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CGCapturer) Capture() (*image.RGBA, error) {
|
||||
cgImage := cgDisplayCreateImage(c.displayID)
|
||||
if cgImage == 0 {
|
||||
return nil, fmt.Errorf("CGDisplayCreateImage returned nil (screen recording permission?)")
|
||||
}
|
||||
defer cgImageRelease(cgImage)
|
||||
|
||||
w := int(cgImageGetWidth(cgImage))
|
||||
h := int(cgImageGetHeight(cgImage))
|
||||
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
|
||||
bpp := int(cgImageGetBitsPerPixel(cgImage))
|
||||
|
||||
provider := cgImageGetDataProvider(cgImage)
|
||||
if provider == 0 {
|
||||
return nil, fmt.Errorf("CGImageGetDataProvider returned nil")
|
||||
}
|
||||
|
||||
cfData := cgDataProviderCopyData(provider)
|
||||
if cfData == 0 {
|
||||
return nil, fmt.Errorf("CGDataProviderCopyData returned nil")
|
||||
}
|
||||
defer cfRelease(cfData)
|
||||
|
||||
dataLen := int(cfDataGetLength(cfData))
|
||||
dataPtr := cfDataGetBytePtr(cfData)
|
||||
if dataPtr == 0 || dataLen == 0 {
|
||||
return nil, fmt.Errorf("empty image data")
|
||||
}
|
||||
|
||||
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
|
||||
|
||||
hash := maphash.Bytes(c.hashSeed, src)
|
||||
if c.hasHash && hash == c.lastHash {
|
||||
return nil, errFrameUnchanged
|
||||
}
|
||||
c.lastHash = hash
|
||||
c.hasHash = true
|
||||
|
||||
ds := c.downscale
|
||||
if ds < 1 {
|
||||
ds = 1
|
||||
}
|
||||
outW := w / ds
|
||||
outH := h / ds
|
||||
img := image.NewRGBA(image.Rect(0, 0, outW, outH))
|
||||
|
||||
bytesPerPixel := bpp / 8
|
||||
switch {
|
||||
case bytesPerPixel == 4 && ds == 1:
|
||||
convertBGRAToRGBA(img.Pix, img.Stride, src, bytesPerRow, w, h)
|
||||
case bytesPerPixel == 4 && ds == 2:
|
||||
convertBGRAToRGBADownscale2(img.Pix, img.Stride, src, bytesPerRow, outW, outH)
|
||||
default:
|
||||
convertBGRAToRGBAGeneric(img.Pix, img.Stride, src, bytesPerRow, bgraDownscaleParams{outW: outW, outH: outH, bytesPerPixel: bytesPerPixel, ds: ds})
|
||||
}
|
||||
|
||||
return img, nil
|
||||
}
|
||||
|
||||
type bgraDownscaleParams struct {
|
||||
outW, outH, bytesPerPixel, ds int
|
||||
}
|
||||
|
||||
// convertBGRAToRGBAGeneric is the slow per-pixel fallback for non-4-bytes
|
||||
// or non-1/2 downscale formats. Always available regardless of the source
|
||||
// format quirks the fast paths optimize for.
|
||||
func convertBGRAToRGBAGeneric(dst []byte, dstStride int, src []byte, srcStride int, p bgraDownscaleParams) {
|
||||
for row := 0; row < p.outH; row++ {
|
||||
srcOff := row * p.ds * srcStride
|
||||
dstOff := row * dstStride
|
||||
for col := 0; col < p.outW; col++ {
|
||||
si := srcOff + col*p.ds*p.bytesPerPixel
|
||||
di := dstOff + col*4
|
||||
dst[di+0] = src[si+2]
|
||||
dst[di+1] = src[si+1]
|
||||
dst[di+2] = src[si+0]
|
||||
dst[di+3] = 0xff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// convertBGRAToRGBADownscale2 averages every 2x2 BGRA block into one RGBA
|
||||
// output pixel, parallelised across GOMAXPROCS cores. outW and outH are the
|
||||
// destination dimensions (source is 2*outW by 2*outH).
|
||||
func convertBGRAToRGBADownscale2(dst []byte, dstStride int, src []byte, srcStride, outW, outH int) {
|
||||
workers := runtime.GOMAXPROCS(0)
|
||||
if workers > outH {
|
||||
workers = outH
|
||||
}
|
||||
if workers < 1 || outH < 32 {
|
||||
workers = 1
|
||||
}
|
||||
|
||||
convertRows := func(y0, y1 int) {
|
||||
for row := y0; row < y1; row++ {
|
||||
srcRow0 := 2 * row * srcStride
|
||||
srcRow1 := srcRow0 + srcStride
|
||||
dstOff := row * dstStride
|
||||
for col := 0; col < outW; col++ {
|
||||
s0 := srcRow0 + col*8
|
||||
s1 := srcRow1 + col*8
|
||||
b := (uint32(src[s0]) + uint32(src[s0+4]) + uint32(src[s1]) + uint32(src[s1+4])) >> 2
|
||||
g := (uint32(src[s0+1]) + uint32(src[s0+5]) + uint32(src[s1+1]) + uint32(src[s1+5])) >> 2
|
||||
r := (uint32(src[s0+2]) + uint32(src[s0+6]) + uint32(src[s1+2]) + uint32(src[s1+6])) >> 2
|
||||
di := dstOff + col*4
|
||||
dst[di+0] = byte(r)
|
||||
dst[di+1] = byte(g)
|
||||
dst[di+2] = byte(b)
|
||||
dst[di+3] = 0xff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if workers == 1 {
|
||||
convertRows(0, outH)
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
chunk := (outH + workers - 1) / workers
|
||||
for i := 0; i < workers; i++ {
|
||||
y0 := i * chunk
|
||||
y1 := y0 + chunk
|
||||
if y1 > outH {
|
||||
y1 = outH
|
||||
}
|
||||
if y0 >= y1 {
|
||||
break
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(y0, y1 int) {
|
||||
defer wg.Done()
|
||||
convertRows(y0, y1)
|
||||
}(y0, y1)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// convertBGRAToRGBA swaps R/B channels using uint32 word operations, and
|
||||
// parallelises across GOMAXPROCS cores for large images.
|
||||
func convertBGRAToRGBA(dst []byte, dstStride int, src []byte, srcStride, w, h int) {
|
||||
workers := runtime.GOMAXPROCS(0)
|
||||
if workers > h {
|
||||
workers = h
|
||||
}
|
||||
if workers < 1 || h < 64 {
|
||||
workers = 1
|
||||
}
|
||||
|
||||
convertRows := func(y0, y1 int) {
|
||||
rowBytes := w * 4
|
||||
for row := y0; row < y1; row++ {
|
||||
dstRow := dst[row*dstStride : row*dstStride+rowBytes]
|
||||
srcRow := src[row*srcStride : row*srcStride+rowBytes]
|
||||
dstU := unsafe.Slice((*uint32)(unsafe.Pointer(&dstRow[0])), w)
|
||||
srcU := unsafe.Slice((*uint32)(unsafe.Pointer(&srcRow[0])), w)
|
||||
for i, p := range srcU {
|
||||
dstU[i] = (p & 0xff00ff00) | ((p & 0x000000ff) << 16) | ((p & 0x00ff0000) >> 16) | 0xff000000
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if workers == 1 {
|
||||
convertRows(0, h)
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
chunk := (h + workers - 1) / workers
|
||||
for i := 0; i < workers; i++ {
|
||||
y0 := i * chunk
|
||||
y1 := y0 + chunk
|
||||
if y1 > h {
|
||||
y1 = h
|
||||
}
|
||||
if y0 >= y1 {
|
||||
break
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(y0, y1 int) {
|
||||
defer wg.Done()
|
||||
convertRows(y0, y1)
|
||||
}(y0, y1)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// MacPoller wraps CGCapturer with a staleness-cached on-demand Capture:
|
||||
// sessions drive captures themselves from their encoder goroutine, so we
|
||||
// don't need a background ticker. The last result is cached for a short
|
||||
// window so concurrent sessions coalesce into one capture.
|
||||
//
|
||||
// The capturer is allocated lazily on first use and released when all
|
||||
// clients disconnect. Init is retried with backoff because the user may
|
||||
// grant Screen Recording permission while the server is already running.
|
||||
type MacPoller struct {
|
||||
mu sync.Mutex
|
||||
|
||||
capturer *CGCapturer
|
||||
w, h int
|
||||
|
||||
lastFrame *image.RGBA
|
||||
lastAt time.Time
|
||||
|
||||
clients atomic.Int32
|
||||
initFails int
|
||||
initBackoffUntil time.Time
|
||||
closed bool
|
||||
}
|
||||
|
||||
// macInitRetryBackoffFor returns the delay we wait between init attempts
|
||||
// after consecutive failures. Screen Recording permission is a one-shot
|
||||
// user grant, so after several failures we back off aggressively.
|
||||
func macInitRetryBackoffFor(fails int) time.Duration {
|
||||
switch {
|
||||
case fails > 15:
|
||||
return 30 * time.Second
|
||||
case fails > 5:
|
||||
return 10 * time.Second
|
||||
default:
|
||||
return 2 * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
// NewMacPoller creates a lazy on-demand capturer for the macOS display.
|
||||
func NewMacPoller() *MacPoller {
|
||||
return &MacPoller{}
|
||||
}
|
||||
|
||||
// Wake is a no-op retained for API compatibility. With on-demand capture
|
||||
// there is no background retry loop to kick: init happens on the next
|
||||
// Capture/ClientConnect call.
|
||||
func (p *MacPoller) Wake() {
|
||||
// intentional no-op
|
||||
}
|
||||
|
||||
// ClientConnect increments the active client count and eagerly initialises
|
||||
// the capturer so the first FBUpdateRequest doesn't pay the init cost.
|
||||
func (p *MacPoller) ClientConnect() {
|
||||
if p.clients.Add(1) == 1 {
|
||||
p.mu.Lock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// ClientDisconnect decrements the active client count. On the last
|
||||
// disconnect the capturer is released.
|
||||
func (p *MacPoller) ClientDisconnect() {
|
||||
if p.clients.Add(-1) == 0 {
|
||||
p.mu.Lock()
|
||||
p.capturer = nil
|
||||
p.lastFrame = nil
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Close releases all resources.
|
||||
func (p *MacPoller) Close() {
|
||||
p.mu.Lock()
|
||||
p.closed = true
|
||||
p.capturer = nil
|
||||
p.lastFrame = nil
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
// Width returns the screen width. Triggers lazy init if needed.
|
||||
func (p *MacPoller) Width() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
return p.w
|
||||
}
|
||||
|
||||
// Height returns the screen height. Triggers lazy init if needed.
|
||||
func (p *MacPoller) Height() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
return p.h
|
||||
}
|
||||
|
||||
// CaptureInto fills dst directly via the underlying capturer, bypassing
|
||||
// the freshness cache.
|
||||
func (p *MacPoller) CaptureInto(dst *image.RGBA) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
err := p.capturer.CaptureInto(dst)
|
||||
if errors.Is(err, errFrameUnchanged) {
|
||||
// Caller (session) treats this as "no change"; the dst buffer
|
||||
// keeps its prior contents from the previous capture cycle so
|
||||
// the diff stays meaningful.
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
p.capturer = nil
|
||||
return fmt.Errorf("macos capture: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Capture returns a fresh frame, serving from the short-lived cache if a
|
||||
// previous caller captured within freshWindow. Handles the
|
||||
// errFrameUnchanged return from CGCapturer by reusing the cached frame.
|
||||
func (p *MacPoller) Capture() (*image.RGBA, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.lastFrame != nil && time.Since(p.lastAt) < freshWindow {
|
||||
return p.lastFrame, nil
|
||||
}
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
img, err := p.capturer.Capture()
|
||||
if errors.Is(err, errFrameUnchanged) {
|
||||
if p.lastFrame != nil {
|
||||
p.lastAt = time.Now()
|
||||
return p.lastFrame, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if err != nil {
|
||||
// Drop the capturer so the next call retries init; the display stream
|
||||
// can die if the session changes or permissions are revoked.
|
||||
p.capturer = nil
|
||||
return nil, fmt.Errorf("macos capture: %w", err)
|
||||
}
|
||||
p.lastFrame = img
|
||||
p.lastAt = time.Now()
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// ensureCapturerLocked initialises the underlying CGCapturer if needed.
|
||||
// Caller must hold p.mu.
|
||||
func (p *MacPoller) ensureCapturerLocked() error {
|
||||
if p.closed {
|
||||
return fmt.Errorf("poller closed")
|
||||
}
|
||||
if p.capturer != nil {
|
||||
return nil
|
||||
}
|
||||
if time.Now().Before(p.initBackoffUntil) {
|
||||
return fmt.Errorf("macOS capturer unavailable (retry scheduled)")
|
||||
}
|
||||
c, err := NewCGCapturer()
|
||||
if err != nil {
|
||||
p.initFails++
|
||||
p.initBackoffUntil = time.Now().Add(macInitRetryBackoffFor(p.initFails))
|
||||
if p.initFails == 1 || p.initFails%10 == 0 {
|
||||
log.Warnf("macOS capturer: %v (attempt %d)", err, p.initFails)
|
||||
} else {
|
||||
log.Debugf("macOS capturer: %v (attempt %d)", err, p.initFails)
|
||||
}
|
||||
return err
|
||||
}
|
||||
p.initFails = 0
|
||||
p.capturer = c
|
||||
p.w, p.h = c.Width(), c.Height()
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ ScreenCapturer = (*MacPoller)(nil)
|
||||
@@ -1,99 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
|
||||
"github.com/kirides/go-d3d/d3d11"
|
||||
"github.com/kirides/go-d3d/outputduplication"
|
||||
)
|
||||
|
||||
// dxgiCapturer captures the desktop using DXGI Desktop Duplication.
|
||||
// Provides GPU-accelerated capture with native dirty rect tracking.
|
||||
// Only works from the interactive user session, not Session 0.
|
||||
//
|
||||
// Uses a double-buffer: DXGI writes into img, then we copy to the current
|
||||
// output buffer and hand it out. Alternating between two output buffers
|
||||
// avoids allocating a new image.RGBA per frame (~8MB at 1080p, 30fps).
|
||||
type dxgiCapturer struct {
|
||||
dup *outputduplication.OutputDuplicator
|
||||
device *d3d11.ID3D11Device
|
||||
ctx *d3d11.ID3D11DeviceContext
|
||||
img *image.RGBA
|
||||
out [2]*image.RGBA
|
||||
outIdx int
|
||||
width int
|
||||
height int
|
||||
}
|
||||
|
||||
func newDXGICapturer() (*dxgiCapturer, error) {
|
||||
device, deviceCtx, err := d3d11.NewD3D11Device()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create D3D11 device: %w", err)
|
||||
}
|
||||
|
||||
dup, err := outputduplication.NewIDXGIOutputDuplication(device, deviceCtx, 0)
|
||||
if err != nil {
|
||||
device.Release()
|
||||
deviceCtx.Release()
|
||||
return nil, fmt.Errorf("create output duplication: %w", err)
|
||||
}
|
||||
|
||||
w, h := screenSize()
|
||||
if w == 0 || h == 0 {
|
||||
dup.Release()
|
||||
device.Release()
|
||||
deviceCtx.Release()
|
||||
return nil, fmt.Errorf("screen dimensions are zero")
|
||||
}
|
||||
|
||||
rect := image.Rect(0, 0, w, h)
|
||||
c := &dxgiCapturer{
|
||||
dup: dup,
|
||||
device: device,
|
||||
ctx: deviceCtx,
|
||||
img: image.NewRGBA(rect),
|
||||
out: [2]*image.RGBA{image.NewRGBA(rect), image.NewRGBA(rect)},
|
||||
width: w,
|
||||
height: h,
|
||||
}
|
||||
|
||||
// Grab the initial frame with a longer timeout to ensure we have
|
||||
// a valid image before returning.
|
||||
_ = dup.GetImage(c.img, 2000)
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *dxgiCapturer) capture() (*image.RGBA, error) {
|
||||
err := c.dup.GetImage(c.img, 100)
|
||||
if err != nil && !errors.Is(err, outputduplication.ErrNoImageYet) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Copy into the next output buffer. The DesktopCapturer hands out the
|
||||
// returned pointer to VNC sessions that read pixels concurrently, so we
|
||||
// alternate between two pre-allocated buffers instead of allocating per frame.
|
||||
out := c.out[c.outIdx]
|
||||
c.outIdx ^= 1
|
||||
copy(out.Pix, c.img.Pix)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *dxgiCapturer) close() {
|
||||
if c.dup != nil {
|
||||
c.dup.Release()
|
||||
c.dup = nil
|
||||
}
|
||||
if c.ctx != nil {
|
||||
c.ctx.Release()
|
||||
c.ctx = nil
|
||||
}
|
||||
if c.device != nil {
|
||||
c.device.Release()
|
||||
c.device = nil
|
||||
}
|
||||
}
|
||||
@@ -1,148 +0,0 @@
|
||||
//go:build freebsd
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// FreeBSD vt(4) framebuffer ioctl numbers from sys/fbio.h.
|
||||
//
|
||||
// #define FBIOGTYPE _IOR('F', 0, struct fbtype)
|
||||
//
|
||||
// _IOR(g, n, t) on FreeBSD: dir=2 (read) <<30 | (sizeof(t) & 0x1fff)<<16
|
||||
// | (g<<8) | n. sizeof(struct fbtype)=24 → 0x40184600.
|
||||
const fbioGType = 0x40184600
|
||||
|
||||
func defaultFBPath() string { return "/dev/ttyv0" }
|
||||
|
||||
// fbType mirrors FreeBSD's struct fbtype.
|
||||
type fbType struct {
|
||||
FbType int32
|
||||
FbHeight int32
|
||||
FbWidth int32
|
||||
FbDepth int32
|
||||
FbCMSize int32
|
||||
FbSize int32
|
||||
}
|
||||
|
||||
// FBCapturer reads pixels from FreeBSD's vt(4) framebuffer device. The
|
||||
// vt(4) console exposes the active framebuffer via ttyv0 with FBIOGTYPE
|
||||
// for geometry and mmap for backing memory. Pixel layout is assumed to
|
||||
// be 32bpp BGRA (the common case for KMS-backed vt); fbtype doesn't
|
||||
// expose channel offsets, so we don't try to handle exotic layouts here.
|
||||
type FBCapturer struct {
|
||||
mu sync.Mutex
|
||||
path string
|
||||
fd int
|
||||
mmap []byte
|
||||
w, h int
|
||||
bpp int
|
||||
stride int
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// NewFBCapturer opens the given vt(4) device and queries its geometry.
|
||||
func NewFBCapturer(path string) (*FBCapturer, error) {
|
||||
if path == "" {
|
||||
path = defaultFBPath()
|
||||
}
|
||||
fd, err := unix.Open(path, unix.O_RDWR, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open %s: %w", path, err)
|
||||
}
|
||||
|
||||
var fbt fbType
|
||||
if _, _, e := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), fbioGType, uintptr(unsafe.Pointer(&fbt))); e != 0 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("FBIOGTYPE: %v", e)
|
||||
}
|
||||
if fbt.FbDepth != 16 && fbt.FbDepth != 24 && fbt.FbDepth != 32 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("unsupported framebuffer depth: %d", fbt.FbDepth)
|
||||
}
|
||||
if fbt.FbWidth <= 0 || fbt.FbHeight <= 0 || fbt.FbSize <= 0 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("invalid framebuffer geometry: %dx%d size=%d", fbt.FbWidth, fbt.FbHeight, fbt.FbSize)
|
||||
}
|
||||
|
||||
mm, err := unix.Mmap(fd, 0, int(fbt.FbSize), unix.PROT_READ, unix.MAP_SHARED)
|
||||
if err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("mmap %s: %w (vt may not support mmap on this driver, e.g. virtio_gpu)", path, err)
|
||||
}
|
||||
|
||||
bpp := int(fbt.FbDepth)
|
||||
stride := int(fbt.FbWidth) * (bpp / 8)
|
||||
c := &FBCapturer{
|
||||
path: path,
|
||||
fd: fd, // valid fd >= 0; we use -1 as the closed sentinel
|
||||
mmap: mm,
|
||||
w: int(fbt.FbWidth),
|
||||
h: int(fbt.FbHeight),
|
||||
bpp: bpp,
|
||||
stride: stride,
|
||||
}
|
||||
log.Infof("framebuffer capturer ready: %s %dx%d bpp=%d (freebsd vt)", path, c.w, c.h, c.bpp)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Width returns the framebuffer width.
|
||||
func (c *FBCapturer) Width() int { return c.w }
|
||||
|
||||
// Height returns the framebuffer height.
|
||||
func (c *FBCapturer) Height() int { return c.h }
|
||||
|
||||
// Capture allocates a fresh image and fills it with the current
|
||||
// framebuffer contents.
|
||||
func (c *FBCapturer) Capture() (*image.RGBA, error) {
|
||||
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
|
||||
if err := c.CaptureInto(img); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// CaptureInto reads the framebuffer directly into dst.Pix. Assumes BGRA
|
||||
// for 32bpp; the FreeBSD fbtype struct doesn't expose channel offsets.
|
||||
func (c *FBCapturer) CaptureInto(dst *image.RGBA) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if dst.Rect.Dx() != c.w || dst.Rect.Dy() != c.h {
|
||||
return fmt.Errorf("dst size mismatch: dst=%dx%d fb=%dx%d",
|
||||
dst.Rect.Dx(), dst.Rect.Dy(), c.w, c.h)
|
||||
}
|
||||
switch c.bpp {
|
||||
case 32:
|
||||
// vt(4) on KMS framebuffers is BGRA: byte 0=B, 1=G, 2=R.
|
||||
swizzleBGRAtoRGBA(dst.Pix, c.mmap[:c.h*c.stride])
|
||||
case 24:
|
||||
swizzleFB24(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h)
|
||||
case 16:
|
||||
swizzleFB16RGB565(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close releases the framebuffer mmap and file descriptor. Serialized with
|
||||
// CaptureInto via c.mu so an in-flight capture can't read freed memory.
|
||||
func (c *FBCapturer) Close() {
|
||||
c.closeOnce.Do(func() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.mmap != nil {
|
||||
_ = unix.Munmap(c.mmap)
|
||||
c.mmap = nil
|
||||
}
|
||||
if c.fd >= 0 {
|
||||
_ = unix.Close(c.fd)
|
||||
c.fd = -1
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,229 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"image"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// Linux framebuffer ioctls (linux/fb.h).
|
||||
const (
|
||||
fbioGetVScreenInfo = 0x4600
|
||||
fbioGetFScreenInfo = 0x4602
|
||||
)
|
||||
|
||||
func defaultFBPath() string { return "/dev/fb0" }
|
||||
|
||||
// fbVarScreenInfo mirrors the kernel's fb_var_screeninfo. Only the
|
||||
// fields we use are mapped; the rest are absorbed into _padN.
|
||||
type fbVarScreenInfo struct {
|
||||
Xres, Yres uint32
|
||||
XresVirtual, YresVirtual uint32
|
||||
XOffset, YOffset uint32
|
||||
BitsPerPixel uint32
|
||||
Grayscale uint32
|
||||
RedOffset, RedLen, RedMSBR uint32
|
||||
GreenOffset, GreenLen, GreenMSBR uint32
|
||||
BlueOffset, BlueLen, BlueMSBR uint32
|
||||
TranspOffset, TranspLen, TranspM uint32
|
||||
NonStd uint32
|
||||
Activate uint32
|
||||
Height, Width uint32
|
||||
AccelFlags uint32
|
||||
PixClock uint32
|
||||
LeftMargin, RightMargin uint32
|
||||
UpperMargin, LowerMargin uint32
|
||||
HsyncLen, VsyncLen uint32
|
||||
Sync uint32
|
||||
Vmode uint32
|
||||
Rotate uint32
|
||||
Colorspace uint32
|
||||
_pad [4]uint32
|
||||
}
|
||||
|
||||
// fbFixScreenInfo mirrors fb_fix_screeninfo. We only need LineLength.
|
||||
type fbFixScreenInfo struct {
|
||||
IDStr [16]byte
|
||||
SmemStart uint64
|
||||
SmemLen uint32
|
||||
Type uint32
|
||||
TypeAux uint32
|
||||
Visual uint32
|
||||
XPanStep uint16
|
||||
YPanStep uint16
|
||||
YWrapStep uint16
|
||||
_pad0 uint16
|
||||
LineLength uint32
|
||||
MmioStart uint64
|
||||
MmioLen uint32
|
||||
Accel uint32
|
||||
Capabilities uint16
|
||||
_reserved [2]uint16
|
||||
}
|
||||
|
||||
// FBCapturer reads pixels straight from the Linux framebuffer device.
|
||||
// Used as a fallback when X11 isn't available, e.g. on a headless box at
|
||||
// the kernel console or the display manager's pre-login screen on machines
|
||||
// without an Xorg server. The framebuffer must be mmap()-able under our
|
||||
// process privileges (typically the netbird service runs as root).
|
||||
type FBCapturer struct {
|
||||
mu sync.Mutex
|
||||
path string
|
||||
fd int
|
||||
mmap []byte
|
||||
w, h int
|
||||
bpp int
|
||||
stride int
|
||||
rOff uint32
|
||||
gOff uint32
|
||||
bOff uint32
|
||||
rLen uint32
|
||||
gLen uint32
|
||||
bLen uint32
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// NewFBCapturer opens the given framebuffer device (/dev/fbN) and
|
||||
// queries its current geometry + pixel format.
|
||||
func NewFBCapturer(path string) (*FBCapturer, error) {
|
||||
if path == "" {
|
||||
path = "/dev/fb0"
|
||||
}
|
||||
fd, err := unix.Open(path, unix.O_RDONLY, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open %s: %w", path, err)
|
||||
}
|
||||
|
||||
var vinfo fbVarScreenInfo
|
||||
if _, _, e := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), fbioGetVScreenInfo, uintptr(unsafe.Pointer(&vinfo))); e != 0 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("FBIOGET_VSCREENINFO: %v", e)
|
||||
}
|
||||
var finfo fbFixScreenInfo
|
||||
if _, _, e := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), fbioGetFScreenInfo, uintptr(unsafe.Pointer(&finfo))); e != 0 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("FBIOGET_FSCREENINFO: %v", e)
|
||||
}
|
||||
|
||||
bpp := int(vinfo.BitsPerPixel)
|
||||
if bpp != 16 && bpp != 24 && bpp != 32 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("unsupported framebuffer bpp: %d", bpp)
|
||||
}
|
||||
|
||||
size := int(finfo.LineLength) * int(vinfo.Yres)
|
||||
if size <= 0 {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("invalid framebuffer dimensions: stride=%d h=%d", finfo.LineLength, vinfo.Yres)
|
||||
}
|
||||
|
||||
mm, err := unix.Mmap(fd, 0, size, unix.PROT_READ, unix.MAP_SHARED)
|
||||
if err != nil {
|
||||
unix.Close(fd)
|
||||
return nil, fmt.Errorf("mmap %s: %w", path, err)
|
||||
}
|
||||
|
||||
c := &FBCapturer{
|
||||
path: path,
|
||||
fd: fd,
|
||||
mmap: mm,
|
||||
w: int(vinfo.Xres),
|
||||
h: int(vinfo.Yres),
|
||||
bpp: bpp,
|
||||
stride: int(finfo.LineLength),
|
||||
rOff: vinfo.RedOffset,
|
||||
gOff: vinfo.GreenOffset,
|
||||
bOff: vinfo.BlueOffset,
|
||||
rLen: vinfo.RedLen,
|
||||
gLen: vinfo.GreenLen,
|
||||
bLen: vinfo.BlueLen,
|
||||
}
|
||||
log.Infof("framebuffer capturer ready: %s %dx%d bpp=%d r=%d/%d g=%d/%d b=%d/%d",
|
||||
path, c.w, c.h, c.bpp, c.rOff, c.rLen, c.gOff, c.gLen, c.bOff, c.bLen)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Width returns the framebuffer width in pixels.
|
||||
func (c *FBCapturer) Width() int { return c.w }
|
||||
|
||||
// Height returns the framebuffer height in pixels.
|
||||
func (c *FBCapturer) Height() int { return c.h }
|
||||
|
||||
// Capture allocates a fresh image and fills it with the current
|
||||
// framebuffer contents.
|
||||
func (c *FBCapturer) Capture() (*image.RGBA, error) {
|
||||
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
|
||||
if err := c.CaptureInto(img); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// CaptureInto reads the framebuffer directly into dst.Pix.
|
||||
func (c *FBCapturer) CaptureInto(dst *image.RGBA) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if dst.Rect.Dx() != c.w || dst.Rect.Dy() != c.h {
|
||||
return fmt.Errorf("dst size mismatch: dst=%dx%d fb=%dx%d",
|
||||
dst.Rect.Dx(), dst.Rect.Dy(), c.w, c.h)
|
||||
}
|
||||
|
||||
switch c.bpp {
|
||||
case 32:
|
||||
swizzleFB32(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h, channelShifts{R: c.rOff, G: c.gOff, B: c.bOff})
|
||||
case 24:
|
||||
swizzleFB24(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h)
|
||||
case 16:
|
||||
swizzleFB16RGB565(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close releases the framebuffer mmap and file descriptor. Serialized with
|
||||
// CaptureInto via c.mu so an in-flight capture can't read freed memory.
|
||||
func (c *FBCapturer) Close() {
|
||||
c.closeOnce.Do(func() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.mmap != nil {
|
||||
_ = unix.Munmap(c.mmap)
|
||||
c.mmap = nil
|
||||
}
|
||||
if c.fd >= 0 {
|
||||
_ = unix.Close(c.fd)
|
||||
c.fd = -1
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// channelShifts groups the bit offsets for the R/G/B channels in a packed
|
||||
// uint32 framebuffer pixel. Bundling avoids drowning per-row callers in a
|
||||
// 9-parameter signature.
|
||||
type channelShifts struct {
|
||||
R, G, B uint32
|
||||
}
|
||||
|
||||
// swizzleFB32 handles 32-bit framebuffers with arbitrary R/G/B channel
|
||||
// offsets. Pulls one pixel per uint32, then masks each channel into the
|
||||
// destination RGBA byte order.
|
||||
func swizzleFB32(dst []byte, dstStride int, src []byte, srcStride, w, h int, shifts channelShifts) {
|
||||
for y := 0; y < h; y++ {
|
||||
srcRow := src[y*srcStride : y*srcStride+w*4]
|
||||
dstRow := dst[y*dstStride:]
|
||||
for x := 0; x < w; x++ {
|
||||
pix := binary.LittleEndian.Uint32(srcRow[x*4 : x*4+4])
|
||||
dstRow[x*4+0] = byte(pix >> shifts.R)
|
||||
dstRow[x*4+1] = byte(pix >> shifts.G)
|
||||
dstRow[x*4+2] = byte(pix >> shifts.B)
|
||||
dstRow[x*4+3] = 0xff
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,149 +0,0 @@
|
||||
//go:build unix && !darwin && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"image"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// FBPoller wraps FBCapturer with the same lifecycle (ClientConnect /
|
||||
// ClientDisconnect, lazy init) as X11Poller, so it slots into the same
|
||||
// session plumbing without code changes upstream. The concrete
|
||||
// FBCapturer is platform-specific (capture_fb_linux.go / _freebsd.go);
|
||||
// this file owns the cross-platform glue.
|
||||
type FBPoller struct {
|
||||
mu sync.Mutex
|
||||
path string
|
||||
capturer *FBCapturer
|
||||
w, h int
|
||||
clients int32
|
||||
}
|
||||
|
||||
// NewFBPoller returns a poller that opens path on first use. Empty path
|
||||
// defaults to /dev/fb0 on Linux and /dev/ttyv0 on FreeBSD.
|
||||
func NewFBPoller(path string) *FBPoller {
|
||||
if path == "" {
|
||||
path = defaultFBPath()
|
||||
}
|
||||
return &FBPoller{path: path}
|
||||
}
|
||||
|
||||
// ClientConnect eagerly initialises the capturer on first connect.
|
||||
func (p *FBPoller) ClientConnect() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.clients++
|
||||
if p.clients == 1 {
|
||||
_ = p.ensureCapturerLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// ClientDisconnect closes the capturer when the last client leaves.
|
||||
func (p *FBPoller) ClientDisconnect() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.clients--
|
||||
if p.clients <= 0 && p.capturer != nil {
|
||||
p.capturer.Close()
|
||||
p.capturer = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Width returns the framebuffer width, doing lazy init if needed.
|
||||
func (p *FBPoller) Width() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
return p.w
|
||||
}
|
||||
|
||||
// Height returns the framebuffer height, doing lazy init if needed.
|
||||
func (p *FBPoller) Height() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
return p.h
|
||||
}
|
||||
|
||||
// Capture takes a fresh frame.
|
||||
func (p *FBPoller) Capture() (*image.RGBA, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p.capturer.Capture()
|
||||
}
|
||||
|
||||
// CaptureInto fills dst directly.
|
||||
func (p *FBPoller) CaptureInto(dst *image.RGBA) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
return p.capturer.CaptureInto(dst)
|
||||
}
|
||||
|
||||
// Close releases all framebuffer resources.
|
||||
func (p *FBPoller) Close() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.capturer != nil {
|
||||
p.capturer.Close()
|
||||
p.capturer = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *FBPoller) ensureCapturerLocked() error {
|
||||
if p.capturer != nil {
|
||||
return nil
|
||||
}
|
||||
c, err := NewFBCapturer(p.path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.capturer = c
|
||||
p.w, p.h = c.Width(), c.Height()
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ ScreenCapturer = (*FBPoller)(nil)
|
||||
var _ captureIntoer = (*FBPoller)(nil)
|
||||
|
||||
// swizzleFB24 handles 24-bit packed framebuffers (B,G,R triplets).
|
||||
// Shared between Linux and FreeBSD framebuffer paths.
|
||||
func swizzleFB24(dst []byte, dstStride int, src []byte, srcStride, w, h int) {
|
||||
for y := 0; y < h; y++ {
|
||||
srcRow := src[y*srcStride : y*srcStride+w*3]
|
||||
dstRow := dst[y*dstStride:]
|
||||
for x := 0; x < w; x++ {
|
||||
b := srcRow[x*3+0]
|
||||
g := srcRow[x*3+1]
|
||||
r := srcRow[x*3+2]
|
||||
dstRow[x*4+0] = r
|
||||
dstRow[x*4+1] = g
|
||||
dstRow[x*4+2] = b
|
||||
dstRow[x*4+3] = 0xff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// swizzleFB16RGB565 handles 16bpp RGB 565 framebuffers.
|
||||
func swizzleFB16RGB565(dst []byte, dstStride int, src []byte, srcStride, w, h int) {
|
||||
for y := 0; y < h; y++ {
|
||||
srcRow := src[y*srcStride : y*srcStride+w*2]
|
||||
dstRow := dst[y*dstStride:]
|
||||
for x := 0; x < w; x++ {
|
||||
pix := uint16(srcRow[x*2]) | uint16(srcRow[x*2+1])<<8
|
||||
r := byte((pix >> 11) & 0x1f)
|
||||
g := byte((pix >> 5) & 0x3f)
|
||||
b := byte(pix & 0x1f)
|
||||
dstRow[x*4+0] = (r << 3) | (r >> 2)
|
||||
dstRow[x*4+1] = (g << 2) | (g >> 4)
|
||||
dstRow[x*4+2] = (b << 3) | (b >> 2)
|
||||
dstRow[x*4+3] = 0xff
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,586 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
gdi32 = windows.NewLazySystemDLL("gdi32.dll")
|
||||
user32 = windows.NewLazySystemDLL("user32.dll")
|
||||
|
||||
procGetDC = user32.NewProc("GetDC")
|
||||
procReleaseDC = user32.NewProc("ReleaseDC")
|
||||
procCreateCompatDC = gdi32.NewProc("CreateCompatibleDC")
|
||||
procCreateDIBSection = gdi32.NewProc("CreateDIBSection")
|
||||
procSelectObject = gdi32.NewProc("SelectObject")
|
||||
procDeleteObject = gdi32.NewProc("DeleteObject")
|
||||
procDeleteDC = gdi32.NewProc("DeleteDC")
|
||||
procBitBlt = gdi32.NewProc("BitBlt")
|
||||
procGetSystemMetrics = user32.NewProc("GetSystemMetrics")
|
||||
|
||||
// Desktop switching for service/Session 0 capture.
|
||||
procOpenInputDesktop = user32.NewProc("OpenInputDesktop")
|
||||
procSetThreadDesktop = user32.NewProc("SetThreadDesktop")
|
||||
procCloseDesktop = user32.NewProc("CloseDesktop")
|
||||
procOpenWindowStation = user32.NewProc("OpenWindowStationW")
|
||||
procSetProcessWindowStation = user32.NewProc("SetProcessWindowStation")
|
||||
procCloseWindowStation = user32.NewProc("CloseWindowStation")
|
||||
procGetUserObjectInformationW = user32.NewProc("GetUserObjectInformationW")
|
||||
)
|
||||
|
||||
const uoiName = 2
|
||||
|
||||
const (
|
||||
smCxScreen = 0
|
||||
smCyScreen = 1
|
||||
srccopy = 0x00CC0020
|
||||
captureBlt = 0x40000000
|
||||
dibRgbColors = 0
|
||||
)
|
||||
|
||||
type bitmapInfoHeader struct {
|
||||
Size uint32
|
||||
Width int32
|
||||
Height int32
|
||||
Planes uint16
|
||||
BitCount uint16
|
||||
Compression uint32
|
||||
SizeImage uint32
|
||||
XPelsPerMeter int32
|
||||
YPelsPerMeter int32
|
||||
ClrUsed uint32
|
||||
ClrImportant uint32
|
||||
}
|
||||
|
||||
type bitmapInfo struct {
|
||||
Header bitmapInfoHeader
|
||||
}
|
||||
|
||||
// setupInteractiveWindowStation associates the current process with WinSta0,
|
||||
// the interactive window station. This is required for a SYSTEM service in
|
||||
// Session 0 to call OpenInputDesktop for screen capture and input injection.
|
||||
func setupInteractiveWindowStation() error {
|
||||
name, err := windows.UTF16PtrFromString("WinSta0")
|
||||
if err != nil {
|
||||
return fmt.Errorf("UTF16 WinSta0: %w", err)
|
||||
}
|
||||
hWinSta, _, err := procOpenWindowStation.Call(
|
||||
uintptr(unsafe.Pointer(name)),
|
||||
0,
|
||||
uintptr(windows.MAXIMUM_ALLOWED),
|
||||
)
|
||||
if hWinSta == 0 {
|
||||
return fmt.Errorf("OpenWindowStation(WinSta0): %w", err)
|
||||
}
|
||||
r, _, err := procSetProcessWindowStation.Call(hWinSta)
|
||||
if r == 0 {
|
||||
_, _, _ = procCloseWindowStation.Call(hWinSta)
|
||||
return fmt.Errorf("SetProcessWindowStation: %w", err)
|
||||
}
|
||||
log.Info("process window station set to WinSta0 (interactive)")
|
||||
return nil
|
||||
}
|
||||
|
||||
func screenSize() (int, int) {
|
||||
w, _, _ := procGetSystemMetrics.Call(uintptr(smCxScreen))
|
||||
h, _, _ := procGetSystemMetrics.Call(uintptr(smCyScreen))
|
||||
return int(w), int(h)
|
||||
}
|
||||
|
||||
func getDesktopName(hDesk uintptr) string {
|
||||
var buf [256]uint16
|
||||
var needed uint32
|
||||
_, _, _ = procGetUserObjectInformationW.Call(hDesk, uoiName,
|
||||
uintptr(unsafe.Pointer(&buf[0])), 512,
|
||||
uintptr(unsafe.Pointer(&needed)))
|
||||
return windows.UTF16ToString(buf[:])
|
||||
}
|
||||
|
||||
// switchToInputDesktop opens the desktop currently receiving user input
|
||||
// and sets it as the calling OS thread's desktop. Must be called from a
|
||||
// goroutine locked to its OS thread via runtime.LockOSThread().
|
||||
func switchToInputDesktop() (bool, string) {
|
||||
hDesk, _, _ := procOpenInputDesktop.Call(0, 0, uintptr(windows.MAXIMUM_ALLOWED))
|
||||
if hDesk == 0 {
|
||||
return false, ""
|
||||
}
|
||||
name := getDesktopName(hDesk)
|
||||
ret, _, _ := procSetThreadDesktop.Call(hDesk)
|
||||
_, _, _ = procCloseDesktop.Call(hDesk)
|
||||
return ret != 0, name
|
||||
}
|
||||
|
||||
// gdiCapturer captures the desktop screen using GDI BitBlt.
|
||||
// GDI objects (DC, DIBSection) are allocated once and reused across frames.
|
||||
type gdiCapturer struct {
|
||||
mu sync.Mutex
|
||||
width int
|
||||
height int
|
||||
|
||||
// Pre-allocated GDI resources, reused across captures.
|
||||
memDC uintptr
|
||||
bmp uintptr
|
||||
bits uintptr
|
||||
}
|
||||
|
||||
func newGDICapturer() (*gdiCapturer, error) {
|
||||
w, h := screenSize()
|
||||
if w == 0 || h == 0 {
|
||||
return nil, fmt.Errorf("screen dimensions are zero")
|
||||
}
|
||||
c := &gdiCapturer{width: w, height: h}
|
||||
if err := c.allocGDI(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// allocGDI pre-allocates the compatible DC and DIB section for reuse.
|
||||
func (c *gdiCapturer) allocGDI() error {
|
||||
screenDC, _, _ := procGetDC.Call(0)
|
||||
if screenDC == 0 {
|
||||
return fmt.Errorf("GetDC returned 0")
|
||||
}
|
||||
defer func() { _, _, _ = procReleaseDC.Call(0, screenDC) }()
|
||||
|
||||
memDC, _, _ := procCreateCompatDC.Call(screenDC)
|
||||
if memDC == 0 {
|
||||
return fmt.Errorf("CreateCompatibleDC returned 0")
|
||||
}
|
||||
|
||||
bi := bitmapInfo{
|
||||
Header: bitmapInfoHeader{
|
||||
Size: uint32(unsafe.Sizeof(bitmapInfoHeader{})),
|
||||
Width: int32(c.width),
|
||||
Height: -int32(c.height), // negative = top-down DIB
|
||||
Planes: 1,
|
||||
BitCount: 32,
|
||||
},
|
||||
}
|
||||
|
||||
var bits uintptr
|
||||
bmp, _, _ := procCreateDIBSection.Call(
|
||||
screenDC,
|
||||
uintptr(unsafe.Pointer(&bi)),
|
||||
dibRgbColors,
|
||||
uintptr(unsafe.Pointer(&bits)),
|
||||
0, 0,
|
||||
)
|
||||
if bmp == 0 || bits == 0 {
|
||||
_, _, _ = procDeleteDC.Call(memDC)
|
||||
return fmt.Errorf("CreateDIBSection returned 0")
|
||||
}
|
||||
|
||||
_, _, _ = procSelectObject.Call(memDC, bmp)
|
||||
|
||||
c.memDC = memDC
|
||||
c.bmp = bmp
|
||||
c.bits = bits
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *gdiCapturer) close() { c.freeGDI() }
|
||||
|
||||
// freeGDI releases pre-allocated GDI resources.
|
||||
func (c *gdiCapturer) freeGDI() {
|
||||
if c.bmp != 0 {
|
||||
_, _, _ = procDeleteObject.Call(c.bmp)
|
||||
c.bmp = 0
|
||||
}
|
||||
if c.memDC != 0 {
|
||||
_, _, _ = procDeleteDC.Call(c.memDC)
|
||||
c.memDC = 0
|
||||
}
|
||||
c.bits = 0
|
||||
}
|
||||
|
||||
func (c *gdiCapturer) capture() (*image.RGBA, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.memDC == 0 {
|
||||
return nil, fmt.Errorf("GDI resources not allocated")
|
||||
}
|
||||
|
||||
screenDC, _, _ := procGetDC.Call(0)
|
||||
if screenDC == 0 {
|
||||
return nil, fmt.Errorf("GetDC returned 0")
|
||||
}
|
||||
defer func() { _, _, _ = procReleaseDC.Call(0, screenDC) }()
|
||||
|
||||
// SRCCOPY|CAPTUREBLT: CAPTUREBLT forces inclusion of layered/topmost
|
||||
// windows in the capture and is required for GDI BitBlt to return live
|
||||
// pixels when the session is rendered through RDP / DWM-composited
|
||||
// surfaces. Without it BitBlt reads the backing-store DIB which is
|
||||
// often empty (all-black) on RDP and headless sessions.
|
||||
ret, _, _ := procBitBlt.Call(c.memDC, 0, 0, uintptr(c.width), uintptr(c.height),
|
||||
screenDC, 0, 0, srccopy|captureBlt)
|
||||
if ret == 0 {
|
||||
return nil, fmt.Errorf("BitBlt returned 0")
|
||||
}
|
||||
|
||||
n := c.width * c.height * 4
|
||||
raw := unsafe.Slice((*byte)(unsafe.Pointer(c.bits)), n)
|
||||
|
||||
// GDI gives BGRA, the RFB encoder expects RGBA (img.Pix layout).
|
||||
// Swap R and B in bulk using uint32 operations (one load + mask + shift
|
||||
// per pixel instead of three separate byte assignments).
|
||||
img := image.NewRGBA(image.Rect(0, 0, c.width, c.height))
|
||||
swizzleBGRAtoRGBA(img.Pix, raw)
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// DesktopCapturer captures the interactive desktop, handling desktop transitions
|
||||
// (login screen, UAC prompts). A dedicated OS-locked goroutine continuously
|
||||
// captures frames on demand via a dedicated OS-locked goroutine (required
|
||||
// because DXGI's D3D11 device context is not thread-safe). Sessions drive
|
||||
// timing by calling Capture(); a short staleness cache coalesces concurrent
|
||||
// requests. Capture pauses automatically when no clients are connected.
|
||||
type DesktopCapturer struct {
|
||||
mu sync.Mutex
|
||||
w, h int
|
||||
|
||||
// lastFrame/lastAt implement a small staleness cache so multiple
|
||||
// near-simultaneous Capture calls share one DXGI round-trip.
|
||||
lastFrame *image.RGBA
|
||||
lastAt time.Time
|
||||
|
||||
// clients tracks the number of active VNC sessions. When zero, the
|
||||
// worker goroutine releases the underlying capturer.
|
||||
clients atomic.Int32
|
||||
|
||||
// reqCh carries capture requests from sessions to the OS-locked worker.
|
||||
reqCh chan captureReq
|
||||
// wake is signaled when a client connects and the worker should resume.
|
||||
wake chan struct{}
|
||||
// done is closed when Close is called, terminating the worker.
|
||||
done chan struct{}
|
||||
|
||||
// cursorState holds the latest cursor sprite sampled by the worker.
|
||||
// The worker calls GetCursorInfo every capture and decodes a new
|
||||
// sprite only when the HCURSOR changes.
|
||||
cursorState cursorState
|
||||
}
|
||||
|
||||
// captureReq is a single capture request awaiting a reply. Reply channel is
|
||||
// buffered to size 1 so the worker never blocks on a sender that's gone.
|
||||
type captureReq struct {
|
||||
reply chan captureReply
|
||||
}
|
||||
|
||||
type captureReply struct {
|
||||
img *image.RGBA
|
||||
err error
|
||||
}
|
||||
|
||||
// NewDesktopCapturer creates an on-demand capturer for the active desktop.
|
||||
func NewDesktopCapturer() *DesktopCapturer {
|
||||
c := &DesktopCapturer{
|
||||
wake: make(chan struct{}, 1),
|
||||
done: make(chan struct{}),
|
||||
reqCh: make(chan captureReq),
|
||||
}
|
||||
go c.worker()
|
||||
return c
|
||||
}
|
||||
|
||||
// ClientConnect increments the active client count, resuming capture if needed.
|
||||
func (c *DesktopCapturer) ClientConnect() {
|
||||
c.clients.Add(1)
|
||||
select {
|
||||
case c.wake <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// ClientDisconnect decrements the active client count.
|
||||
func (c *DesktopCapturer) ClientDisconnect() {
|
||||
c.clients.Add(-1)
|
||||
}
|
||||
|
||||
// Close stops the capture loop and releases resources.
|
||||
func (c *DesktopCapturer) Close() {
|
||||
select {
|
||||
case <-c.done:
|
||||
default:
|
||||
close(c.done)
|
||||
}
|
||||
}
|
||||
|
||||
// Width returns the current screen width, triggering a capture if the
|
||||
// worker hasn't initialised yet. validateCapturer depends on Width/Height
|
||||
// becoming non-zero promptly after ClientConnect so it doesn't reject
|
||||
// brand-new sessions.
|
||||
func (c *DesktopCapturer) Width() int {
|
||||
c.mu.Lock()
|
||||
w := c.w
|
||||
c.mu.Unlock()
|
||||
if w == 0 && c.clients.Load() > 0 {
|
||||
_, _ = c.Capture()
|
||||
c.mu.Lock()
|
||||
w = c.w
|
||||
c.mu.Unlock()
|
||||
}
|
||||
return w
|
||||
}
|
||||
|
||||
// Height returns the current screen height, triggering a capture if the
|
||||
// worker hasn't initialised yet (see Width). Returns 0 while no client is
|
||||
// connected so callers don't deadlock against a parked worker.
|
||||
func (c *DesktopCapturer) Height() int {
|
||||
c.mu.Lock()
|
||||
h := c.h
|
||||
c.mu.Unlock()
|
||||
if h == 0 && c.clients.Load() > 0 {
|
||||
_, _ = c.Capture()
|
||||
c.mu.Lock()
|
||||
h = c.h
|
||||
c.mu.Unlock()
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// Capture returns a freshly captured frame, serving from a short staleness
|
||||
// cache when multiple sessions ask within freshWindow of each other. All
|
||||
// real DXGI/GDI work happens on the OS-locked worker goroutine.
|
||||
func (c *DesktopCapturer) Capture() (*image.RGBA, error) {
|
||||
c.mu.Lock()
|
||||
if c.lastFrame != nil && time.Since(c.lastAt) < freshWindow {
|
||||
img := c.lastFrame
|
||||
c.mu.Unlock()
|
||||
return img, nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
reply := make(chan captureReply, 1)
|
||||
select {
|
||||
case c.reqCh <- captureReq{reply: reply}:
|
||||
case <-c.done:
|
||||
return nil, fmt.Errorf("capturer closed")
|
||||
}
|
||||
select {
|
||||
case r := <-reply:
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.lastFrame = r.img
|
||||
c.lastAt = time.Now()
|
||||
c.mu.Unlock()
|
||||
return r.img, nil
|
||||
case <-c.done:
|
||||
return nil, fmt.Errorf("capturer closed")
|
||||
}
|
||||
}
|
||||
|
||||
// waitForClient blocks until a client connects or the capturer is closed.
|
||||
func (c *DesktopCapturer) waitForClient() bool {
|
||||
if c.clients.Load() > 0 {
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case <-c.wake:
|
||||
return true
|
||||
case <-c.done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// worker owns DXGI/GDI state on its OS-locked thread and services capture
|
||||
// requests from sessions. No background ticker: a capture happens only when
|
||||
// a session asks for one (throttled by Capture()'s staleness cache).
|
||||
func (c *DesktopCapturer) worker() {
|
||||
runtime.LockOSThread()
|
||||
|
||||
// When running as a Windows service (Session 0), we need to attach to the
|
||||
// interactive window station before OpenInputDesktop will succeed.
|
||||
if err := setupInteractiveWindowStation(); err != nil {
|
||||
log.Warnf("attach to interactive window station: %v", err)
|
||||
}
|
||||
|
||||
w := &captureWorker{c: c}
|
||||
defer w.closeCapturer()
|
||||
|
||||
for {
|
||||
if !c.waitForClient() {
|
||||
return
|
||||
}
|
||||
// Drop the capturer when all clients have disconnected so we don't
|
||||
// hold the DXGI duplication or GDI DC on an idle peer.
|
||||
if c.clients.Load() <= 0 {
|
||||
w.closeCapturer()
|
||||
continue
|
||||
}
|
||||
if !w.handleNextRequest() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// frameCapturer is the per-backend interface used by the worker. DXGI and
|
||||
// GDI implementations both satisfy it.
|
||||
type frameCapturer interface {
|
||||
capture() (*image.RGBA, error)
|
||||
close()
|
||||
}
|
||||
|
||||
// captureWorker owns the worker goroutine's mutable state. Extracted into a
|
||||
// struct so the request/desktop/init logic can live on small methods and the
|
||||
// outer worker() stays a thin loop.
|
||||
type captureWorker struct {
|
||||
c *DesktopCapturer
|
||||
cap frameCapturer
|
||||
desktopFails int
|
||||
lastDesktop string
|
||||
nextInitRetry time.Time
|
||||
cursor cursorSampler
|
||||
// lastBackend records the last capturer kind that came out of
|
||||
// createCapturer ("dxgi" or "gdi"); used to demote repeat "using X"
|
||||
// and DXGI-unavailable logs to debug when nothing changed.
|
||||
lastBackend string
|
||||
// lastDXGIErr is the textual DXGI failure printed in the most recent
|
||||
// fallback warning; suppresses repeat warns when DXGI keeps failing
|
||||
// the same way across desktop changes (login -> lock -> login).
|
||||
lastDXGIErr string
|
||||
}
|
||||
|
||||
// handleNextRequest waits for either shutdown or a capture request and runs
|
||||
// the request through prepCapturer/capture. Returns false when the worker
|
||||
// should exit.
|
||||
func (w *captureWorker) handleNextRequest() bool {
|
||||
select {
|
||||
case <-w.c.done:
|
||||
return false
|
||||
case req := <-w.c.reqCh:
|
||||
w.serveRequest(req)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (w *captureWorker) serveRequest(req captureReq) {
|
||||
fc, err := w.prepCapturer()
|
||||
if err != nil {
|
||||
req.reply <- captureReply{err: err}
|
||||
return
|
||||
}
|
||||
img, err := fc.capture()
|
||||
if err != nil {
|
||||
log.Debugf("capture: %v", err)
|
||||
w.closeCapturer()
|
||||
w.nextInitRetry = time.Now().Add(100 * time.Millisecond)
|
||||
req.reply <- captureReply{err: err}
|
||||
return
|
||||
}
|
||||
if snap, err := w.cursor.sample(); err != nil {
|
||||
w.c.cursorState.store(&cursorSnapshot{err: err})
|
||||
} else {
|
||||
w.c.cursorState.store(snap)
|
||||
}
|
||||
req.reply <- captureReply{img: img}
|
||||
}
|
||||
|
||||
// prepCapturer switches to the input desktop, handles desktop-change
|
||||
// teardown, and creates the underlying capturer on demand. Backoff state is
|
||||
// tracked across calls via w.nextInitRetry.
|
||||
func (w *captureWorker) prepCapturer() (frameCapturer, error) {
|
||||
if err := w.refreshDesktop(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if w.cap != nil {
|
||||
return w.cap, nil
|
||||
}
|
||||
if time.Now().Before(w.nextInitRetry) {
|
||||
return nil, fmt.Errorf("capturer init backing off")
|
||||
}
|
||||
fc, err := w.createCapturer()
|
||||
if err != nil {
|
||||
w.nextInitRetry = time.Now().Add(500 * time.Millisecond)
|
||||
return nil, err
|
||||
}
|
||||
w.cap = fc
|
||||
sw, sh := screenSize()
|
||||
w.c.mu.Lock()
|
||||
sizeChanged := w.c.w != sw || w.c.h != sh
|
||||
w.c.w, w.c.h = sw, sh
|
||||
w.c.mu.Unlock()
|
||||
if sizeChanged {
|
||||
log.Infof("screen capturer ready: %dx%d", sw, sh)
|
||||
} else {
|
||||
log.Debugf("screen capturer ready: %dx%d", sw, sh)
|
||||
}
|
||||
return w.cap, nil
|
||||
}
|
||||
|
||||
// refreshDesktop tracks the active input desktop. When it changes (lock
|
||||
// screen, fast-user-switch) the existing capturer is dropped so the next
|
||||
// call rebuilds one against the new desktop.
|
||||
func (w *captureWorker) refreshDesktop() error {
|
||||
ok, desk := switchToInputDesktop()
|
||||
if !ok {
|
||||
w.desktopFails++
|
||||
if w.desktopFails == 1 || w.desktopFails%100 == 0 {
|
||||
log.Warnf("switchToInputDesktop failed (count=%d), no interactive desktop session?", w.desktopFails)
|
||||
}
|
||||
return fmt.Errorf("no interactive desktop")
|
||||
}
|
||||
if w.desktopFails > 0 {
|
||||
log.Infof("switchToInputDesktop recovered after %d failures, desktop=%q", w.desktopFails, desk)
|
||||
w.desktopFails = 0
|
||||
}
|
||||
if desk != w.lastDesktop {
|
||||
log.Infof("desktop changed: %q -> %q", w.lastDesktop, desk)
|
||||
w.lastDesktop = desk
|
||||
w.closeCapturer()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *captureWorker) createCapturer() (frameCapturer, error) {
|
||||
dc, err := newDXGICapturer()
|
||||
if err == nil {
|
||||
if w.lastBackend != "dxgi" {
|
||||
log.Info("using DXGI Desktop Duplication for capture")
|
||||
} else {
|
||||
log.Debug("using DXGI Desktop Duplication for capture")
|
||||
}
|
||||
w.lastBackend = "dxgi"
|
||||
w.lastDXGIErr = ""
|
||||
return dc, nil
|
||||
}
|
||||
errStr := err.Error()
|
||||
if errStr != w.lastDXGIErr {
|
||||
log.Warnf("DXGI Desktop Duplication unavailable, falling back to slower GDI BitBlt: %v", err)
|
||||
w.lastDXGIErr = errStr
|
||||
} else {
|
||||
log.Debugf("DXGI Desktop Duplication still unavailable, falling back to slower GDI BitBlt: %v", err)
|
||||
}
|
||||
gc, err := newGDICapturer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if w.lastBackend != "gdi" {
|
||||
log.Info("using GDI BitBlt for capture")
|
||||
} else {
|
||||
log.Debug("using GDI BitBlt for capture")
|
||||
}
|
||||
w.lastBackend = "gdi"
|
||||
return gc, nil
|
||||
}
|
||||
|
||||
func (w *captureWorker) closeCapturer() {
|
||||
if w.cap != nil {
|
||||
w.cap.close()
|
||||
w.cap = nil
|
||||
}
|
||||
}
|
||||
@@ -1,544 +0,0 @@
|
||||
//go:build unix && !darwin && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/jezek/xgb"
|
||||
"github.com/jezek/xgb/xproto"
|
||||
)
|
||||
|
||||
const (
|
||||
// x11SocketDir is the well-known directory where X servers create
|
||||
// their abstract UNIX-domain sockets, named "X<display>". Used both
|
||||
// for auto-detecting an existing display and for placing/probing
|
||||
// sockets of virtual sessions we spawn.
|
||||
x11SocketDir = "/tmp/.X11-unix"
|
||||
|
||||
// envDisplay is the X11 display selector environment variable.
|
||||
envDisplay = "DISPLAY"
|
||||
// envXAuthority points X clients at the cookie file used to
|
||||
// authenticate against the running X server.
|
||||
envXAuthority = "XAUTHORITY"
|
||||
)
|
||||
|
||||
// X11Capturer captures the screen from an X11 display using the MIT-SHM extension.
|
||||
type X11Capturer struct {
|
||||
mu sync.Mutex
|
||||
conn *xgb.Conn
|
||||
screen *xproto.ScreenInfo
|
||||
w, h int
|
||||
shmID int
|
||||
shmAddr []byte
|
||||
shmSeg uint32
|
||||
useSHM bool
|
||||
// bufs double-buffers output images so the X11Poller's capture loop can
|
||||
// overwrite one while the session is still encoding the other. Before
|
||||
// this, a single reused buffer would race with the reader. Allocation
|
||||
// happens on first use and on geometry change.
|
||||
bufs [2]*image.RGBA
|
||||
cur int
|
||||
// cursor is the XFixes binding used to report the current sprite.
|
||||
// Allocated lazily on the first Cursor call. cursorInitErr latches
|
||||
// a permanent init failure so we stop retrying every frame.
|
||||
cursor *xfixesCursor
|
||||
cursorInitErr error
|
||||
}
|
||||
|
||||
// detectX11Display finds the active X11 display and sets DISPLAY/XAUTHORITY
|
||||
// environment variables if needed. This is required when running as a system
|
||||
// service where these vars aren't set.
|
||||
func detectX11Display() {
|
||||
if os.Getenv(envDisplay) != "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Try /proc first (Linux), then ps fallback (FreeBSD and others).
|
||||
if detectX11FromProc() {
|
||||
return
|
||||
}
|
||||
if detectX11FromSockets() {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// detectX11FromProc scans /proc/*/cmdline for Xorg (Linux).
|
||||
func detectX11FromProc() bool {
|
||||
entries, err := os.ReadDir("/proc")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, e := range entries {
|
||||
if !e.IsDir() {
|
||||
continue
|
||||
}
|
||||
cmdline, err := os.ReadFile("/proc/" + e.Name() + "/cmdline")
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if display, auth := parseXorgArgs(splitCmdline(cmdline)); display != "" {
|
||||
setDisplayEnv(display, auth)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// detectX11FromSockets checks /tmp/.X11-unix/ for X sockets and uses ps
|
||||
// to find the auth file. Works on FreeBSD and other systems without /proc.
|
||||
func detectX11FromSockets() bool {
|
||||
entries, err := os.ReadDir(x11SocketDir)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Pick the lowest numeric display rather than the lexically first
|
||||
// entry, so X10 doesn't win over X2.
|
||||
minDisplay := -1
|
||||
for _, e := range entries {
|
||||
name := e.Name()
|
||||
if len(name) < 2 || name[0] != 'X' {
|
||||
continue
|
||||
}
|
||||
n, err := strconv.Atoi(name[1:])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if minDisplay < 0 || n < minDisplay {
|
||||
minDisplay = n
|
||||
}
|
||||
}
|
||||
if minDisplay < 0 {
|
||||
return false
|
||||
}
|
||||
display := ":" + strconv.Itoa(minDisplay)
|
||||
os.Setenv(envDisplay, display)
|
||||
auth := findXorgAuthFromPS()
|
||||
if auth != "" {
|
||||
os.Setenv(envXAuthority, auth)
|
||||
log.Infof("auto-detected DISPLAY=%s (from socket) XAUTHORITY=%s (from ps)", display, auth)
|
||||
} else {
|
||||
log.Infof("auto-detected DISPLAY=%s (from socket)", display)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// findXorgAuthFromPS runs ps to find Xorg and extract its -auth argument.
|
||||
func findXorgAuthFromPS() string {
|
||||
out, err := exec.Command("ps", "auxww").Output()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, line := range strings.Split(string(out), "\n") {
|
||||
if !strings.Contains(line, "Xorg") && !strings.Contains(line, "/X ") {
|
||||
continue
|
||||
}
|
||||
fields := strings.Fields(line)
|
||||
for i, f := range fields {
|
||||
if f == "-auth" && i+1 < len(fields) {
|
||||
return fields[i+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseXorgArgs(args []string) (display, auth string) {
|
||||
if len(args) == 0 {
|
||||
return "", ""
|
||||
}
|
||||
base := args[0]
|
||||
if !(base == "Xorg" || base == "X" || len(base) > 0 && base[len(base)-1] == 'X' ||
|
||||
strings.Contains(base, "/Xorg") || strings.Contains(base, "/X")) {
|
||||
return "", ""
|
||||
}
|
||||
for i, arg := range args[1:] {
|
||||
if len(arg) > 0 && arg[0] == ':' {
|
||||
display = arg
|
||||
}
|
||||
if arg == "-auth" && i+2 < len(args) {
|
||||
auth = args[i+2]
|
||||
}
|
||||
}
|
||||
return display, auth
|
||||
}
|
||||
|
||||
func setDisplayEnv(display, auth string) {
|
||||
os.Setenv(envDisplay, display)
|
||||
if auth != "" {
|
||||
os.Setenv(envXAuthority, auth)
|
||||
log.Infof("auto-detected DISPLAY=%s XAUTHORITY=%s", display, auth)
|
||||
return
|
||||
}
|
||||
log.Infof("auto-detected DISPLAY=%s", display)
|
||||
}
|
||||
|
||||
func splitCmdline(data []byte) []string {
|
||||
var args []string
|
||||
for _, b := range splitNull(data) {
|
||||
if len(b) > 0 {
|
||||
args = append(args, string(b))
|
||||
}
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func splitNull(data []byte) [][]byte {
|
||||
var parts [][]byte
|
||||
start := 0
|
||||
for i, b := range data {
|
||||
if b == 0 {
|
||||
parts = append(parts, data[start:i])
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
if start < len(data) {
|
||||
parts = append(parts, data[start:])
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// NewX11Capturer connects to the X11 display and sets up shared memory capture.
|
||||
// Empty cookieHex falls back to XAUTHORITY env lookup.
|
||||
func NewX11Capturer(display, cookieHex string) (*X11Capturer, error) {
|
||||
if display == "" {
|
||||
detectX11Display()
|
||||
display = os.Getenv(envDisplay)
|
||||
}
|
||||
if display == "" {
|
||||
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
|
||||
}
|
||||
|
||||
var conn *xgb.Conn
|
||||
var err error
|
||||
if cookieHex != "" {
|
||||
conn, err = dialXUnixWithCookie(display, cookieHex)
|
||||
} else {
|
||||
conn, err = xgb.NewConnDisplay(display)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
|
||||
}
|
||||
|
||||
setup := xproto.Setup(conn)
|
||||
if len(setup.Roots) == 0 {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("no X11 screens")
|
||||
}
|
||||
screen := setup.Roots[0]
|
||||
|
||||
c := &X11Capturer{
|
||||
conn: conn,
|
||||
screen: &screen,
|
||||
w: int(screen.WidthInPixels),
|
||||
h: int(screen.HeightInPixels),
|
||||
}
|
||||
|
||||
if err := c.initSHM(); err != nil {
|
||||
log.Debugf("X11 SHM not available, using slow GetImage: %v", err)
|
||||
}
|
||||
|
||||
log.Infof("X11 capturer ready: %dx%d (display=%s, shm=%v)", c.w, c.h, display, c.useSHM)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// initSHM is implemented in capture_x11_shm_linux.go (requires SysV SHM).
|
||||
// On platforms without SysV SHM (FreeBSD), a stub returns an error and
|
||||
// the capturer falls back to GetImage.
|
||||
|
||||
// Width returns the screen width.
|
||||
func (c *X11Capturer) Width() int { return c.w }
|
||||
|
||||
// Height returns the screen height.
|
||||
func (c *X11Capturer) Height() int { return c.h }
|
||||
|
||||
// Capture returns the current screen as an RGBA image.
|
||||
func (c *X11Capturer) Capture() (*image.RGBA, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.useSHM {
|
||||
return c.captureSHM()
|
||||
}
|
||||
return c.captureGetImage()
|
||||
}
|
||||
|
||||
// CaptureInto fills the caller's destination buffer in one pass. The
|
||||
// source path (SHM or fallback GetImage) writes directly into dst.Pix
|
||||
// instead of going through the X11Capturer's internal double-buffer,
|
||||
// saving one full-frame memcpy per capture.
|
||||
func (c *X11Capturer) CaptureInto(dst *image.RGBA) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if dst.Rect.Dx() != c.w || dst.Rect.Dy() != c.h {
|
||||
return fmt.Errorf("dst size mismatch: dst=%dx%d capturer=%dx%d",
|
||||
dst.Rect.Dx(), dst.Rect.Dy(), c.w, c.h)
|
||||
}
|
||||
if c.useSHM {
|
||||
return c.captureSHMInto(dst)
|
||||
}
|
||||
return c.captureGetImageInto(dst)
|
||||
}
|
||||
|
||||
func (c *X11Capturer) captureGetImageInto(dst *image.RGBA) error {
|
||||
cookie := xproto.GetImage(c.conn, xproto.ImageFormatZPixmap,
|
||||
xproto.Drawable(c.screen.Root),
|
||||
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF)
|
||||
reply, err := cookie.Reply()
|
||||
if err != nil {
|
||||
return fmt.Errorf("GetImage: %w", err)
|
||||
}
|
||||
n := c.w * c.h * 4
|
||||
if len(reply.Data) < n {
|
||||
return fmt.Errorf("GetImage returned %d bytes, expected %d", len(reply.Data), n)
|
||||
}
|
||||
swizzleBGRAtoRGBA(dst.Pix, reply.Data)
|
||||
return nil
|
||||
}
|
||||
|
||||
// captureSHM is implemented in capture_x11_shm_linux.go.
|
||||
|
||||
func (c *X11Capturer) captureGetImage() (*image.RGBA, error) {
|
||||
cookie := xproto.GetImage(c.conn, xproto.ImageFormatZPixmap,
|
||||
xproto.Drawable(c.screen.Root),
|
||||
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF)
|
||||
|
||||
reply, err := cookie.Reply()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetImage: %w", err)
|
||||
}
|
||||
|
||||
data := reply.Data
|
||||
n := c.w * c.h * 4
|
||||
if len(data) < n {
|
||||
return nil, fmt.Errorf("GetImage returned %d bytes, expected %d", len(data), n)
|
||||
}
|
||||
|
||||
img := c.nextBuffer()
|
||||
swizzleBGRAtoRGBA(img.Pix, data)
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// nextBuffer returns the *image.RGBA the next capture should fill, advancing
|
||||
// the double-buffer index. Reallocates on geometry change.
|
||||
func (c *X11Capturer) nextBuffer() *image.RGBA {
|
||||
c.cur ^= 1
|
||||
b := c.bufs[c.cur]
|
||||
if b == nil || b.Rect.Dx() != c.w || b.Rect.Dy() != c.h {
|
||||
b = image.NewRGBA(image.Rect(0, 0, c.w, c.h))
|
||||
c.bufs[c.cur] = b
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Close releases X11 resources.
|
||||
func (c *X11Capturer) Close() {
|
||||
c.closeSHM()
|
||||
c.conn.Close()
|
||||
}
|
||||
|
||||
// closeSHM is implemented in capture_x11_shm_linux.go.
|
||||
|
||||
// X11Poller wraps X11Capturer with a staleness-cached on-demand Capture:
|
||||
// sessions drive captures themselves through the encoder goroutine, so we
|
||||
// don't need a background ticker. The last result is cached for a short
|
||||
// window so concurrent sessions coalesce into one capture.
|
||||
//
|
||||
// The capturer is allocated lazily on first use and released when all
|
||||
// clients disconnect, so an idle peer holds no X connection or SHM segment.
|
||||
type X11Poller struct {
|
||||
mu sync.Mutex
|
||||
|
||||
capturer *X11Capturer
|
||||
w, h int
|
||||
// closed at Close so callers can stop waiting on retry backoff.
|
||||
done chan struct{}
|
||||
|
||||
// lastFrame/lastAt implement a small cache: multiple near-simultaneous
|
||||
// Capture calls (multi-client, or input-coalesced) return the same
|
||||
// frame instead of hammering the X server.
|
||||
lastFrame *image.RGBA
|
||||
lastAt time.Time
|
||||
|
||||
// initBackoffUntil throttles capturer re-init when the X server is
|
||||
// unavailable or flapping.
|
||||
initBackoffUntil time.Time
|
||||
|
||||
clients atomic.Int32
|
||||
display string
|
||||
// cookieHex authenticates the X11 connection; empty falls back to XAUTHORITY env.
|
||||
cookieHex string
|
||||
}
|
||||
|
||||
// initRetryBackoff gates capturer re-init attempts after a failure so we
|
||||
// don't spin on X server errors.
|
||||
const initRetryBackoff = 2 * time.Second
|
||||
|
||||
// NewX11Poller creates a lazy on-demand capturer for the given X display.
|
||||
// Empty cookieHex falls back to XAUTHORITY env lookup.
|
||||
func NewX11Poller(display, cookieHex string) *X11Poller {
|
||||
return &X11Poller{
|
||||
display: display,
|
||||
cookieHex: cookieHex,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// ClientConnect increments the active client count. The first client triggers
|
||||
// eager capturer initialisation so that the first FBUpdateRequest doesn't
|
||||
// pay the X11 connect + SHM attach latency.
|
||||
func (p *X11Poller) ClientConnect() {
|
||||
if p.clients.Add(1) == 1 {
|
||||
p.mu.Lock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// ClientDisconnect decrements the active client count. On the last
|
||||
// disconnect we close the underlying capturer so idle peers cost nothing.
|
||||
func (p *X11Poller) ClientDisconnect() {
|
||||
if p.clients.Add(-1) == 0 {
|
||||
p.mu.Lock()
|
||||
if p.capturer != nil {
|
||||
p.capturer.Close()
|
||||
p.capturer = nil
|
||||
p.lastFrame = nil
|
||||
}
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Close releases all resources. Subsequent Capture calls will fail.
|
||||
func (p *X11Poller) Close() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
select {
|
||||
case <-p.done:
|
||||
default:
|
||||
close(p.done)
|
||||
}
|
||||
if p.capturer != nil {
|
||||
p.capturer.Close()
|
||||
p.capturer = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Width returns the screen width. Triggers lazy init if needed.
|
||||
func (p *X11Poller) Width() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
return p.w
|
||||
}
|
||||
|
||||
// Height returns the screen height. Triggers lazy init if needed.
|
||||
func (p *X11Poller) Height() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
_ = p.ensureCapturerLocked()
|
||||
return p.h
|
||||
}
|
||||
|
||||
// Cursor satisfies cursorSource by forwarding to the lazily-initialised
|
||||
// X11Capturer. Asking for the cursor on an idle poller triggers the same
|
||||
// lazy X11 connection setup as a capture would.
|
||||
func (p *X11Poller) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return nil, 0, 0, 0, err
|
||||
}
|
||||
return p.capturer.Cursor()
|
||||
}
|
||||
|
||||
// CursorPos satisfies cursorPositionSource by forwarding to the X11Capturer.
|
||||
func (p *X11Poller) CursorPos() (int, int, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return p.capturer.CursorPos()
|
||||
}
|
||||
|
||||
// Capture returns a fresh frame, serving from the short-lived cache if a
|
||||
// previous caller captured within freshWindow.
|
||||
func (p *X11Poller) Capture() (*image.RGBA, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.lastFrame != nil && time.Since(p.lastAt) < freshWindow {
|
||||
return p.lastFrame, nil
|
||||
}
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
img, err := p.capturer.Capture()
|
||||
if err != nil {
|
||||
// Drop the capturer so the next call re-inits; the X connection may
|
||||
// have died (e.g. Xorg restart).
|
||||
p.capturer.Close()
|
||||
p.capturer = nil
|
||||
p.initBackoffUntil = time.Now().Add(initRetryBackoff)
|
||||
return nil, fmt.Errorf("x11 capture: %w", err)
|
||||
}
|
||||
p.lastFrame = img
|
||||
p.lastAt = time.Now()
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// CaptureInto fills dst directly via the underlying capturer, bypassing
|
||||
// the freshness cache. The session's prevFrame/curFrame swap means each
|
||||
// session needs its own buffer anyway, so caching wouldn't help.
|
||||
func (p *X11Poller) CaptureInto(dst *image.RGBA) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := p.capturer.CaptureInto(dst); err != nil {
|
||||
p.capturer.Close()
|
||||
p.capturer = nil
|
||||
p.initBackoffUntil = time.Now().Add(initRetryBackoff)
|
||||
return fmt.Errorf("x11 capture: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureCapturerLocked initialises the underlying X11Capturer if not
|
||||
// already open. Caller must hold p.mu.
|
||||
func (p *X11Poller) ensureCapturerLocked() error {
|
||||
if p.capturer != nil {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-p.done:
|
||||
return fmt.Errorf("x11 capturer closed")
|
||||
default:
|
||||
}
|
||||
if time.Now().Before(p.initBackoffUntil) {
|
||||
return fmt.Errorf("x11 capturer unavailable (retry scheduled)")
|
||||
}
|
||||
c, err := NewX11Capturer(p.display, p.cookieHex)
|
||||
if err != nil {
|
||||
p.initBackoffUntil = time.Now().Add(initRetryBackoff)
|
||||
log.Debugf("X11 capturer: %v", err)
|
||||
return err
|
||||
}
|
||||
p.capturer = c
|
||||
p.w, p.h = c.Width(), c.Height()
|
||||
return nil
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
|
||||
"github.com/jezek/xgb/shm"
|
||||
"github.com/jezek/xgb/xproto"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func (c *X11Capturer) initSHM() error {
|
||||
if err := shm.Init(c.conn); err != nil {
|
||||
return fmt.Errorf("init SHM extension: %w", err)
|
||||
}
|
||||
|
||||
size := c.w * c.h * 4
|
||||
id, err := unix.SysvShmGet(unix.IPC_PRIVATE, size, unix.IPC_CREAT|0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("shmget: %w", err)
|
||||
}
|
||||
|
||||
addr, err := unix.SysvShmAttach(id, 0, 0)
|
||||
if err != nil {
|
||||
if _, ctlErr := unix.SysvShmCtl(id, unix.IPC_RMID, nil); ctlErr != nil {
|
||||
log.Debugf("shmctl IPC_RMID on attach failure: %v", ctlErr)
|
||||
}
|
||||
return fmt.Errorf("shmat: %w", err)
|
||||
}
|
||||
|
||||
if _, err := unix.SysvShmCtl(id, unix.IPC_RMID, nil); err != nil {
|
||||
log.Debugf("shmctl IPC_RMID: %v", err)
|
||||
}
|
||||
|
||||
seg, err := shm.NewSegId(c.conn)
|
||||
if err != nil {
|
||||
if detachErr := unix.SysvShmDetach(addr); detachErr != nil {
|
||||
log.Debugf("shmdt on new-seg failure: %v", detachErr)
|
||||
}
|
||||
return fmt.Errorf("new SHM seg: %w", err)
|
||||
}
|
||||
|
||||
if err := shm.AttachChecked(c.conn, seg, uint32(id), false).Check(); err != nil {
|
||||
if detachErr := unix.SysvShmDetach(addr); detachErr != nil {
|
||||
log.Debugf("shmdt on attach-checked failure: %v", detachErr)
|
||||
}
|
||||
return fmt.Errorf("SHM attach to X: %w", err)
|
||||
}
|
||||
|
||||
c.shmID = id
|
||||
c.shmAddr = addr
|
||||
c.shmSeg = uint32(seg)
|
||||
c.useSHM = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
|
||||
if err := c.fillSHM(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
img := c.nextBuffer()
|
||||
swizzleBGRAtoRGBA(img.Pix, c.shmAddr[:c.w*c.h*4])
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// captureSHMInto runs a single SHM GetImage and swizzles directly into the
|
||||
// caller-provided destination, skipping the internal double-buffer.
|
||||
func (c *X11Capturer) captureSHMInto(dst *image.RGBA) error {
|
||||
if err := c.fillSHM(); err != nil {
|
||||
return err
|
||||
}
|
||||
swizzleBGRAtoRGBA(dst.Pix, c.shmAddr[:c.w*c.h*4])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *X11Capturer) fillSHM() error {
|
||||
cookie := shm.GetImage(c.conn, xproto.Drawable(c.screen.Root),
|
||||
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF,
|
||||
xproto.ImageFormatZPixmap, shm.Seg(c.shmSeg), 0)
|
||||
if _, err := cookie.Reply(); err != nil {
|
||||
return fmt.Errorf("SHM GetImage: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *X11Capturer) closeSHM() {
|
||||
if c.useSHM {
|
||||
shm.Detach(c.conn, shm.Seg(c.shmSeg))
|
||||
if err := unix.SysvShmDetach(c.shmAddr); err != nil {
|
||||
log.Debugf("shmdt on close: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
//go:build freebsd
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
)
|
||||
|
||||
func (c *X11Capturer) initSHM() error {
|
||||
return fmt.Errorf("SysV SHM not available on this platform")
|
||||
}
|
||||
|
||||
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
|
||||
return nil, fmt.Errorf("SHM capture not available on this platform")
|
||||
}
|
||||
|
||||
func (c *X11Capturer) captureSHMInto(_ *image.RGBA) error {
|
||||
return fmt.Errorf("SHM capture not available on this platform")
|
||||
}
|
||||
|
||||
func (c *X11Capturer) closeSHM() {
|
||||
// no SHM to close on this platform
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCoalesceRects(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in [][4]int
|
||||
want [][4]int
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
in: nil,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "single",
|
||||
in: [][4]int{{0, 0, 64, 64}},
|
||||
want: [][4]int{{0, 0, 64, 64}},
|
||||
},
|
||||
{
|
||||
name: "horizontal_run",
|
||||
in: [][4]int{{0, 0, 64, 64}, {64, 0, 64, 64}, {128, 0, 64, 64}},
|
||||
want: [][4]int{{0, 0, 192, 64}},
|
||||
},
|
||||
{
|
||||
name: "vertical_run",
|
||||
in: [][4]int{{0, 0, 64, 64}, {0, 64, 64, 64}, {0, 128, 64, 64}},
|
||||
want: [][4]int{{0, 0, 64, 192}},
|
||||
},
|
||||
{
|
||||
name: "block_2x2",
|
||||
in: [][4]int{
|
||||
{0, 0, 64, 64}, {64, 0, 64, 64},
|
||||
{0, 64, 64, 64}, {64, 64, 64, 64},
|
||||
},
|
||||
want: [][4]int{{0, 0, 128, 128}},
|
||||
},
|
||||
{
|
||||
name: "no_merge_gap",
|
||||
in: [][4]int{{0, 0, 64, 64}, {192, 0, 64, 64}},
|
||||
want: [][4]int{{0, 0, 64, 64}, {192, 0, 64, 64}},
|
||||
},
|
||||
{
|
||||
name: "two_disjoint_columns",
|
||||
in: [][4]int{
|
||||
{0, 0, 64, 64}, {192, 0, 64, 64},
|
||||
{0, 64, 64, 64}, {192, 64, 64, 64},
|
||||
},
|
||||
want: [][4]int{{0, 0, 64, 128}, {192, 0, 64, 128}},
|
||||
},
|
||||
{
|
||||
name: "misaligned_widths_no_vertical_merge",
|
||||
in: [][4]int{
|
||||
{0, 0, 128, 64},
|
||||
{0, 64, 64, 64},
|
||||
},
|
||||
want: [][4]int{{0, 0, 128, 64}, {0, 64, 64, 64}},
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := coalesceRects(tc.in)
|
||||
if len(got) == 0 && len(tc.want) == 0 {
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Fatalf("got %v want %v", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
package server
|
||||
|
||||
// interactiveUserError returns nil when a user is logged into the console
|
||||
// (i.e. an Aqua session is active). At the loginwindow there is nobody to
|
||||
// display an approval prompt to, so callers can decline without waiting on
|
||||
// the broker. Any error (including errNoConsoleUser) is treated as decline.
|
||||
func interactiveUserError() error {
|
||||
_, err := consoleUserID()
|
||||
return err
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
//go:build !darwin && !windows
|
||||
|
||||
package server
|
||||
|
||||
// interactiveUserError is unused outside service mode (darwin/windows) but
|
||||
// the symbol must exist so gateApproval compiles on all platforms.
|
||||
func interactiveUserError() error { return nil }
|
||||
@@ -1,15 +0,0 @@
|
||||
package server
|
||||
|
||||
// interactiveUserError returns nil when there is a logged-in user session
|
||||
// on the box. At the lock/login screen WTSQueryUserName is empty, which
|
||||
// means there is nobody to display an approval prompt to.
|
||||
func interactiveUserError() error {
|
||||
sid := getActiveSessionID()
|
||||
if sid == 0 {
|
||||
return errNoConsoleUser
|
||||
}
|
||||
if !wtsSessionHasUser(sid) {
|
||||
return errNoConsoleUser
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,203 +0,0 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"hash/maphash"
|
||||
"image"
|
||||
)
|
||||
|
||||
// copyRectDetector finds tiles in the current frame that match the content
|
||||
// of some tile-aligned region of the previous frame, so we can emit them as
|
||||
// CopyRect rectangles (16 wire bytes) instead of re-encoding the pixels.
|
||||
//
|
||||
// The detector keeps two structures:
|
||||
// - tileHash, a flat slice of one hash per tile-aligned position, used as
|
||||
// the source of truth for the previous frame's tile content.
|
||||
// - prevTiles, a hash → position lookup used during findTileMatch.
|
||||
//
|
||||
// updateDirty rehashes only the tiles that changed this frame, so the
|
||||
// steady-state cost is proportional to the dirty set, not the framebuffer.
|
||||
// A full rebuild from scratch is only done on the first frame or when the
|
||||
// detector has not yet been initialized for the current resolution.
|
||||
//
|
||||
// Limitations:
|
||||
// - Only tile-aligned source positions are considered. Sub-tile-aligned
|
||||
// moves (e.g. window dragged by 7 pixels) are not detected. This still
|
||||
// covers the common case of vertical/horizontal scrolling, which always
|
||||
// produces tile-aligned matches at the tile granularity.
|
||||
// - 64-bit maphash collisions are assumed not to happen. The probability
|
||||
// for any single frame's hash universe is ~2^-32 * tileCount² which is
|
||||
// vanishingly small at typical resolutions; if we ever observe one we
|
||||
// can fall back to a full memcmp verification.
|
||||
type copyRectDetector struct {
|
||||
seed maphash.Seed
|
||||
tileSize int
|
||||
w, h int
|
||||
cols, rows int
|
||||
// tileHash[ty*cols + tx] is the current hash of the tile at (tx, ty)
|
||||
// in the previous frame. Lookup uses this to detect stale prevTiles
|
||||
// entries: incremental updates may leave hash→pos entries pointing
|
||||
// at a tile whose content has since changed.
|
||||
tileHash []uint64
|
||||
// prevTiles maps a tile hash to a (x, y) origin in the previous frame.
|
||||
prevTiles map[uint64][2]int
|
||||
// hash is reused across hash computations to keep the per-tile lookup
|
||||
// path allocation-free.
|
||||
hash maphash.Hash
|
||||
}
|
||||
|
||||
func newCopyRectDetector(tileSize int) *copyRectDetector {
|
||||
d := ©RectDetector{
|
||||
seed: maphash.MakeSeed(),
|
||||
tileSize: tileSize,
|
||||
prevTiles: make(map[uint64][2]int),
|
||||
}
|
||||
d.hash.SetSeed(d.seed)
|
||||
return d
|
||||
}
|
||||
|
||||
// resize ensures the per-tile tables match the given framebuffer size.
|
||||
// Called from rebuild before each full hash sweep.
|
||||
func (d *copyRectDetector) resize(w, h int) {
|
||||
if d.w == w && d.h == h && d.tileHash != nil {
|
||||
return
|
||||
}
|
||||
d.w, d.h = w, h
|
||||
d.cols = w / d.tileSize
|
||||
d.rows = h / d.tileSize
|
||||
d.tileHash = make([]uint64, d.cols*d.rows)
|
||||
}
|
||||
|
||||
// hashTile computes the 64-bit maphash of one tile-aligned tile of frame.
|
||||
func (d *copyRectDetector) hashTile(frame *image.RGBA, tx, ty int) uint64 {
|
||||
d.hash.Reset()
|
||||
ts := d.tileSize
|
||||
stride := frame.Stride
|
||||
rowBytes := ts * 4
|
||||
base := ty*stride + tx*4
|
||||
for row := 0; row < ts; row++ {
|
||||
off := base + row*stride
|
||||
_, _ = d.hash.Write(frame.Pix[off : off+rowBytes])
|
||||
}
|
||||
return d.hash.Sum64()
|
||||
}
|
||||
|
||||
// rebuild discards everything and rehashes the whole frame. O(w*h). Use
|
||||
// for the first frame or after the detector has been resized. Steady-state
|
||||
// updates should go through updateDirty instead.
|
||||
func (d *copyRectDetector) rebuild(frame *image.RGBA, w, h int) {
|
||||
d.resize(w, h)
|
||||
if d.prevTiles == nil {
|
||||
d.prevTiles = make(map[uint64][2]int)
|
||||
} else {
|
||||
clear(d.prevTiles)
|
||||
}
|
||||
ts := d.tileSize
|
||||
for ty := 0; ty+ts <= h; ty += ts {
|
||||
for tx := 0; tx+ts <= w; tx += ts {
|
||||
sum := d.hashTile(frame, tx, ty)
|
||||
d.tileHash[(ty/ts)*d.cols+(tx/ts)] = sum
|
||||
if _, exists := d.prevTiles[sum]; !exists {
|
||||
d.prevTiles[sum] = [2]int{tx, ty}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateDirty rehashes only the tiles named in dirty (each entry is
|
||||
// [x, y, w, h] with w and h equal to tileSize). O(len(dirty)) work, which
|
||||
// in the common case is a tiny fraction of the whole framebuffer.
|
||||
//
|
||||
// The prevTiles map is replaced on collision rather than first-wins so a
|
||||
// newly-hashed tile claims the slot. Old, stale entries pointing at tiles
|
||||
// that no longer carry that hash are filtered at lookup time via tileHash.
|
||||
func (d *copyRectDetector) updateDirty(frame *image.RGBA, w, h int, dirty [][4]int) {
|
||||
if d.w != w || d.h != h || d.tileHash == nil {
|
||||
d.rebuild(frame, w, h)
|
||||
return
|
||||
}
|
||||
ts := d.tileSize
|
||||
for _, r := range dirty {
|
||||
if r[2] != ts || r[3] != ts {
|
||||
continue
|
||||
}
|
||||
tx, ty := r[0], r[1]
|
||||
if tx+ts > w || ty+ts > h {
|
||||
continue
|
||||
}
|
||||
sum := d.hashTile(frame, tx, ty)
|
||||
d.tileHash[(ty/ts)*d.cols+(tx/ts)] = sum
|
||||
// Latest-wins on collision: ensures the most recent owner of this
|
||||
// hash is the one we'll return on lookup. The previous owner's
|
||||
// entry, if any, gets shadowed; if its content has changed it's
|
||||
// stale anyway and findTileMatch's verification will skip it.
|
||||
d.prevTiles[sum] = [2]int{tx, ty}
|
||||
}
|
||||
}
|
||||
|
||||
// findTileMatch hashes the current-frame tile at (dstX, dstY) and looks up
|
||||
// its hash in the previous-frame map. Returns (srcX, srcY, true) when a
|
||||
// matching tile-aligned tile exists at a different position whose stored
|
||||
// hash still equals the requested hash (so the result is not stale).
|
||||
func (d *copyRectDetector) findTileMatch(cur *image.RGBA, dstX, dstY int) (int, int, bool) {
|
||||
if len(d.prevTiles) == 0 || d.tileHash == nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
ts := d.tileSize
|
||||
if dstX+ts > cur.Rect.Dx() || dstY+ts > cur.Rect.Dy() {
|
||||
return 0, 0, false
|
||||
}
|
||||
sum := d.hashTile(cur, dstX, dstY)
|
||||
pos, ok := d.prevTiles[sum]
|
||||
if !ok {
|
||||
return 0, 0, false
|
||||
}
|
||||
if pos[0] == dstX && pos[1] == dstY {
|
||||
return 0, 0, false
|
||||
}
|
||||
// Reject source coords that fall outside the current framebuffer
|
||||
// (frame may have shrunk since the source position was recorded). A
|
||||
// CopyRect with an out-of-range source would have the client copy
|
||||
// from undefined pixels, so drop the match and let the encoder send
|
||||
// the rect normally.
|
||||
if pos[0] < 0 || pos[1] < 0 || pos[0]+ts > cur.Rect.Dx() || pos[1]+ts > cur.Rect.Dy() {
|
||||
return 0, 0, false
|
||||
}
|
||||
// Reject stale entries: the position the map points at must still
|
||||
// carry the same hash according to our per-tile array.
|
||||
hashIdx := (pos[1]/ts)*d.cols + pos[0]/ts
|
||||
if hashIdx < 0 || hashIdx >= len(d.tileHash) {
|
||||
return 0, 0, false
|
||||
}
|
||||
if d.tileHash[hashIdx] != sum {
|
||||
return 0, 0, false
|
||||
}
|
||||
return pos[0], pos[1], true
|
||||
}
|
||||
|
||||
// extractCopyRectTiles examines the diff-produced (per-tile) dirty list and
|
||||
// pulls out any tiles whose current-frame content matches a prev-frame tile
|
||||
// at a different position. Returns the CopyRect candidates and the residual
|
||||
// dirty tiles that still need pixel encoding.
|
||||
type copyRectMove struct {
|
||||
srcX, srcY int
|
||||
dstX, dstY int
|
||||
}
|
||||
|
||||
func (d *copyRectDetector) extractCopyRectTiles(cur *image.RGBA, dirtyTiles [][4]int) (moves []copyRectMove, remaining [][4]int) {
|
||||
ts := d.tileSize
|
||||
remaining = dirtyTiles[:0:cap(dirtyTiles)]
|
||||
for _, r := range dirtyTiles {
|
||||
if r[2] == ts && r[3] == ts {
|
||||
if sx, sy, ok := d.findTileMatch(cur, r[0], r[1]); ok {
|
||||
moves = append(moves, copyRectMove{
|
||||
srcX: sx, srcY: sy, dstX: r[0], dstY: r[1],
|
||||
})
|
||||
continue
|
||||
}
|
||||
}
|
||||
remaining = append(remaining, r)
|
||||
}
|
||||
return moves, remaining
|
||||
}
|
||||
@@ -1,162 +0,0 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"image"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// fillTile paints a tileSize×tileSize block of img at (x,y) with the colour
|
||||
// derived from (r,g,b) so the test can construct distinct-content tiles.
|
||||
func fillTile(img *image.RGBA, x, y, ts int, r, g, b byte) {
|
||||
for row := 0; row < ts; row++ {
|
||||
off := (y+row)*img.Stride + x*4
|
||||
for col := 0; col < ts; col++ {
|
||||
img.Pix[off+col*4+0] = r
|
||||
img.Pix[off+col*4+1] = g
|
||||
img.Pix[off+col*4+2] = b
|
||||
img.Pix[off+col*4+3] = 0xff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copyTile copies a tileSize×tileSize block from src(sx,sy) to dst(dx,dy).
|
||||
func copyTile(dst, src *image.RGBA, sx, sy, dx, dy, ts int) {
|
||||
for row := 0; row < ts; row++ {
|
||||
srcOff := (sy+row)*src.Stride + sx*4
|
||||
dstOff := (dy+row)*dst.Stride + dx*4
|
||||
copy(dst.Pix[dstOff:dstOff+ts*4], src.Pix[srcOff:srcOff+ts*4])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyRectDetector_DetectsVerticalScroll(t *testing.T) {
|
||||
const w, h = 256, 192 // 4×3 tiles at 64px
|
||||
const ts = 64
|
||||
|
||||
prev := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
cur := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
|
||||
// prev: 12 tiles each with a unique colour.
|
||||
for ty := 0; ty < 3; ty++ {
|
||||
for tx := 0; tx < 4; tx++ {
|
||||
fillTile(prev, tx*ts, ty*ts, ts, byte(tx*40), byte(ty*60), 0x80)
|
||||
}
|
||||
}
|
||||
// cur: simulate a single-tile-row scroll upward, every tile copied from
|
||||
// the row below in prev, top row is new content.
|
||||
for ty := 0; ty < 2; ty++ {
|
||||
for tx := 0; tx < 4; tx++ {
|
||||
copyTile(cur, prev, tx*ts, (ty+1)*ts, tx*ts, ty*ts, ts)
|
||||
}
|
||||
}
|
||||
// Bottom row of cur: new colour, not a match.
|
||||
for tx := 0; tx < 4; tx++ {
|
||||
fillTile(cur, tx*ts, 2*ts, ts, 0xff, 0xff, 0xff)
|
||||
}
|
||||
|
||||
d := newCopyRectDetector(ts)
|
||||
d.rebuild(prev, w, h)
|
||||
|
||||
tiles := diffTiles(prev, cur, w, h, ts)
|
||||
moves, remaining := d.extractCopyRectTiles(cur, tiles)
|
||||
|
||||
// Expect 8 CopyRect moves (top two rows) and 4 residual tiles (bottom row).
|
||||
if len(moves) != 8 {
|
||||
t.Fatalf("moves: want 8, got %d", len(moves))
|
||||
}
|
||||
if len(remaining) != 4 {
|
||||
t.Fatalf("remaining: want 4, got %d", len(remaining))
|
||||
}
|
||||
// Spot-check one move: cur (0, 0) should map to prev (0, 64).
|
||||
var found bool
|
||||
for _, m := range moves {
|
||||
if m.dstX == 0 && m.dstY == 0 {
|
||||
if m.srcX != 0 || m.srcY != ts {
|
||||
t.Fatalf("move at (0,0): src=(%d,%d), want (0,%d)", m.srcX, m.srcY, ts)
|
||||
}
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("no move for dst (0,0)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyRectDetector_RejectsSelfMatch(t *testing.T) {
|
||||
const w, h = 128, 128
|
||||
const ts = 64
|
||||
|
||||
prev := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
cur := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
|
||||
// prev: 4 tiles, all unique
|
||||
fillTile(prev, 0, 0, ts, 0x10, 0x20, 0x30)
|
||||
fillTile(prev, ts, 0, ts, 0x40, 0x50, 0x60)
|
||||
fillTile(prev, 0, ts, ts, 0x70, 0x80, 0x90)
|
||||
fillTile(prev, ts, ts, ts, 0xa0, 0xb0, 0xc0)
|
||||
|
||||
// cur: tile (0,0) unchanged, others changed but content same as prev's (0,0).
|
||||
fillTile(cur, 0, 0, ts, 0x10, 0x20, 0x30) // self-match
|
||||
fillTile(cur, ts, 0, ts, 0xff, 0xff, 0xff)
|
||||
fillTile(cur, 0, ts, ts, 0xff, 0xff, 0xff)
|
||||
fillTile(cur, ts, ts, ts, 0xff, 0xff, 0xff)
|
||||
|
||||
d := newCopyRectDetector(ts)
|
||||
d.rebuild(prev, w, h)
|
||||
|
||||
// Tile (0,0) is not in the dirty list (it's unchanged) so it should not
|
||||
// produce a move even though its hash matches prev (0,0).
|
||||
tiles := diffTiles(prev, cur, w, h, ts)
|
||||
moves, _ := d.extractCopyRectTiles(cur, tiles)
|
||||
for _, m := range moves {
|
||||
if m.dstX == 0 && m.dstY == 0 {
|
||||
t.Fatalf("unexpected move at (0,0)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyRectDetector_PassThroughWhenNoMatch(t *testing.T) {
|
||||
const w, h = 64, 64
|
||||
const ts = 64
|
||||
|
||||
prev := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
cur := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
fillTile(prev, 0, 0, ts, 0x11, 0x22, 0x33)
|
||||
fillTile(cur, 0, 0, ts, 0xaa, 0xbb, 0xcc) // wholly different
|
||||
|
||||
d := newCopyRectDetector(ts)
|
||||
d.rebuild(prev, w, h)
|
||||
tiles := diffTiles(prev, cur, w, h, ts)
|
||||
moves, remaining := d.extractCopyRectTiles(cur, tiles)
|
||||
|
||||
if len(moves) != 0 {
|
||||
t.Fatalf("expected 0 moves, got %d", len(moves))
|
||||
}
|
||||
if len(remaining) != 1 {
|
||||
t.Fatalf("expected 1 residual tile, got %d", len(remaining))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeCopyRectBody_Layout(t *testing.T) {
|
||||
got := encodeCopyRectBody(100, 200, 300, 400, 64, 48)
|
||||
if len(got) != 16 {
|
||||
t.Fatalf("CopyRect body length: want 16, got %d", len(got))
|
||||
}
|
||||
// Dest position
|
||||
if got[0] != 0x01 || got[1] != 0x2c || got[2] != 0x01 || got[3] != 0x90 {
|
||||
t.Fatalf("bad dest bytes: % x", got[0:4])
|
||||
}
|
||||
// Width, height
|
||||
if got[4] != 0 || got[5] != 64 || got[6] != 0 || got[7] != 48 {
|
||||
t.Fatalf("bad size bytes: % x", got[4:8])
|
||||
}
|
||||
// Encoding = 1
|
||||
if got[11] != 0x01 {
|
||||
t.Fatalf("bad encoding byte: 0x%02x", got[11])
|
||||
}
|
||||
// Source position
|
||||
if got[12] != 0 || got[13] != 100 || got[14] != 0 || got[15] != 200 {
|
||||
t.Fatalf("bad src bytes: % x", got[12:16])
|
||||
}
|
||||
}
|
||||
@@ -1,194 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"image"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
darwinCursorOnce sync.Once
|
||||
cgsCreateCursor func() uintptr
|
||||
darwinCursorErr error
|
||||
)
|
||||
|
||||
// initDarwinCursor binds a private symbol that returns the current
|
||||
// system cursor image. The classic CGSCreateCurrentCursorImage moved
|
||||
// from CoreGraphics to SkyLight around macOS 13 and is gone entirely
|
||||
// in Sequoia; we probe both frameworks for any of the historical
|
||||
// names so this keeps working on whichever release the binding still
|
||||
// exists. Without a hit the remote-cursor compositing path becomes a
|
||||
// no-op and we log the candidates we tried.
|
||||
func initDarwinCursor() {
|
||||
darwinCursorOnce.Do(func() {
|
||||
libs := []string{
|
||||
"/System/Library/PrivateFrameworks/SkyLight.framework/SkyLight",
|
||||
"/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics",
|
||||
}
|
||||
names := []string{
|
||||
"CGSCreateCurrentCursorImage",
|
||||
"CGSCopyCurrentCursorImage",
|
||||
"CGSCurrentCursorImage",
|
||||
"CGSHardwareCursorActiveImage",
|
||||
}
|
||||
var tried []string
|
||||
for _, path := range libs {
|
||||
h, err := purego.Dlopen(path, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
tried = append(tried, fmt.Sprintf("dlopen %s: %v", path, err))
|
||||
continue
|
||||
}
|
||||
for _, name := range names {
|
||||
sym, err := purego.Dlsym(h, name)
|
||||
if err != nil {
|
||||
tried = append(tried, fmt.Sprintf("%s!%s missing", path, name))
|
||||
continue
|
||||
}
|
||||
purego.RegisterFunc(&cgsCreateCursor, sym)
|
||||
log.Infof("macOS cursor: bound %s from %s", name, path)
|
||||
return
|
||||
}
|
||||
}
|
||||
darwinCursorErr = fmt.Errorf("no cursor image symbol available; tried: %v", tried)
|
||||
})
|
||||
}
|
||||
|
||||
// cgCursor holds the cached macOS cursor sprite and bumps a serial when
|
||||
// the bytes change. Hotspot is left at (0, 0): the public Cocoa hot-spot
|
||||
// query lives on NSCursor which is process-local and not reachable from
|
||||
// our purego-based bindings; the visual cost is a small misalignment for
|
||||
// non-arrow cursors (I-beam, crosshair, etc.).
|
||||
type cgCursor struct {
|
||||
mu sync.Mutex
|
||||
hashSeed maphash.Seed
|
||||
lastSum uint64
|
||||
cached *image.RGBA
|
||||
serial uint64
|
||||
}
|
||||
|
||||
func newCGCursor() *cgCursor {
|
||||
initDarwinCursor()
|
||||
return &cgCursor{hashSeed: maphash.MakeSeed()}
|
||||
}
|
||||
|
||||
// Cursor returns the current cursor sprite as RGBA. Errors that come from
|
||||
// missing private symbols are sticky; transient empty-image responses are
|
||||
// reported as such so the encoder skips this cycle.
|
||||
func (c *cgCursor) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if darwinCursorErr != nil {
|
||||
return nil, 0, 0, 0, darwinCursorErr
|
||||
}
|
||||
if cgsCreateCursor == nil {
|
||||
return nil, 0, 0, 0, fmt.Errorf("CGSCreateCurrentCursorImage unavailable")
|
||||
}
|
||||
cgImage := cgsCreateCursor()
|
||||
if cgImage == 0 {
|
||||
return nil, 0, 0, 0, fmt.Errorf("no cursor image available")
|
||||
}
|
||||
defer cgImageRelease(cgImage)
|
||||
|
||||
w := int(cgImageGetWidth(cgImage))
|
||||
h := int(cgImageGetHeight(cgImage))
|
||||
if w <= 0 || h <= 0 {
|
||||
return nil, 0, 0, 0, fmt.Errorf("cursor has zero extent")
|
||||
}
|
||||
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
|
||||
bpp := int(cgImageGetBitsPerPixel(cgImage))
|
||||
if bpp != 32 {
|
||||
return nil, 0, 0, 0, fmt.Errorf("unsupported cursor bpp: %d", bpp)
|
||||
}
|
||||
provider := cgImageGetDataProvider(cgImage)
|
||||
if provider == 0 {
|
||||
return nil, 0, 0, 0, fmt.Errorf("cursor data provider missing")
|
||||
}
|
||||
cfData := cgDataProviderCopyData(provider)
|
||||
if cfData == 0 {
|
||||
return nil, 0, 0, 0, fmt.Errorf("cursor data copy failed")
|
||||
}
|
||||
defer cfRelease(cfData)
|
||||
dataLen := int(cfDataGetLength(cfData))
|
||||
dataPtr := cfDataGetBytePtr(cfData)
|
||||
if dataPtr == 0 || dataLen == 0 {
|
||||
return nil, 0, 0, 0, fmt.Errorf("cursor data empty")
|
||||
}
|
||||
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
|
||||
|
||||
sum := maphash.Bytes(c.hashSeed, src)
|
||||
if c.cached != nil && sum == c.lastSum {
|
||||
return c.cached, 0, 0, c.serial, nil
|
||||
}
|
||||
|
||||
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
for y := 0; y < h; y++ {
|
||||
srcOff := y * bytesPerRow
|
||||
dstOff := y * w * 4
|
||||
for x := 0; x < w; x++ {
|
||||
si := srcOff + x*4
|
||||
di := dstOff + x*4
|
||||
img.Pix[di+0] = src[si+2]
|
||||
img.Pix[di+1] = src[si+1]
|
||||
img.Pix[di+2] = src[si+0]
|
||||
img.Pix[di+3] = src[si+3]
|
||||
}
|
||||
}
|
||||
|
||||
c.lastSum = sum
|
||||
c.cached = img
|
||||
c.serial++
|
||||
return img, 0, 0, c.serial, nil
|
||||
}
|
||||
|
||||
// Cursor on CGCapturer satisfies cursorSource. The cgCursor wrapper is
|
||||
// allocated lazily so a build that never asks for the cursor pays no cost.
|
||||
func (c *CGCapturer) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
c.cursorOnce.Do(func() {
|
||||
c.cursor = newCGCursor()
|
||||
})
|
||||
return c.cursor.Cursor()
|
||||
}
|
||||
|
||||
// CursorPos returns the current global mouse location via CGEventCreate /
|
||||
// CGEventGetLocation. Coordinates are screen pixels in the main display.
|
||||
func (c *CGCapturer) CursorPos() (int, int, error) {
|
||||
if cgEventCreate == nil || cgEventGetLocation == nil {
|
||||
return 0, 0, fmt.Errorf("CGEvent location APIs unavailable")
|
||||
}
|
||||
ev := cgEventCreate(0)
|
||||
if ev == 0 {
|
||||
return 0, 0, fmt.Errorf("CGEventCreate returned nil")
|
||||
}
|
||||
defer cfRelease(ev)
|
||||
pt := cgEventGetLocation(ev)
|
||||
return int(pt.X), int(pt.Y), nil
|
||||
}
|
||||
|
||||
// Cursor on MacPoller forwards to the lazy CGCapturer. ensureCapturerLocked
|
||||
// returns an error when Screen Recording permission has not been granted;
|
||||
// in that case there is no usable cursor source either.
|
||||
func (p *MacPoller) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return nil, 0, 0, 0, err
|
||||
}
|
||||
return p.capturer.Cursor()
|
||||
}
|
||||
|
||||
// CursorPos forwards to the lazy CGCapturer.
|
||||
func (p *MacPoller) CursorPos() (int, int, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if err := p.ensureCapturerLocked(); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return p.capturer.CursorPos()
|
||||
}
|
||||
@@ -1,410 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
procGetCursorInfo = user32.NewProc("GetCursorInfo")
|
||||
procGetIconInfo = user32.NewProc("GetIconInfo")
|
||||
procGetObjectW = gdi32.NewProc("GetObjectW")
|
||||
procGetDIBits = gdi32.NewProc("GetDIBits")
|
||||
)
|
||||
|
||||
const (
|
||||
cursorShowing = 0x00000001
|
||||
diRgbColors = 0
|
||||
biRgb = 0
|
||||
dibSectionBytes = 40 // sizeof(BITMAPINFOHEADER)
|
||||
)
|
||||
|
||||
// hiddenHandle is a sentinel stored in cursorSampler.lastHandle while
|
||||
// Windows reports the cursor as hidden. It is not a valid HCURSOR value;
|
||||
// real handles never collide with this constant.
|
||||
const hiddenHandle = windows.Handle(^uintptr(0))
|
||||
|
||||
// transparentCursorImage returns a 1x1 fully transparent sprite. The
|
||||
// client renders this as "no cursor"; emitting it explicitly lets us
|
||||
// recover when an app un-hides the cursor a moment later.
|
||||
func transparentCursorImage() *image.RGBA {
|
||||
return image.NewRGBA(image.Rect(0, 0, 1, 1))
|
||||
}
|
||||
|
||||
type winPoint struct {
|
||||
X, Y int32
|
||||
}
|
||||
|
||||
type winCursorInfo struct {
|
||||
Size uint32
|
||||
Flags uint32
|
||||
Cursor windows.Handle
|
||||
PtPos winPoint
|
||||
}
|
||||
|
||||
type winIconInfo struct {
|
||||
FIcon int32
|
||||
XHotspot uint32
|
||||
YHotspot uint32
|
||||
HbmMask windows.Handle
|
||||
HbmColor windows.Handle
|
||||
}
|
||||
|
||||
type winBitmap struct {
|
||||
BmType int32
|
||||
BmWidth int32
|
||||
BmHeight int32
|
||||
BmWidthBytes int32
|
||||
BmPlanes uint16
|
||||
BmBitsPixel uint16
|
||||
BmBits uintptr
|
||||
}
|
||||
|
||||
type winBitmapInfoHeader struct {
|
||||
BiSize uint32
|
||||
BiWidth int32
|
||||
BiHeight int32
|
||||
BiPlanes uint16
|
||||
BiBitCount uint16
|
||||
BiCompression uint32
|
||||
BiSizeImage uint32
|
||||
BiXPelsPerMeter int32
|
||||
BiYPelsPerMeter int32
|
||||
BiClrUsed uint32
|
||||
BiClrImportant uint32
|
||||
}
|
||||
|
||||
// cursorSnapshot is the captured cursor state shared between the worker
|
||||
// (which polls the OS) and the session encoder (which reads it).
|
||||
type cursorSnapshot struct {
|
||||
img *image.RGBA
|
||||
hotX int
|
||||
hotY int
|
||||
posX int
|
||||
posY int
|
||||
hasPos bool
|
||||
serial uint64
|
||||
err error
|
||||
}
|
||||
|
||||
// cursorSampler captures the foreground process's cursor sprite via Win32
|
||||
// APIs. It must be called from a goroutine attached to the same window
|
||||
// station and desktop as the user session (the capture worker does this
|
||||
// via switchToInputDesktop). lastHandle dedupes per-shape work so we only
|
||||
// touch GDI when Windows hands us a new cursor.
|
||||
type cursorSampler struct {
|
||||
lastHandle windows.Handle
|
||||
serial uint64
|
||||
snapshot *cursorSnapshot
|
||||
}
|
||||
|
||||
// sample queries the current cursor and decodes a new sprite when Windows
|
||||
// reports a different HCURSOR than last time. Returns the current snapshot
|
||||
// regardless of whether anything changed; callers diff by serial.
|
||||
func (s *cursorSampler) sample() (*cursorSnapshot, error) {
|
||||
var ci winCursorInfo
|
||||
ci.Size = uint32(unsafe.Sizeof(ci))
|
||||
r, _, err := procGetCursorInfo.Call(uintptr(unsafe.Pointer(&ci)))
|
||||
if r == 0 {
|
||||
return nil, fmt.Errorf("GetCursorInfo: %w", err)
|
||||
}
|
||||
if ci.Flags&cursorShowing == 0 || ci.Cursor == 0 {
|
||||
// Cursor temporarily hidden by an app (text fields toggle it on
|
||||
// focus). Emit a 1x1 transparent sprite so the client renders no
|
||||
// cursor and stay armed for the next handle change rather than
|
||||
// treating this as a hard failure that would latch us off for
|
||||
// the session.
|
||||
if s.lastHandle == hiddenHandle {
|
||||
s.snapshot.posX = int(ci.PtPos.X)
|
||||
s.snapshot.posY = int(ci.PtPos.Y)
|
||||
s.snapshot.hasPos = true
|
||||
return s.snapshot, nil
|
||||
}
|
||||
s.lastHandle = hiddenHandle
|
||||
s.serial++
|
||||
s.snapshot = &cursorSnapshot{
|
||||
img: transparentCursorImage(),
|
||||
posX: int(ci.PtPos.X),
|
||||
posY: int(ci.PtPos.Y),
|
||||
hasPos: true,
|
||||
serial: s.serial,
|
||||
}
|
||||
return s.snapshot, nil
|
||||
}
|
||||
if ci.Cursor == s.lastHandle && s.snapshot != nil {
|
||||
s.snapshot.posX = int(ci.PtPos.X)
|
||||
s.snapshot.posY = int(ci.PtPos.Y)
|
||||
s.snapshot.hasPos = true
|
||||
return s.snapshot, nil
|
||||
}
|
||||
img, hotX, hotY, err := decodeCursor(ci.Cursor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.lastHandle = ci.Cursor
|
||||
s.serial++
|
||||
s.snapshot = &cursorSnapshot{
|
||||
img: img,
|
||||
hotX: hotX,
|
||||
hotY: hotY,
|
||||
posX: int(ci.PtPos.X),
|
||||
posY: int(ci.PtPos.Y),
|
||||
hasPos: true,
|
||||
serial: s.serial,
|
||||
}
|
||||
return s.snapshot, nil
|
||||
}
|
||||
|
||||
// decodeCursor extracts the sprite at hCur as RGBA along with the hotspot.
|
||||
// Color cursors are read from the colour bitmap with the AND mask combined
|
||||
// in for alpha. Monochrome cursors collapse the two halves of the mask
|
||||
// bitmap into a single visible sprite where the AND bit drives alpha.
|
||||
func decodeCursor(hCur windows.Handle) (*image.RGBA, int, int, error) {
|
||||
var info winIconInfo
|
||||
r, _, err := procGetIconInfo.Call(uintptr(hCur), uintptr(unsafe.Pointer(&info)))
|
||||
if r == 0 {
|
||||
return nil, 0, 0, fmt.Errorf("GetIconInfo: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if info.HbmMask != 0 {
|
||||
_, _, _ = procDeleteObject.Call(uintptr(info.HbmMask))
|
||||
}
|
||||
if info.HbmColor != 0 {
|
||||
_, _, _ = procDeleteObject.Call(uintptr(info.HbmColor))
|
||||
}
|
||||
}()
|
||||
hotX, hotY := int(info.XHotspot), int(info.YHotspot)
|
||||
if info.HbmColor != 0 {
|
||||
img, err := decodeColorCursor(info.HbmColor, info.HbmMask)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
return img, hotX, hotY, nil
|
||||
}
|
||||
img, err := decodeMonoCursor(info.HbmMask)
|
||||
if err != nil {
|
||||
return nil, 0, 0, err
|
||||
}
|
||||
return img, hotX, hotY, nil
|
||||
}
|
||||
|
||||
// readBitmap returns the BITMAP descriptor for hbm.
|
||||
func readBitmap(hbm windows.Handle) (winBitmap, error) {
|
||||
var bm winBitmap
|
||||
r, _, err := procGetObjectW.Call(uintptr(hbm), unsafe.Sizeof(bm), uintptr(unsafe.Pointer(&bm)))
|
||||
if r == 0 {
|
||||
return winBitmap{}, fmt.Errorf("GetObject: %w", err)
|
||||
}
|
||||
return bm, nil
|
||||
}
|
||||
|
||||
// dibCopy reads hbm as 32bpp top-down BGRA into a freshly allocated slice
|
||||
// matching w*h*4 bytes. The bitmap may be selected into the screen DC so
|
||||
// we use a memory DC to keep the call cheap.
|
||||
func dibCopy(hbm windows.Handle, w, h int32) ([]byte, error) {
|
||||
hdcScreen, _, _ := procGetDC.Call(0)
|
||||
if hdcScreen == 0 {
|
||||
return nil, fmt.Errorf("GetDC: failed")
|
||||
}
|
||||
defer func() { _, _, _ = procReleaseDC.Call(0, hdcScreen) }()
|
||||
hdcMem, _, _ := procCreateCompatDC.Call(hdcScreen)
|
||||
if hdcMem == 0 {
|
||||
return nil, fmt.Errorf("CreateCompatibleDC: failed")
|
||||
}
|
||||
defer func() { _, _, _ = procDeleteDC.Call(hdcMem) }()
|
||||
|
||||
var bih winBitmapInfoHeader
|
||||
bih.BiSize = dibSectionBytes
|
||||
bih.BiWidth = w
|
||||
bih.BiHeight = -h // top-down
|
||||
bih.BiPlanes = 1
|
||||
bih.BiBitCount = 32
|
||||
bih.BiCompression = biRgb
|
||||
|
||||
if w <= 0 || h <= 0 || w > maxCursorDim || h > maxCursorDim {
|
||||
return nil, fmt.Errorf("dibCopy: cursor dims %dx%d out of range (max %d)", w, h, maxCursorDim)
|
||||
}
|
||||
buf := make([]byte, int(w)*int(h)*4)
|
||||
r, _, err := procGetDIBits.Call(
|
||||
hdcMem,
|
||||
uintptr(hbm),
|
||||
0,
|
||||
uintptr(h),
|
||||
uintptr(unsafe.Pointer(&buf[0])),
|
||||
uintptr(unsafe.Pointer(&bih)),
|
||||
diRgbColors,
|
||||
)
|
||||
if r == 0 {
|
||||
return nil, fmt.Errorf("GetDIBits: %w", err)
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// decodeColorCursor reads a 32bpp colour cursor and folds the AND mask into
|
||||
// the alpha channel when the colour bitmap leaves it zero.
|
||||
func decodeColorCursor(hbmColor, hbmMask windows.Handle) (*image.RGBA, error) {
|
||||
bm, err := readBitmap(hbmColor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w, h := bm.BmWidth, bm.BmHeight
|
||||
color, err := dibCopy(hbmColor, w, h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var mask []byte
|
||||
if hbmMask != 0 {
|
||||
mask, _ = dibCopy(hbmMask, w, h)
|
||||
}
|
||||
hasAlpha := colorHasAlpha(color)
|
||||
img := image.NewRGBA(image.Rect(0, 0, int(w), int(h)))
|
||||
for y := int32(0); y < h; y++ {
|
||||
for x := int32(0); x < w; x++ {
|
||||
si := (y*w + x) * 4
|
||||
b := color[si]
|
||||
g := color[si+1]
|
||||
r := color[si+2]
|
||||
a := pixelAlpha(color[si+3], si, mask, hasAlpha)
|
||||
// Premultiply so the shared compositor can use the same
|
||||
// formula on every platform (X11 XFixes and macOS CG return
|
||||
// premultiplied bytes natively).
|
||||
if a != 255 && a != 0 {
|
||||
r = byte(uint32(r) * uint32(a) / 255)
|
||||
g = byte(uint32(g) * uint32(a) / 255)
|
||||
b = byte(uint32(b) * uint32(a) / 255)
|
||||
} else if a == 0 {
|
||||
r, g, b = 0, 0, 0
|
||||
}
|
||||
img.Pix[si+0] = r
|
||||
img.Pix[si+1] = g
|
||||
img.Pix[si+2] = b
|
||||
img.Pix[si+3] = a
|
||||
}
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// colorHasAlpha reports whether any pixel of a 32bpp BGRA buffer has a
|
||||
// non-zero alpha. Cursors authored without alpha leave the channel at 0
|
||||
// and rely on hbmMask for transparency.
|
||||
func colorHasAlpha(color []byte) bool {
|
||||
for i := 0; i < len(color); i += 4 {
|
||||
if color[i+3] != 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// pixelAlpha returns the effective alpha for a colour-cursor pixel. When
|
||||
// the source bitmap already has alpha we trust it; otherwise the AND mask
|
||||
// decides (1 = transparent, 0 = opaque). The 32bpp DIB stores each AND
|
||||
// bit as a 4-byte entry; the first byte carries the effective value.
|
||||
func pixelAlpha(colorA byte, si int32, mask []byte, hasAlpha bool) byte {
|
||||
if hasAlpha {
|
||||
return colorA
|
||||
}
|
||||
if mask != nil && mask[si] != 0 {
|
||||
return 0
|
||||
}
|
||||
return 255
|
||||
}
|
||||
|
||||
// decodeMonoCursor handles legacy 1bpp cursors where hbmMask is twice as
|
||||
// tall as the visible sprite: rows [0..h) are the AND mask and rows [h..2h)
|
||||
// are the XOR mask. We render the visible half into RGBA, treating
|
||||
// AND-mask=1 as transparent and the XOR bit as a black/white pixel.
|
||||
func decodeMonoCursor(hbmMask windows.Handle) (*image.RGBA, error) {
|
||||
bm, err := readBitmap(hbmMask)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w, fullH := bm.BmWidth, bm.BmHeight
|
||||
if fullH%2 != 0 {
|
||||
return nil, fmt.Errorf("unexpected mono cursor shape: %dx%d", w, fullH)
|
||||
}
|
||||
h := fullH / 2
|
||||
data, err := dibCopy(hbmMask, w, fullH)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
img := image.NewRGBA(image.Rect(0, 0, int(w), int(h)))
|
||||
for y := int32(0); y < h; y++ {
|
||||
for x := int32(0); x < w; x++ {
|
||||
and := data[(y*w+x)*4]
|
||||
xor := data[((y+h)*w+x)*4]
|
||||
di := (y*w + x) * 4
|
||||
if and != 0 {
|
||||
img.Pix[di+3] = 0
|
||||
continue
|
||||
}
|
||||
c := byte(0)
|
||||
if xor != 0 {
|
||||
c = 255
|
||||
}
|
||||
img.Pix[di+0] = c
|
||||
img.Pix[di+1] = c
|
||||
img.Pix[di+2] = c
|
||||
img.Pix[di+3] = 255
|
||||
}
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
// cursorState is the latest snapshot shared between the worker and
|
||||
// session readers.
|
||||
type cursorState struct {
|
||||
mu sync.Mutex
|
||||
snapshot *cursorSnapshot
|
||||
}
|
||||
|
||||
func (s *cursorState) store(snap *cursorSnapshot) {
|
||||
s.mu.Lock()
|
||||
s.snapshot = snap
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *cursorState) load() *cursorSnapshot {
|
||||
s.mu.Lock()
|
||||
snap := s.snapshot
|
||||
s.mu.Unlock()
|
||||
return snap
|
||||
}
|
||||
|
||||
// Cursor satisfies cursorSource by returning the latest snapshot the
|
||||
// capture worker decoded. The "no sample yet" and "cursor hidden" cases
|
||||
// return img=nil with no error so callers skip emission this cycle
|
||||
// without latching the source off for the rest of the session.
|
||||
func (c *DesktopCapturer) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
snap := c.cursorState.load()
|
||||
if snap == nil {
|
||||
return nil, 0, 0, 0, nil
|
||||
}
|
||||
if snap.err != nil {
|
||||
return nil, 0, 0, 0, snap.err
|
||||
}
|
||||
return snap.img, snap.hotX, snap.hotY, snap.serial, nil
|
||||
}
|
||||
|
||||
// CursorPos returns the cursor screen position observed by the worker on
|
||||
// its last sample. Errors out if the worker hasn't yet captured a frame
|
||||
// or the most recent sample failed.
|
||||
func (c *DesktopCapturer) CursorPos() (int, int, error) {
|
||||
snap := c.cursorState.load()
|
||||
if snap == nil {
|
||||
return 0, 0, fmt.Errorf("cursor position not sampled yet")
|
||||
}
|
||||
if snap.err != nil {
|
||||
return 0, 0, snap.err
|
||||
}
|
||||
if !snap.hasPos {
|
||||
return 0, 0, fmt.Errorf("cursor position unavailable")
|
||||
}
|
||||
return snap.posX, snap.posY, nil
|
||||
}
|
||||
@@ -1,127 +0,0 @@
|
||||
//go:build unix && !darwin && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"sync"
|
||||
|
||||
"github.com/jezek/xgb"
|
||||
"github.com/jezek/xgb/xfixes"
|
||||
)
|
||||
|
||||
// xfixesCursor reports the current X cursor sprite via the XFixes extension.
|
||||
// CursorSerial changes whenever the server picks a different cursor, so
|
||||
// callers can cache by serial without comparing pixels.
|
||||
type xfixesCursor struct {
|
||||
mu sync.Mutex
|
||||
conn *xgb.Conn
|
||||
// lastPosX/lastPosY hold the cursor screen position observed on the
|
||||
// most recent successful GetCursorImage. cursorPositionSource readers
|
||||
// share this value so we do not pay a second X round-trip per frame.
|
||||
lastPosX, lastPosY int
|
||||
hasPos bool
|
||||
// lastImg, lastHotX, lastHotY, lastSerial cache the most recent good
|
||||
// GetCursorImage result so transient failures (cursor hidden, server
|
||||
// briefly unresponsive) reuse the previous sprite instead of going
|
||||
// dark. Without this the encoder's compositing path drops to no-op as
|
||||
// soon as the cursor becomes momentarily unavailable.
|
||||
lastImg *image.RGBA
|
||||
lastHotX int
|
||||
lastHotY int
|
||||
lastSerial uint64
|
||||
}
|
||||
|
||||
// newXFixesCursor initialises the XFixes extension on conn. Returns an
|
||||
// error if the extension is unavailable; callers can fall back to no
|
||||
// cursor emission instead of asking on every frame.
|
||||
func newXFixesCursor(conn *xgb.Conn) (*xfixesCursor, error) {
|
||||
if err := xfixes.Init(conn); err != nil {
|
||||
return nil, fmt.Errorf("xfixes init: %w", err)
|
||||
}
|
||||
if _, err := xfixes.QueryVersion(conn, 4, 0).Reply(); err != nil {
|
||||
return nil, fmt.Errorf("xfixes query version: %w", err)
|
||||
}
|
||||
return &xfixesCursor{conn: conn}, nil
|
||||
}
|
||||
|
||||
// Cursor returns the current cursor sprite as RGBA along with its hotspot
|
||||
// and serial. Callers should treat an unchanged serial as "no update". On
|
||||
// a transient GetCursorImage failure the last cached sprite is returned
|
||||
// so compositing keeps painting the cursor instead of disappearing.
|
||||
func (c *xfixesCursor) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
reply, err := xfixes.GetCursorImage(c.conn).Reply()
|
||||
if err != nil {
|
||||
if c.lastImg != nil {
|
||||
return c.lastImg, c.lastHotX, c.lastHotY, c.lastSerial, nil
|
||||
}
|
||||
return nil, 0, 0, 0, fmt.Errorf("xfixes GetCursorImage: %w", err)
|
||||
}
|
||||
c.lastPosX, c.lastPosY, c.hasPos = int(reply.X), int(reply.Y), true
|
||||
w, h := int(reply.Width), int(reply.Height)
|
||||
if w <= 0 || h <= 0 {
|
||||
if c.lastImg != nil {
|
||||
return c.lastImg, c.lastHotX, c.lastHotY, c.lastSerial, nil
|
||||
}
|
||||
return nil, 0, 0, 0, fmt.Errorf("cursor has zero extent")
|
||||
}
|
||||
if len(reply.CursorImage) < w*h {
|
||||
if c.lastImg != nil {
|
||||
return c.lastImg, c.lastHotX, c.lastHotY, c.lastSerial, nil
|
||||
}
|
||||
return nil, 0, 0, 0, fmt.Errorf("cursor pixel buffer truncated: %d < %d", len(reply.CursorImage), w*h)
|
||||
}
|
||||
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
// XFixes packs each pixel as a uint32 in ARGB order with premultiplied
|
||||
// alpha. Unpack into the standard RGBA byte layout.
|
||||
for i, p := range reply.CursorImage[:w*h] {
|
||||
o := i * 4
|
||||
img.Pix[o+0] = byte(p >> 16)
|
||||
img.Pix[o+1] = byte(p >> 8)
|
||||
img.Pix[o+2] = byte(p)
|
||||
img.Pix[o+3] = byte(p >> 24)
|
||||
}
|
||||
c.lastImg = img
|
||||
c.lastHotX = int(reply.Xhot)
|
||||
c.lastHotY = int(reply.Yhot)
|
||||
c.lastSerial = uint64(reply.CursorSerial)
|
||||
return img, c.lastHotX, c.lastHotY, c.lastSerial, nil
|
||||
}
|
||||
|
||||
// Cursor on X11Capturer satisfies cursorSource. The XFixes binding is
|
||||
// created lazily on the same X connection used for screen capture; the
|
||||
// first init failure is latched so we stop asking on every frame.
|
||||
func (x *X11Capturer) Cursor() (*image.RGBA, int, int, uint64, error) {
|
||||
x.mu.Lock()
|
||||
if x.cursor == nil && x.cursorInitErr == nil {
|
||||
x.cursor, x.cursorInitErr = newXFixesCursor(x.conn)
|
||||
}
|
||||
cur := x.cursor
|
||||
initErr := x.cursorInitErr
|
||||
x.mu.Unlock()
|
||||
if initErr != nil {
|
||||
return nil, 0, 0, 0, initErr
|
||||
}
|
||||
return cur.Cursor()
|
||||
}
|
||||
|
||||
// CursorPos on X11Capturer returns the screen position from the most
|
||||
// recent successful Cursor() call. Sessions call Cursor() once per encode
|
||||
// cycle, so this stays current without a second X round-trip.
|
||||
func (x *X11Capturer) CursorPos() (int, int, error) {
|
||||
x.mu.Lock()
|
||||
cur := x.cursor
|
||||
x.mu.Unlock()
|
||||
if cur == nil {
|
||||
return 0, 0, fmt.Errorf("cursor source not initialised")
|
||||
}
|
||||
cur.mu.Lock()
|
||||
defer cur.mu.Unlock()
|
||||
if !cur.hasPos {
|
||||
return 0, 0, fmt.Errorf("cursor position not sampled yet")
|
||||
}
|
||||
return cur.lastPosX, cur.lastPosY, nil
|
||||
}
|
||||
@@ -1,159 +0,0 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/zlib"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// ExtendedClipboard is an RFB community extension (pseudo-encoding
|
||||
// 0xC0A1E5CE) that replaces legacy CutText with a Caps/Notify/Request/
|
||||
// Provide/Peek handshake. Wins versus legacy CutText:
|
||||
// - UTF-8 text format (legacy is Latin-1).
|
||||
// - Pull-based: a Notify announces "I have new content", the peer fetches
|
||||
// via Request only when it actually needs the data. Saves bandwidth on
|
||||
// high-latency transports versus pushing every change.
|
||||
// - zlib-compressed payloads.
|
||||
// - Caps negotiation so each side knows the other's per-format max size.
|
||||
//
|
||||
// The extension reuses message opcodes 3 (ServerCutText) and 6 (ClientCutText)
|
||||
// and signals "extended" by encoding the length field as a negative int32;
|
||||
// the absolute value is the payload size in bytes. The first 4 bytes of
|
||||
// payload are a flags word: top byte is the action, low 16 bits are the
|
||||
// format mask.
|
||||
const pseudoEncExtendedClipboard = -1063131698 // 0xC0A1E5CE as int32
|
||||
|
||||
const (
|
||||
extClipActionCaps uint32 = 0x01000000
|
||||
extClipActionRequest uint32 = 0x02000000
|
||||
extClipActionPeek uint32 = 0x04000000
|
||||
extClipActionNotify uint32 = 0x08000000
|
||||
extClipActionProvide uint32 = 0x10000000
|
||||
extClipActionMask uint32 = 0x1F000000
|
||||
|
||||
extClipFormatText uint32 = 0x00000001
|
||||
extClipFormatRTF uint32 = 0x00000002
|
||||
extClipFormatHTML uint32 = 0x00000004
|
||||
extClipFormatDIB uint32 = 0x00000008
|
||||
extClipFormatFiles uint32 = 0x00000010
|
||||
extClipFormatMask uint32 = 0x0000FFFF
|
||||
|
||||
// extClipMaxText caps our accepted text payload. Mirrors the legacy
|
||||
// maxCutTextBytes (1 MiB); advertised in Caps and enforced on Provide.
|
||||
extClipMaxText = maxCutTextBytes
|
||||
|
||||
// extClipMaxPayload bounds the raw on-wire payload we will read for an
|
||||
// extended CutText message. Includes flags header, length prefixes, NUL,
|
||||
// and zlib framing overhead on top of the text body.
|
||||
extClipMaxPayload = extClipMaxText + 1024
|
||||
)
|
||||
|
||||
// buildExtClipCaps emits the Caps payload. The flags word advertises every
|
||||
// action we support in the high byte (Caps + Request + Peek + Notify +
|
||||
// Provide) and every format we accept in the low 16 bits. Clients use
|
||||
// these action bits to decide whether to auto-Request on Notify; without
|
||||
// Request in our Caps a conforming client silently drops our Notify
|
||||
// messages. After the flags word we emit one uint32 max size per format
|
||||
// bit set, in ascending bit order.
|
||||
func buildExtClipCaps() []byte {
|
||||
flags := extClipActionCaps | extClipActionRequest | extClipActionPeek |
|
||||
extClipActionNotify | extClipActionProvide | extClipFormatText
|
||||
payload := make([]byte, 4+4)
|
||||
binary.BigEndian.PutUint32(payload[0:4], flags)
|
||||
binary.BigEndian.PutUint32(payload[4:8], uint32(extClipMaxText))
|
||||
return payload
|
||||
}
|
||||
|
||||
// buildExtClipNotify emits a Notify announcing that we have new clipboard
|
||||
// content available in the given format mask. No data is shipped; the peer
|
||||
// pulls via Request when it actually needs to paste.
|
||||
func buildExtClipNotify(formats uint32) []byte {
|
||||
payload := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(payload, extClipActionNotify|formats)
|
||||
return payload
|
||||
}
|
||||
|
||||
// buildExtClipRequest emits a Request asking the peer to send Provide for
|
||||
// the given format mask. Sent in response to an inbound Notify.
|
||||
func buildExtClipRequest(formats uint32) []byte {
|
||||
payload := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(payload, extClipActionRequest|formats)
|
||||
return payload
|
||||
}
|
||||
|
||||
// buildExtClipProvideText emits a Provide carrying UTF-8 text. The inner
|
||||
// stream (4-byte length including the trailing NUL, then UTF-8 bytes, then
|
||||
// NUL) is zlib-compressed; each Provide uses an independent zlib context
|
||||
// per the extension spec. Rejects oversized input so a caller bug can't
|
||||
// produce a payload larger than the size advertised in our Caps.
|
||||
func buildExtClipProvideText(text string) ([]byte, error) {
|
||||
if len(text)+1 > extClipMaxText {
|
||||
return nil, fmt.Errorf("clipboard text exceeds extClipMaxText (%d > %d)", len(text)+1, extClipMaxText)
|
||||
}
|
||||
body := make([]byte, 0, 4+len(text)+1)
|
||||
var lenBuf [4]byte
|
||||
binary.BigEndian.PutUint32(lenBuf[:], uint32(len(text)+1))
|
||||
body = append(body, lenBuf[:]...)
|
||||
body = append(body, text...)
|
||||
body = append(body, 0)
|
||||
|
||||
var compressed bytes.Buffer
|
||||
zw := zlib.NewWriter(&compressed)
|
||||
if _, err := zw.Write(body); err != nil {
|
||||
return nil, fmt.Errorf("zlib write: %w", err)
|
||||
}
|
||||
if err := zw.Close(); err != nil {
|
||||
return nil, fmt.Errorf("zlib close: %w", err)
|
||||
}
|
||||
|
||||
payload := make([]byte, 4+compressed.Len())
|
||||
binary.BigEndian.PutUint32(payload[0:4], extClipActionProvide|extClipFormatText)
|
||||
copy(payload[4:], compressed.Bytes())
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// parseExtClipProvideText decompresses a Provide payload (the bytes after
|
||||
// the 4-byte flags header) and returns the UTF-8 text record if the text
|
||||
// format bit is set. Records for other formats are skipped. The trailing
|
||||
// NUL byte the spec appends to text records is stripped.
|
||||
func parseExtClipProvideText(flags uint32, payload []byte) (string, error) {
|
||||
zr, err := zlib.NewReader(bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("zlib reader: %w", err)
|
||||
}
|
||||
defer zr.Close()
|
||||
|
||||
limited := io.LimitReader(zr, int64(extClipMaxText)+16)
|
||||
var text string
|
||||
for bit := uint32(1); bit <= extClipFormatFiles; bit <<= 1 {
|
||||
if flags&bit == 0 {
|
||||
continue
|
||||
}
|
||||
var sizeBuf [4]byte
|
||||
if _, err := io.ReadFull(limited, sizeBuf[:]); err != nil {
|
||||
if bit == extClipFormatText && err == io.EOF {
|
||||
return "", nil
|
||||
}
|
||||
return "", fmt.Errorf("read record size: %w", err)
|
||||
}
|
||||
size := binary.BigEndian.Uint32(sizeBuf[:])
|
||||
if size > uint32(extClipMaxText) {
|
||||
return "", fmt.Errorf("record too large: %d", size)
|
||||
}
|
||||
rec := make([]byte, size)
|
||||
if _, err := io.ReadFull(limited, rec); err != nil {
|
||||
return "", fmt.Errorf("read record: %w", err)
|
||||
}
|
||||
if bit == extClipFormatText {
|
||||
if len(rec) > 0 && rec[len(rec)-1] == 0 {
|
||||
rec = rec[:len(rec)-1]
|
||||
}
|
||||
text = string(rec)
|
||||
}
|
||||
}
|
||||
return text, nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user