Compare commits

...

25 Commits

Author SHA1 Message Date
Viktor Liu
b754df1171 Add embedded VNC server with JWT auth, DXGI capture, and dashboard integration 2026-04-20 12:26:40 +02:00
Zoltan Papp
3098f48b25 [client] fix ios network addresses mac filter (#5906)
* fix(client): skip MAC address filter for network addresses on iOS

iOS does not expose hardware (MAC) addresses due to Apple's privacy
restrictions (since iOS 14), causing networkAddresses() to return an
empty list because all interfaces are filtered out by the HardwareAddr
check. Move networkAddresses() to platform-specific files so iOS can
skip this filter.
2026-04-20 11:49:38 +02:00
Zoltan Papp
7f023ce801 [client] Android debug bundle support (#5888)
Add Android debug bundle support with Troubleshoot UI
2026-04-20 11:26:30 +02:00
Michael Uray
e361126515 [client] Fix WGIface.Close deadlock when DNS filter hook re-enters GetDevice (#5916)
WGIface.Close() took w.mu and held it across w.tun.Close(). The
underlying wireguard-go device waits for its send/receive goroutines to
drain before Close() returns, and some of those goroutines re-enter
WGIface during shutdown. In particular, the userspace packet filter DNS
hook in client/internal/dns.ServiceViaMemory.filterDNSTraffic calls
s.wgInterface.GetDevice() on every packet, which also needs w.mu. With
the Close-side holding the mutex, the read goroutine blocks in
GetDevice and Close waits forever for that goroutine to exit:

  goroutine N (TestDNSPermanent_updateUpstream):
    WGIface.Close -> holds w.mu -> tun.Close -> sync.WaitGroup.Wait
  goroutine M (wireguard read routine):
    FilteredDevice.Read -> filterOutbound -> udpHooksDrop ->
    filterDNSTraffic.func1 -> WGIface.GetDevice -> sync.Mutex.Lock

This surfaces as a 5 minute test timeout on the macOS Client/Unit
CI job (panic: test timed out after 5m0s, running tests:
TestDNSPermanent_updateUpstream).

Release w.mu before calling w.tun.Close(). The other Close steps
(wgProxyFactory.Free, waitUntilRemoved, Destroy) do not mutate any
fields guarded by w.mu beyond what Free() already does, so the lock
is not needed once the tun has started shutting down. A new unit test
in iface_close_test.go uses a fake WGTunDevice to reproduce the
deadlock deterministically without requiring CAP_NET_ADMIN.
2026-04-20 10:36:19 +02:00
Viktor Liu
95213f7157 [client] Use Match host+exec instead of Host+Match in SSH client config (#5903) 2026-04-20 10:24:11 +02:00
Viktor Liu
2e0e3a3601 [client] Replace exclusion routes with scoped default + IP_BOUND_IF on macOS (#5918) 2026-04-20 10:01:01 +02:00
Nicolas Frati
8ae8f2098f [management] chores: fix lint error on google workspace (#5907)
* chores: fix lint error on google workspace

* chores: updated google api dependency

* update google golang api sdk to latest
2026-04-16 20:02:09 +02:00
Viktor Liu
a39787d679 [infrastructure] Add CrowdSec LAPI container to self-hosted setup script (#5880) 2026-04-16 18:06:38 +02:00
Maycon Santos
53b04e512a [management] Reuse a single cache store across all management server consumers (#5889)
* Add support for legacy IDP cache environment variable

* Centralize cache store creation to reuse a single Redis connection pool

Each cache consumer (IDP cache, token store, PKCE store, secrets manager,
EDR validator) was independently calling NewStore, creating separate Redis
clients with their own connection pools — up to 1400 potential connections
from a single management server process.

Introduce a shared CacheStore() singleton on BaseServer that creates one
store at boot and injects it into all consumers. Consumer constructors now
receive a store.StoreInterface instead of creating their own.

For Redis mode, all consumers share one connection pool (1000 max conns).
For in-memory mode, all consumers share one GoCache instance.

* Update management-integrations module to latest version

* sync go.sum

* Export `GetAddrFromEnv` to allow reuse across packages

* Update management-integrations module version in go.mod and go.sum

* Update management-integrations module version in go.mod and go.sum
2026-04-16 16:04:53 +02:00
Viktor Liu
633dde8d1f [client] Reconnect conntrack netlink listener on error (#5885) 2026-04-16 22:30:36 +09:00
Michael Uray
7e4542adde fix(client): populate NetworkAddresses on iOS for posture checks (#5900)
The iOS GetInfo() function never populated NetworkAddresses, causing
the peer_network_range_check posture check to fail for all iOS clients.

This adds the same networkAddresses() call that macOS, Linux, Windows,
and FreeBSD already use.

Fixes: #3968
Fixes: #4657
2026-04-16 14:25:55 +02:00
Viktor Liu
d4c61ed38b [client] Add mangle FORWARD guard to prevent Docker DNAT bypass of ACL rules (#5697) 2026-04-16 14:02:52 +02:00
Viktor Liu
6b540d145c [client] Add --disable-networks flag to block network selection (#5896) 2026-04-16 14:02:31 +02:00
Bethuel Mmbaga
08f624507d [management] Enforce peer or peer groups requirement for network routers (#5894) 2026-04-16 13:12:19 +03:00
Viktor Liu
95bc01e48f [client] Allow clearing saved service env vars with --service-env "" (#5893) 2026-04-15 19:22:08 +02:00
Viktor Liu
0d86de47df [client] Add PCP support (#5219) 2026-04-15 11:43:16 +02:00
Viktor Liu
e804a705b7 [infrastructure] Update sign pipeline version to v0.1.2 (#5884) 2026-04-14 17:08:35 +02:00
Pascal Fischer
46fc8c9f65 [proxy] direct redirect to SSO (#5874) 2026-04-14 13:47:02 +02:00
Viktor Liu
d7ad908962 [misc] Add CI check for proto version string changes (#5854)
* Add CI check for proto version string changes

* Handle pagination and missing patch data in proto version check
2026-04-14 13:36:26 +02:00
Pascal Fischer
c5623307cc [management] add context cancel monitoring (#5879) 2026-04-14 12:49:18 +02:00
Vlad
7f666b8022 [management] revert ctx dependency in get account with backpressure (#5878) 2026-04-14 12:16:03 +02:00
Viktor Liu
0a30b9b275 [management, proxy] Add CrowdSec IP reputation integration for reverse proxy (#5722) 2026-04-14 12:14:58 +02:00
Viktor Liu
4eed459f27 [client] Fix DNS resolution with userspace WireGuard and kernel firewall (#5873) 2026-04-13 16:23:57 +02:00
Zoltan Papp
13539543af [client] Fix/grpc retry (#5750)
* [client] Fix flow client Receive retry loop not stopping after Close

Use backoff.Permanent for canceled gRPC errors so Receive returns
immediately instead of retrying until context deadline when the
connection is already closed. Add TestNewClient_PermanentClose to
verify the behavior.

The connectivity.Shutdown check was meaningless because when the connection is
shut down, c.realClient.Events(ctx, grpc.WaitForReady(true)) on the nex line
already fails with codes.Canceled — which is now handled as a permanent error.
The explicit state check was just duplicating what gRPC already reports
through its normal error path.

* [client] remove WaitForReady from stream open call

grpc.WaitForReady(true) parks the RPC call internally until the
connection reaches READY, only unblocking on ctx cancellation.
This means the external backoff.Retry loop in Receive() never gets
control back during a connection outage — it cannot tick, log, or
apply its retry intervals while WaitForReady is blocking.

Removing it restores fail-fast behaviour: Events() returns immediately
with codes.Unavailable when the connection is not ready, which is
exactly what the backoff loop expects. The backoff becomes the single
authority over retry timing and cadence, as originally intended.

* [client] Add connection recreation and improve flow client error handling

Store gRPC dial options on the client to enable connection recreation
on Internal errors (RST_STREAM/PROTOCOL_ERROR). Treat Unauthenticated,
PermissionDenied, and Unimplemented as permanent failures. Unify mutex
usage and add reconnection logging for better observability.

* [client] Remove Unauthenticated, PermissionDenied, and Unimplemented from permanent error handling

* [client] Fix error handling in Receive to properly re-establish stream and improve reconnection messaging

* Fix test

* [client] Add graceful shutdown handling and test for concurrent Close during Receive

Prevent reconnection attempts after client closure by tracking a `closed` flag. Use `backoff.Permanent` for errors caused by operations on a closed client. Add a test to ensure `Close` does not block when `Receive` is actively running.

* [client] Fix connection swap to properly close old gRPC connection

Close the old `gRPC.ClientConn` after successfully swapping to a new connection during reconnection.

* [client] Reset backoff

* [client] Ensure stream closure on error during initialization

* [client] Add test for handling server-side stream closure and reconnection

Introduce `TestReceive_ServerClosesStream` to verify the client's ability to recover and process acknowledgments after the server closes the stream. Enhance test server with a controlled stream closure mechanism.

* [client] Add protocol error simulation and enhance reconnection test

Introduce `connTrackListener` to simulate HTTP/2 RST_STREAM with PROTOCOL_ERROR for testing. Refactor and rename `TestReceive_ServerClosesStream` to `TestReceive_ProtocolErrorStreamReconnect` to verify client recovery on protocol errors.

* [client] Update Close error message in test for clarity

* [client] Fine-tune the tests

* [client] Adjust connection tracking in reconnection test

* [client] Wait for Events handler to exit in RST_STREAM reconnection test

Ensure the old `Events` handler exits fully before proceeding in the reconnection test to avoid dropped acknowledgments on a broken stream. Add a `handlerDone` channel to synchronize handler exits.

* [client] Prevent panic on nil connection during Close

* [client] Refactor connection handling to use explicit target tracking

Introduce `target` field to store the gRPC connection target directly, simplifying reconnections and ensuring consistent connection reuse logic.

* [client] Rename `isCancellation` to `isContextDone` and extend handling for `DeadlineExceeded`

Refactor error handling to include `DeadlineExceeded` scenarios alongside `Canceled`. Update related condition checks for consistency.

* [client] Add connection generation tracking to prevent stale reconnections

Introduce `connGen` to track connection generations and ensure that stale `recreateConnection` calls do not override newer connections. Update stream establishment and reconnection logic to incorporate generation validation.

* [client] Add backoff reset condition to prevent short-lived retry cycles

Refine backoff reset logic to ensure it only occurs for sufficiently long-lived stream connections, avoiding interference with `MaxElapsedTime`.

* [client] Introduce `minHealthyDuration` to refine backoff reset logic

Add `minHealthyDuration` constant to ensure stream retries only reset the backoff timer if the stream survives beyond a minimum duration. Prevents unhealthy, short-lived streams from interfering with `MaxElapsedTime`.

* [client] IPv6 friendly connection

parsedURL.Hostname() strips IPv6 brackets. For http://[::1]:443, this turns it into ::1:443, which is not a valid host:port target for gRPC. Additionally, fmt.Sprintf("%s:%s", hostname, port) produces a trailing colon when the URL has no explicit port—http://example.com becomes example.com:. Both cases break the initial dial and reconnect paths. Use parsedURL.Host directly instead.

* [client] Add `handlerStarted` channel to synchronize stream establishment in tests

Introduce `handlerStarted` channel in the test server to signal when the server-side handler begins, ensuring robust synchronization between client and server during stream establishment. Update relevant test cases to wait for this signal before proceeding.

* [client] Replace `receivedAcks` map with atomic counter and improve stream establishment sync in tests

Refactor acknowledgment tracking in tests to use an `atomic.Int32` counter instead of a map. Replace fixed sleep with robust synchronization by waiting on `handlerStarted` signal for stream establishment.

* [client] Extract `handleReceiveError` to simplify receive logic

Refactor error handling in `receive` to a dedicated `handleReceiveError` method. Streamlines the main logic and isolates error recovery, including backoff reset and connection recreation.

* [client] recreate gRPC ClientConn on every retry to prevent dual backoff

The flow client had two competing retry loops: our custom exponential
backoff and gRPC's internal subchannel reconnection. When establishStream
failed, the same ClientConn was reused, allowing gRPC's internal backoff
state to accumulate and control dial timing independently.

Changes:
- Consolidate error handling into handleRetryableError, which now
 handles context cancellation, permanent errors, backoff reset,
 and connection recreation in a single path
- Call recreateConnection on every retryable error so each retry
 gets a fresh ClientConn with no internal backoff state
- Remove connGen tracking since Receive is sequential and protected
 by a new receiving guard against concurrent calls
- Reduce RandomizationFactor from 1 to 0.5 to avoid near-zero
 backoff intervals
2026-04-13 10:42:24 +02:00
Zoltan Papp
7483fec048 Fix Android internet blackhole caused by stale route re-injection on TUN rebuild (#5865)
extraInitialRoutes() was meant to preserve only the fake IP route
(240.0.0.0/8) across TUN rebuilds, but it re-injected any initial
route missing from the current set. When the management server
advertised exit node routes (0.0.0.0/0) that were later filtered
by the route selector, extraInitialRoutes() re-added them, causing
the Android VPN to capture all traffic with no peer to handle it.

Store the fake IP route explicitly and append only that in notify(),
removing the overly broad initial route diffing.
2026-04-13 09:38:38 +02:00
210 changed files with 16685 additions and 3225 deletions

View File

@@ -0,0 +1,62 @@
name: Proto Version Check
on:
pull_request:
paths:
- "**/*.pb.go"
jobs:
check-proto-versions:
runs-on: ubuntu-latest
steps:
- name: Check for proto tool version changes
uses: actions/github-script@v7
with:
script: |
const files = await github.paginate(github.rest.pulls.listFiles, {
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: context.issue.number,
per_page: 100,
});
const pbFiles = files.filter(f => f.filename.endsWith('.pb.go'));
const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename);
if (missingPatch.length > 0) {
core.setFailed(
`Cannot inspect patch data for:\n` +
missingPatch.map(f => `- ${f}`).join('\n') +
`\nThis can happen with very large PRs. Verify proto versions manually.`
);
return;
}
const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
const violations = [];
for (const file of pbFiles) {
const changed = file.patch
.split('\n')
.filter(line => versionPattern.test(line));
if (changed.length > 0) {
violations.push({
file: file.filename,
lines: changed,
});
}
}
if (violations.length > 0) {
const details = violations.map(v =>
`${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}`
).join('\n\n');
core.setFailed(
`Proto version strings changed in generated files.\n` +
`This usually means the wrong protoc or protoc-gen-go version was used.\n` +
`Regenerate with the matching tool versions.\n\n` +
details
);
return;
}
console.log('No proto version string changes detected');

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.1"
SIGN_PIPE_VER: "v0.1.2"
GORELEASER_VER: "v2.14.3"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"

View File

@@ -5,7 +5,7 @@ GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
$(GOLANGCI_LINT):
@echo "Installing golangci-lint..."
@mkdir -p ./bin
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
# Lint only changed files (fast, for pre-push)
lint: $(GOLANGCI_LINT)

View File

@@ -8,6 +8,7 @@ import (
"os"
"slices"
"sync"
"time"
"golang.org/x/exp/maps"
@@ -15,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -26,6 +28,7 @@ import (
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
types "github.com/netbirdio/netbird/upload-server/types"
)
// ConnectionListener export internal Listener for mobile
@@ -68,7 +71,30 @@ type Client struct {
uiVersion string
networkChangeListener listener.NetworkChangeListener
stateMu sync.RWMutex
connectClient *internal.ConnectClient
config *profilemanager.Config
cacheDir string
}
func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) {
c.stateMu.Lock()
defer c.stateMu.Unlock()
c.config = cfg
c.cacheDir = cacheDir
c.connectClient = cc
}
func (c *Client) stateSnapshot() (*profilemanager.Config, string, *internal.ConnectClient) {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.config, c.cacheDir, c.connectClient
}
func (c *Client) getConnectClient() *internal.ConnectClient {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.connectClient
}
// NewClient instantiate a new Client
@@ -93,6 +119,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
cacheDir := platformFiles.CacheDir()
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
@@ -124,8 +151,9 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, cacheDir, connectClient)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -135,6 +163,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
cacheDir := platformFiles.CacheDir()
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
@@ -157,8 +186,9 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, cacheDir, connectClient)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
}
// Stop the internal client and free the resources
@@ -173,11 +203,12 @@ func (c *Client) Stop() {
}
func (c *Client) RenewTun(fd int) error {
if c.connectClient == nil {
cc := c.getConnectClient()
if cc == nil {
return fmt.Errorf("engine not running")
}
e := c.connectClient.Engine()
e := cc.Engine()
if e == nil {
return fmt.Errorf("engine not initialized")
}
@@ -185,6 +216,73 @@ func (c *Client) RenewTun(fd int) error {
return e.RenewTun(fd)
}
// DebugBundle generates a debug bundle, uploads it, and returns the upload key.
// It works both with and without a running engine.
func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (string, error) {
cfg, cacheDir, cc := c.stateSnapshot()
// If the engine hasn't been started, load config from disk
if cfg == nil {
var err error
cfg, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: platformFiles.ConfigurationFilePath(),
})
if err != nil {
return "", fmt.Errorf("load config: %w", err)
}
cacheDir = platformFiles.CacheDir()
}
deps := debug.GeneratorDependencies{
InternalConfig: cfg,
StatusRecorder: c.recorder,
TempDir: cacheDir,
}
if cc != nil {
resp, err := cc.GetLatestSyncResponse()
if err != nil {
log.Warnf("get latest sync response: %v", err)
}
deps.SyncResponse = resp
if e := cc.Engine(); e != nil {
if cm := e.GetClientMetrics(); cm != nil {
deps.ClientMetrics = cm
}
}
}
bundleGenerator := debug.NewBundleGenerator(
deps,
debug.BundleConfig{
Anonymize: anonymize,
IncludeSystemInfo: true,
},
)
path, err := bundleGenerator.Generate()
if err != nil {
return "", fmt.Errorf("generate debug bundle: %w", err)
}
defer func() {
if err := os.Remove(path); err != nil {
log.Errorf("failed to remove debug bundle file: %v", err)
}
}()
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
if err != nil {
return "", fmt.Errorf("upload debug bundle: %w", err)
}
log.Infof("debug bundle uploaded with key %s", key)
return key, nil
}
// SetTraceLogLevel configure the logger to trace level
func (c *Client) SetTraceLogLevel() {
log.SetLevel(log.TraceLevel)
@@ -214,12 +312,13 @@ func (c *Client) PeersList() *PeerInfoArray {
}
func (c *Client) Networks() *NetworkArray {
if c.connectClient == nil {
cc := c.getConnectClient()
if cc == nil {
log.Error("not connected")
return nil
}
engine := c.connectClient.Engine()
engine := cc.Engine()
if engine == nil {
log.Error("could not get engine")
return nil
@@ -300,7 +399,7 @@ func (c *Client) toggleRoute(command routeCommand) error {
}
func (c *Client) getRouteManager() (routemanager.Manager, error) {
client := c.connectClient
client := c.getConnectClient()
if client == nil {
return nil, fmt.Errorf("not connected")
}

View File

@@ -7,4 +7,5 @@ package android
type PlatformFiles interface {
ConfigurationFilePath() string
StateFilePath() string
CacheDir() string
}

View File

@@ -75,6 +75,7 @@ var (
mtu uint16
profilesDisabled bool
updateSettingsDisabled bool
networksDisabled bool
rootCmd = &cobra.Command{
Use: "netbird",
@@ -150,6 +151,7 @@ func init() {
rootCmd.AddCommand(logoutCmd)
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(vncCmd)
rootCmd.AddCommand(networksCMD)
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd)

View File

@@ -44,10 +44,13 @@ func init() {
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks")
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
serviceEnvDesc := `Sets extra environment variables for the service. ` +
`You can specify a comma-separated list of KEY=VALUE pairs. ` +
`New keys are merged with previously saved env vars; existing keys are overwritten. ` +
`Use --service-env "" to clear all saved env vars. ` +
`E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value`
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)

View File

@@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error {
}
}
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled)
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, networksDisabled)
if err := serverInstance.Start(); err != nil {
log.Fatalf("failed to start daemon: %v", err)
}

View File

@@ -59,6 +59,10 @@ func buildServiceArguments() []string {
args = append(args, "--disable-update-settings")
}
if networksDisabled {
args = append(args, "--disable-networks")
}
return args
}

View File

@@ -28,6 +28,7 @@ type serviceParams struct {
LogFiles []string `json:"log_files,omitempty"`
DisableProfiles bool `json:"disable_profiles,omitempty"`
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
DisableNetworks bool `json:"disable_networks,omitempty"`
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
}
@@ -78,11 +79,12 @@ func currentServiceParams() *serviceParams {
LogFiles: logFiles,
DisableProfiles: profilesDisabled,
DisableUpdateSettings: updateSettingsDisabled,
DisableNetworks: networksDisabled,
}
if len(serviceEnvVars) > 0 {
parsed, err := parseServiceEnvVars(serviceEnvVars)
if err == nil && len(parsed) > 0 {
if err == nil {
params.ServiceEnvVars = parsed
}
}
@@ -142,31 +144,46 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
updateSettingsDisabled = params.DisableUpdateSettings
}
if !serviceCmd.PersistentFlags().Changed("disable-networks") {
networksDisabled = params.DisableNetworks
}
applyServiceEnvParams(cmd, params)
}
// applyServiceEnvParams merges saved service environment variables.
// If --service-env was explicitly set, explicit values win on key conflict
// but saved keys not in the explicit set are carried over.
// If --service-env was explicitly set with values, explicit values win on key
// conflict but saved keys not in the explicit set are carried over.
// If --service-env was explicitly set to empty, all saved env vars are cleared.
// If --service-env was not set, saved env vars are used entirely.
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
if len(params.ServiceEnvVars) == 0 {
return
}
if !cmd.Flags().Changed("service-env") {
// No explicit env vars: rebuild serviceEnvVars from saved params.
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
if len(params.ServiceEnvVars) > 0 {
// No explicit env vars: rebuild serviceEnvVars from saved params.
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
}
return
}
// Explicit env vars were provided: merge saved values underneath.
// Flag was explicitly set: parse what the user provided.
explicit, err := parseServiceEnvVars(serviceEnvVars)
if err != nil {
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
return
}
// If the user passed an empty value (e.g. --service-env ""), clear all
// saved env vars rather than merging.
if len(explicit) == 0 {
serviceEnvVars = nil
return
}
if len(params.ServiceEnvVars) == 0 {
return
}
// Merge saved values underneath explicit ones.
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
maps.Copy(merged, params.ServiceEnvVars)
maps.Copy(merged, explicit) // explicit wins on conflict

View File

@@ -327,6 +327,41 @@ func TestApplyServiceEnvParams_NotChanged(t *testing.T) {
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
}
func TestApplyServiceEnvParams_ExplicitEmptyClears(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
// Simulate --service-env "" which produces [""] in the slice.
serviceEnvVars = []string{""}
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
require.NoError(t, cmd.Flags().Set("service-env", ""))
saved := &serviceParams{
ServiceEnvVars: map[string]string{"OLD_VAR": "should_be_cleared"},
}
applyServiceEnvParams(cmd, saved)
assert.Nil(t, serviceEnvVars, "explicit empty --service-env should clear all saved env vars")
}
func TestCurrentServiceParams_EmptyEnvVarsAfterParse(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
// Simulate --service-env "" which produces [""] in the slice.
serviceEnvVars = []string{""}
params := currentServiceParams()
// After parsing, the empty string is skipped, resulting in an empty map.
// The map should still be set (not nil) so it overwrites saved values.
assert.NotNil(t, params.ServiceEnvVars, "empty env vars should produce empty map, not nil")
assert.Empty(t, params.ServiceEnvVars, "no valid env vars should be parsed from empty string")
}
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
// added to serviceParams but not wired into these functions, this test fails.
@@ -500,6 +535,7 @@ func fieldToGlobalVar(field string) string {
"LogFiles": "logFiles",
"DisableProfiles": "profilesDisabled",
"DisableUpdateSettings": "updateSettingsDisabled",
"DisableNetworks": "networksDisabled",
"ServiceEnvVars": "serviceEnvVars",
}
if v, ok := m[field]; ok {

View File

@@ -36,7 +36,10 @@ const (
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
disableSSHAuthFlag = "disable-ssh-auth"
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
jwtCacheTTLFlag = "jwt-cache-ttl"
// Alias for backward compatibility.
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
)
var (
@@ -61,7 +64,7 @@ var (
enableSSHLocalPortForward bool
enableSSHRemotePortForward bool
disableSSHAuth bool
sshJWTCacheTTL int
jwtCacheTTL int
)
func init() {
@@ -71,7 +74,9 @@ func init() {
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
upCmd.PersistentFlags().IntVar(&jwtCacheTTL, jwtCacheTTLFlag, 0, "JWT token cache TTL in seconds (0=disabled)")
upCmd.PersistentFlags().IntVar(&jwtCacheTTL, sshJWTCacheTTLFlag, 0, "JWT token cache TTL in seconds (alias for --jwt-cache-ttl)")
_ = upCmd.PersistentFlags().MarkDeprecated(sshJWTCacheTTLFlag, "use --jwt-cache-ttl instead")
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)

View File

@@ -13,6 +13,8 @@ import (
"github.com/netbirdio/management-integrations/integrations"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
@@ -100,9 +102,16 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
jobManager := job.NewJobManager(nil, store, peersmanager)
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
ctx := context.Background()
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
t.Fatal(err)
}
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
require.NoError(t, err)
settingsMockManager := settings.NewMockManager(ctrl)
@@ -113,12 +122,11 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
Return(&types.Settings{}, nil).
AnyTimes()
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := mgmt.BuildManager(ctx, config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
if err != nil {
t.Fatal(err)
}
@@ -152,7 +160,7 @@ func startClientDaemon(
s := grpc.NewServer()
server := client.New(ctx,
"", "", false, false)
"", "", false, false, false)
if err := server.Start(); err != nil {
t.Fatal(err)
}

View File

@@ -356,6 +356,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(serverSSHAllowedFlag).Changed {
req.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
req.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
req.EnableSSHRoot = &enableSSHRoot
}
@@ -371,9 +374,12 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(disableSSHAuthFlag).Changed {
req.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
req.SshJWTCacheTTL = &sshJWTCacheTTL32
if cmd.Flag(disableVNCAuthFlag).Changed {
req.DisableVNCAuth = &disableVNCAuth
}
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
jwtCacheTTL32 := int32(jwtCacheTTL)
req.SshJWTCacheTTL = &jwtCacheTTL32
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
@@ -458,6 +464,9 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
ic.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
ic.EnableSSHRoot = &enableSSHRoot
@@ -479,8 +488,12 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
if cmd.Flag(disableVNCAuthFlag).Changed {
ic.DisableVNCAuth = &disableVNCAuth
}
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &jwtCacheTTL
}
if cmd.Flag(interfaceNameFlag).Changed {
@@ -582,6 +595,9 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
if cmd.Flag(serverSSHAllowedFlag).Changed {
loginRequest.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
loginRequest.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
loginRequest.EnableSSHRoot = &enableSSHRoot
@@ -603,9 +619,13 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
if cmd.Flag(disableVNCAuthFlag).Changed {
loginRequest.DisableVNCAuth = &disableVNCAuth
}
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
jwtCacheTTL32 := int32(jwtCacheTTL)
loginRequest.SshJWTCacheTTL = &jwtCacheTTL32
}
if cmd.Flag(disableAutoConnectFlag).Changed {

271
client/cmd/vnc.go Normal file
View File

@@ -0,0 +1,271 @@
package cmd
import (
"context"
"encoding/binary"
"fmt"
"io"
"net"
"os"
"os/signal"
"os/user"
"strings"
"syscall"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/util"
)
var (
vncUsername string
vncHost string
vncMode string
vncListen string
vncNoBrowser bool
vncNoCache bool
)
func init() {
vncCmd.PersistentFlags().StringVar(&vncUsername, "user", "", "OS username for session mode")
vncCmd.PersistentFlags().StringVar(&vncMode, "mode", "attach", "Connection mode: attach (view current display) or session (virtual desktop)")
vncCmd.PersistentFlags().StringVar(&vncListen, "listen", "", "Start local VNC proxy on this address (e.g., :5900) for external VNC viewers")
vncCmd.PersistentFlags().BoolVar(&vncNoBrowser, noBrowserFlag, false, noBrowserDesc)
vncCmd.PersistentFlags().BoolVar(&vncNoCache, "no-cache", false, "Skip cached JWT token and force fresh authentication")
}
var vncCmd = &cobra.Command{
Use: "vnc [flags] [user@]host",
Short: "Connect to a NetBird peer via VNC",
Long: `Connect to a NetBird peer using VNC with JWT-based authentication.
The target peer must have the VNC server enabled.
Two modes are available:
- attach: view the current physical display (remote support)
- session: start a virtual desktop as the specified user (passwordless login)
Use --listen to start a local proxy for external VNC viewers:
netbird vnc --listen :5900 peer-hostname
vncviewer localhost:5900
Examples:
netbird vnc peer-hostname
netbird vnc --mode session --user alice peer-hostname
netbird vnc --listen :5900 peer-hostname`,
Args: cobra.MinimumNArgs(1),
RunE: vncFn,
}
func vncFn(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
cmd.SetOut(cmd.OutOrStdout())
logOutput := "console"
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
logOutput = firstLogFile
}
if err := util.InitLog(logLevel, logOutput); err != nil {
return fmt.Errorf("init log: %w", err)
}
if err := parseVNCHostArg(args[0]); err != nil {
return err
}
ctx := internal.CtxInitState(cmd.Context())
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
vncCtx, cancel := context.WithCancel(ctx)
errCh := make(chan error, 1)
go func() {
if err := runVNC(vncCtx, cmd); err != nil {
errCh <- err
}
cancel()
}()
select {
case <-sig:
cancel()
<-vncCtx.Done()
return nil
case err := <-errCh:
return err
case <-vncCtx.Done():
}
return nil
}
func parseVNCHostArg(arg string) error {
if strings.Contains(arg, "@") {
parts := strings.SplitN(arg, "@", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return fmt.Errorf("invalid user@host format")
}
if vncUsername == "" {
vncUsername = parts[0]
}
vncHost = parts[1]
if vncMode == "attach" {
vncMode = "session"
}
} else {
vncHost = arg
}
if vncMode == "session" && vncUsername == "" {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
vncUsername = sudoUser
} else if currentUser, err := user.Current(); err == nil {
vncUsername = currentUser.Username
}
}
return nil
}
func runVNC(ctx context.Context, cmd *cobra.Command) error {
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return fmt.Errorf("connect to daemon: %w", err)
}
defer func() { _ = grpcConn.Close() }()
daemonClient := proto.NewDaemonServiceClient(grpcConn)
if vncMode == "session" {
cmd.Printf("Connecting to %s@%s [session mode]...\n", vncUsername, vncHost)
} else {
cmd.Printf("Connecting to %s [attach mode]...\n", vncHost)
}
// Obtain JWT token. If the daemon has no SSO configured, proceed without one
// (the server will accept unauthenticated connections if --disable-vnc-auth is set).
var jwtToken string
hint := profilemanager.GetLoginHint()
var browserOpener func(string) error
if !vncNoBrowser {
browserOpener = util.OpenBrowser
}
token, err := nbssh.RequestJWTToken(ctx, daemonClient, nil, cmd.ErrOrStderr(), !vncNoCache, hint, browserOpener)
if err != nil {
log.Debugf("JWT authentication unavailable, connecting without token: %v", err)
} else {
jwtToken = token
log.Debug("JWT authentication successful")
}
// Connect to the VNC server on the standard port (5900). The peer's firewall
// DNATs 5900 -> 25900 (internal), so both ports work on the overlay network.
vncAddr := net.JoinHostPort(vncHost, "5900")
vncConn, err := net.DialTimeout("tcp", vncAddr, vncDialTimeout)
if err != nil {
return fmt.Errorf("connect to VNC at %s: %w", vncAddr, err)
}
defer vncConn.Close()
// Send session header with mode, username, and JWT.
if err := sendVNCHeader(vncConn, vncMode, vncUsername, jwtToken); err != nil {
return fmt.Errorf("send VNC header: %w", err)
}
cmd.Printf("VNC connected to %s\n", vncHost)
if vncListen != "" {
return runVNCLocalProxy(ctx, cmd, vncConn)
}
// No --listen flag: inform the user they need to use --listen for external viewers.
cmd.Printf("VNC tunnel established. Use --listen :5900 to proxy for local VNC viewers.\n")
cmd.Printf("Press Ctrl+C to disconnect.\n")
<-ctx.Done()
return nil
}
const vncDialTimeout = 15 * time.Second
// sendVNCHeader writes the NetBird VNC session header.
func sendVNCHeader(conn net.Conn, mode, username, jwt string) error {
var modeByte byte
if mode == "session" {
modeByte = 1
}
usernameBytes := []byte(username)
jwtBytes := []byte(jwt)
hdr := make([]byte, 3+len(usernameBytes)+2+len(jwtBytes))
hdr[0] = modeByte
binary.BigEndian.PutUint16(hdr[1:3], uint16(len(usernameBytes)))
off := 3
copy(hdr[off:], usernameBytes)
off += len(usernameBytes)
binary.BigEndian.PutUint16(hdr[off:off+2], uint16(len(jwtBytes)))
off += 2
copy(hdr[off:], jwtBytes)
_, err := conn.Write(hdr)
return err
}
// runVNCLocalProxy listens on the given address and proxies incoming
// connections to the already-established VNC tunnel.
func runVNCLocalProxy(ctx context.Context, cmd *cobra.Command, vncConn net.Conn) error {
listener, err := net.Listen("tcp", vncListen)
if err != nil {
return fmt.Errorf("listen on %s: %w", vncListen, err)
}
defer listener.Close()
cmd.Printf("VNC proxy listening on %s - connect with your VNC viewer\n", listener.Addr())
cmd.Printf("Press Ctrl+C to stop.\n")
go func() {
<-ctx.Done()
listener.Close()
}()
// Accept a single viewer connection. VNC is single-session: the RFB
// handshake completes on vncConn for the first viewer, so subsequent
// viewers would get a mid-stream connection. The loop handles transient
// accept errors until a valid connection arrives.
for {
clientConn, err := listener.Accept()
if err != nil {
select {
case <-ctx.Done():
return nil
default:
}
log.Debugf("accept VNC proxy client: %v", err)
continue
}
cmd.Printf("VNC viewer connected from %s\n", clientConn.RemoteAddr())
// Bidirectional copy.
done := make(chan struct{})
go func() {
io.Copy(vncConn, clientConn)
close(done)
}()
io.Copy(clientConn, vncConn)
<-done
clientConn.Close()
cmd.Printf("VNC viewer disconnected\n")
return nil
}
}

62
client/cmd/vnc_agent.go Normal file
View File

@@ -0,0 +1,62 @@
//go:build windows
package cmd
import (
"net/netip"
"os"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
var vncAgentPort string
func init() {
vncAgentCmd.Flags().StringVar(&vncAgentPort, "port", "15900", "Port for the VNC agent to listen on")
rootCmd.AddCommand(vncAgentCmd)
}
// vncAgentCmd runs a VNC server in the current user session, listening on
// localhost. It is spawned by the NetBird service (Session 0) via
// CreateProcessAsUser into the interactive console session.
var vncAgentCmd = &cobra.Command{
Use: "vnc-agent",
Short: "Run VNC capture agent (internal, spawned by service)",
Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error {
// Agent's stderr is piped to the service which relogs it.
// Use JSON format with caller info for structured parsing.
log.SetReportCaller(true)
log.SetFormatter(&log.JSONFormatter{})
log.SetOutput(os.Stderr)
sessionID := vncserver.GetCurrentSessionID()
log.Infof("VNC agent starting on 127.0.0.1:%s (session %d)", vncAgentPort, sessionID)
capturer := vncserver.NewDesktopCapturer()
injector := vncserver.NewWindowsInputInjector()
srv := vncserver.New(capturer, injector, "")
// Auth is handled by the service. The agent verifies a token on each
// connection to ensure only the service process can connect.
// The token is passed via environment variable to avoid exposing it
// in the process command line (visible via tasklist/wmic).
srv.SetDisableAuth(true)
srv.SetAgentToken(os.Getenv("NB_VNC_AGENT_TOKEN"))
port, err := netip.ParseAddrPort("127.0.0.1:" + vncAgentPort)
if err != nil {
return err
}
loopback := netip.PrefixFrom(netip.AddrFrom4([4]byte{127, 0, 0, 0}), 8)
if err := srv.Start(cmd.Context(), port, loopback); err != nil {
return err
}
<-cmd.Context().Done()
return srv.Stop()
},
}

16
client/cmd/vnc_flags.go Normal file
View File

@@ -0,0 +1,16 @@
package cmd
const (
serverVNCAllowedFlag = "allow-server-vnc"
disableVNCAuthFlag = "disable-vnc-auth"
)
var (
serverVNCAllowed bool
disableVNCAuth bool
)
func init() {
upCmd.PersistentFlags().BoolVar(&serverVNCAllowed, serverVNCAllowedFlag, false, "Allow embedded VNC server on peer")
upCmd.PersistentFlags().BoolVar(&disableVNCAuth, disableVNCAuthFlag, false, "Disable JWT authentication for VNC")
}

View File

@@ -0,0 +1,229 @@
package cmd
import (
"crypto/ecdh"
"crypto/rand"
"encoding/base64"
"fmt"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"text/tabwriter"
"time"
"github.com/spf13/cobra"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
"github.com/netbirdio/netbird/util"
)
var vncRecDir string
func init() {
vncRecPlayCmd.Flags().StringVar(&vncRecDir, "dir", "", "Recording directory (default: auto-detect)")
vncRecListCmd.Flags().StringVar(&vncRecDir, "dir", "", "Recording directory (default: auto-detect)")
vncRecCmd.AddCommand(vncRecListCmd)
vncRecCmd.AddCommand(vncRecPlayCmd)
vncRecCmd.AddCommand(vncRecKeygenCmd)
vncCmd.AddCommand(vncRecCmd)
}
var vncRecCmd = &cobra.Command{
Use: "rec",
Short: "Manage VNC session recordings",
}
var vncRecKeygenCmd = &cobra.Command{
Use: "keygen",
Short: "Generate an X25519 keypair for recording encryption",
Long: `Generates an X25519 keypair. Put the public key in management settings
(Session Recording > Encryption Key). Keep the private key safe for decrypting recordings.`,
RunE: vncRecKeygenFn,
}
var vncRecListCmd = &cobra.Command{
Use: "list",
Short: "List VNC session recordings",
RunE: vncRecListFn,
}
var vncRecPlayCmd = &cobra.Command{
Use: "play <file-or-name>",
Short: "Open a VNC recording in the browser",
Long: `Opens a browser-based player with playback controls:
play/pause, seek, speed (0.25x to 8x), keyboard shortcuts.
Examples:
netbird vnc rec play last
netbird vnc rec play 20260416-104433_vnc.rec`,
Args: cobra.ExactArgs(1),
RunE: vncRecPlayFn,
}
func vncRecListFn(cmd *cobra.Command, _ []string) error {
dir, err := resolveVNCRecDir()
if err != nil {
return err
}
entries, err := os.ReadDir(dir)
if err != nil {
return fmt.Errorf("read recording dir %s: %w", dir, err)
}
w := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 0, 2, ' ', 0)
fmt.Fprintln(w, "FILE\tSIZE\tDIMENSIONS\tUSER\tREMOTE\tMODE\tDATE")
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".rec") {
continue
}
filePath := filepath.Join(dir, entry.Name())
info, err := entry.Info()
if err != nil {
continue
}
header, err := vncserver.ReadRecordingHeader(filePath)
if err != nil {
fmt.Fprintf(w, "%s\t%s\t?\t?\t?\t?\t?\n", entry.Name(), vncFormatSize(info.Size()))
continue
}
fmt.Fprintf(w, "%s\t%s\t%dx%d\t%s\t%s\t%s\t%s\n",
entry.Name(),
vncFormatSize(info.Size()),
header.Width, header.Height,
header.Meta.User,
header.Meta.RemoteAddr,
header.Meta.Mode,
header.StartTime.Format("2006-01-02 15:04:05"),
)
}
return w.Flush()
}
func vncRecPlayFn(cmd *cobra.Command, args []string) error {
filePath, err := resolveVNCRecFile(args[0])
if err != nil {
return err
}
header, err := vncserver.ReadRecordingHeader(filePath)
if err != nil {
return fmt.Errorf("read recording: %w", err)
}
cmd.Printf("Recording: %s (%dx%d)\n", filepath.Base(filePath), header.Width, header.Height)
url, err := vncserver.ServeWebPlayer(filePath, "localhost:0")
if err != nil {
return fmt.Errorf("start web player: %w", err)
}
cmd.Printf("Player: %s\n", url)
if err := util.OpenBrowser(url); err != nil {
cmd.Printf("Open %s in your browser\n", url)
}
cmd.Printf("Press Ctrl+C to stop.\n")
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
<-sig
return nil
}
func vncRecKeygenFn(cmd *cobra.Command, _ []string) error {
priv, err := ecdh.X25519().GenerateKey(rand.Reader)
if err != nil {
return fmt.Errorf("generate key: %w", err)
}
privB64 := base64.StdEncoding.EncodeToString(priv.Bytes())
pubB64 := base64.StdEncoding.EncodeToString(priv.PublicKey().Bytes())
cmd.Printf("Private key (keep secret, for decrypting recordings):\n %s\n\n", privB64)
cmd.Printf("Public key (paste into management Settings > Session Recording > Encryption Key):\n %s\n", pubB64)
return nil
}
func vncFormatSize(size int64) string {
switch {
case size >= 1<<20:
return fmt.Sprintf("%.1fM", float64(size)/float64(1<<20))
case size >= 1<<10:
return fmt.Sprintf("%.1fK", float64(size)/float64(1<<10))
default:
return fmt.Sprintf("%dB", size)
}
}
func resolveVNCRecDir() (string, error) {
if vncRecDir != "" {
return vncRecDir, nil
}
candidates := []string{
"/var/lib/netbird/recordings/vnc",
filepath.Join(os.Getenv("HOME"), ".netbird/recordings/vnc"),
}
for _, dir := range candidates {
if fi, err := os.Stat(dir); err == nil && fi.IsDir() {
return dir, nil
}
}
return "", fmt.Errorf("no VNC recording directory found; use --dir to specify")
}
func resolveVNCRecFile(arg string) (string, error) {
if strings.Contains(arg, "/") || strings.Contains(arg, string(os.PathSeparator)) {
return arg, nil
}
dir, err := resolveVNCRecDir()
if err != nil && arg != "last" {
return arg, nil
}
if arg == "last" {
if err != nil {
return "", err
}
return findLatestRec(dir)
}
full := filepath.Join(dir, arg)
if _, err := os.Stat(full); err == nil {
return full, nil
}
return arg, nil
}
func findLatestRec(dir string) (string, error) {
entries, err := os.ReadDir(dir)
if err != nil {
return "", fmt.Errorf("read dir: %w", err)
}
var latest string
var latestTime time.Time
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".rec") {
continue
}
info, err := entry.Info()
if err != nil {
continue
}
if info.ModTime().After(latestTime) {
latestTime = info.ModTime()
latest = filepath.Join(dir, entry.Name())
}
}
if latest == "" {
return "", fmt.Errorf("no recordings found in %s", dir)
}
return latest, nil
}

View File

@@ -56,6 +56,13 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
}
// Native firewall handles packet filtering, but the userspace WireGuard bind
// needs a device filter for DNS interception hooks. Install a minimal
// hooks-only filter that passes all traffic through to the kernel firewall.
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
}
return fm, nil
}

View File

@@ -21,6 +21,10 @@ const (
// rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT"
// mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent
// external DNAT from bypassing ACL rules.
mangleFwdKey = "MANGLE-FORWARD"
)
type aclEntries map[string][][]string
@@ -274,6 +278,12 @@ func (m *aclManager) cleanChains() error {
}
}
for _, rule := range m.entries[mangleFwdKey] {
if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil {
log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err)
}
}
for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := m.flushIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) {
@@ -303,6 +313,10 @@ func (m *aclManager) createDefaultChains() error {
}
for chainName, rules := range m.entries {
// mangle FORWARD guard rules are handled separately below
if chainName == mangleFwdKey {
continue
}
for _, rule := range rules {
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
@@ -322,6 +336,13 @@ func (m *aclManager) createDefaultChains() error {
}
clear(m.optionalEntries)
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
for _, rule := range m.entries[mangleFwdKey] {
if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil {
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
}
}
return nil
}
@@ -343,6 +364,22 @@ func (m *aclManager) seedInitialEntries() {
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
// Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it
// traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD
// can be inserted above ours. Mangle runs before filter, so these guard rules enforce the
// ACL mark check where it cannot be overridden.
m.appendToEntries(mangleFwdKey, []string{
"-i", m.wgIface.Name(),
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
"-j", "ACCEPT",
})
m.appendToEntries(mangleFwdKey, []string{
"-i", m.wgIface.Name(),
"-m", "conntrack", "--ctstate", "DNAT",
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
"-j", "DROP",
})
}
func (m *aclManager) seedInitialOptionalEntries() {

View File

@@ -0,0 +1,37 @@
package common
import (
"net/netip"
"sync/atomic"
)
// PacketHook stores a registered hook for a specific IP:port.
type PacketHook struct {
IP netip.Addr
Port uint16
Fn func([]byte) bool
}
// HookMatches checks if a packet's destination matches the hook and invokes it.
func HookMatches(h *PacketHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
if h == nil {
return false
}
if h.IP == dstIP && h.Port == dport {
return h.Fn(packetData)
}
return false
}
// SetHook atomically stores a hook, handling nil removal.
func SetHook(ptr *atomic.Pointer[PacketHook], ip netip.Addr, dPort uint16, hook func([]byte) bool) {
if hook == nil {
ptr.Store(nil)
return
}
ptr.Store(&PacketHook{
IP: ip,
Port: dPort,
Fn: hook,
})
}

View File

@@ -142,15 +142,8 @@ type Manager struct {
mssClampEnabled bool
// Only one hook per protocol is supported. Outbound direction only.
udpHookOut atomic.Pointer[packetHook]
tcpHookOut atomic.Pointer[packetHook]
}
// packetHook stores a registered hook for a specific IP:port.
type packetHook struct {
ip netip.Addr
port uint16
fn func([]byte) bool
udpHookOut atomic.Pointer[common.PacketHook]
tcpHookOut atomic.Pointer[common.PacketHook]
}
// decoder for packages
@@ -912,21 +905,11 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
}
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
return common.HookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
}
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
}
func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
if h == nil {
return false
}
if h.ip == dstIP && h.port == dport {
return h.fn(packetData)
}
return false
return common.HookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
}
// filterInbound implements filtering logic for incoming packets.
@@ -1337,28 +1320,12 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
if hook == nil {
m.udpHookOut.Store(nil)
return
}
m.udpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
common.SetHook(&m.udpHookOut, ip, dPort, hook)
}
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
if hook == nil {
m.tcpHookOut.Store(nil)
return
}
m.tcpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
common.SetHook(&m.tcpHookOut, ip, dPort, hook)
}
// SetLogLevel sets the log level for the firewall manager

View File

@@ -202,9 +202,9 @@ func TestSetUDPPacketHook(t *testing.T) {
h := manager.udpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(8000), h.port)
assert.True(t, h.fn(nil))
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
assert.Equal(t, uint16(8000), h.Port)
assert.True(t, h.Fn(nil))
assert.True(t, called)
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
@@ -226,9 +226,9 @@ func TestSetTCPPacketHook(t *testing.T) {
h := manager.tcpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(53), h.port)
assert.True(t, h.fn(nil))
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
assert.Equal(t, uint16(53), h.Port)
assert.True(t, h.Fn(nil))
assert.True(t, called)
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)

View File

@@ -0,0 +1,90 @@
package uspfilter
import (
"encoding/binary"
"net/netip"
"sync/atomic"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/iface/device"
)
const (
ipv4HeaderMinLen = 20
ipv4ProtoOffset = 9
ipv4FlagsOffset = 6
ipv4DstOffset = 16
ipProtoUDP = 17
ipProtoTCP = 6
ipv4FragOffMask = 0x1fff
// dstPortOffset is the offset of the destination port within a UDP or TCP header.
dstPortOffset = 2
)
// HooksFilter is a minimal packet filter that only handles outbound DNS hooks.
// It is installed on the WireGuard interface when the userspace bind is active
// but a full firewall filter (Manager) is not needed because a native kernel
// firewall (nftables/iptables) handles packet filtering.
type HooksFilter struct {
udpHook atomic.Pointer[common.PacketHook]
tcpHook atomic.Pointer[common.PacketHook]
}
var _ device.PacketFilter = (*HooksFilter)(nil)
// FilterOutbound checks outbound packets for DNS hook matches.
// Only IPv4 packets matching the registered hook IP:port are intercepted.
// IPv6 and non-IP packets pass through unconditionally.
func (f *HooksFilter) FilterOutbound(packetData []byte, _ int) bool {
if len(packetData) < ipv4HeaderMinLen {
return false
}
// Only process IPv4 packets, let everything else pass through.
if packetData[0]>>4 != 4 {
return false
}
ihl := int(packetData[0]&0x0f) * 4
if ihl < ipv4HeaderMinLen || len(packetData) < ihl+4 {
return false
}
// Skip non-first fragments: they don't carry L4 headers.
flagsAndOffset := binary.BigEndian.Uint16(packetData[ipv4FlagsOffset : ipv4FlagsOffset+2])
if flagsAndOffset&ipv4FragOffMask != 0 {
return false
}
dstIP, ok := netip.AddrFromSlice(packetData[ipv4DstOffset : ipv4DstOffset+4])
if !ok {
return false
}
proto := packetData[ipv4ProtoOffset]
dstPort := binary.BigEndian.Uint16(packetData[ihl+dstPortOffset : ihl+dstPortOffset+2])
switch proto {
case ipProtoUDP:
return common.HookMatches(f.udpHook.Load(), dstIP, dstPort, packetData)
case ipProtoTCP:
return common.HookMatches(f.tcpHook.Load(), dstIP, dstPort, packetData)
default:
return false
}
}
// FilterInbound allows all inbound packets (native firewall handles filtering).
func (f *HooksFilter) FilterInbound([]byte, int) bool {
return false
}
// SetUDPPacketHook registers the UDP packet hook.
func (f *HooksFilter) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
common.SetHook(&f.udpHook, ip, dPort, hook)
}
// SetTCPPacketHook registers the TCP packet hook.
func (f *HooksFilter) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
common.SetHook(&f.tcpHook, ip, dPort, hook)
}

View File

@@ -217,7 +217,6 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
// Close closes the tunnel interface
func (w *WGIface) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
var result *multierror.Error
@@ -225,7 +224,15 @@ func (w *WGIface) Close() error {
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
}
if err := w.tun.Close(); err != nil {
// Release w.mu before calling w.tun.Close(): the underlying
// wireguard-go device.Close() waits for its send/receive goroutines
// to drain. Some of those goroutines re-enter WGIface methods that
// take w.mu (e.g. the packet filter DNS hook calls GetDevice()), so
// holding the mutex here would deadlock the shutdown path.
tun := w.tun
w.mu.Unlock()
if err := tun.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
}

View File

@@ -0,0 +1,113 @@
//go:build !android
package iface
import (
"errors"
"sync"
"testing"
"time"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// fakeTunDevice implements WGTunDevice and lets the test control when
// Close() returns. It mimics the wireguard-go shutdown path, which blocks
// until its goroutines drain. Some of those goroutines (e.g. the packet
// filter DNS hook in client/internal/dns) call back into WGIface, so if
// WGIface.Close() held w.mu across tun.Close() the shutdown would
// deadlock.
type fakeTunDevice struct {
closeStarted chan struct{}
unblockClose chan struct{}
}
func (f *fakeTunDevice) Create() (device.WGConfigurer, error) {
return nil, errors.New("not implemented")
}
func (f *fakeTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, errors.New("not implemented")
}
func (f *fakeTunDevice) UpdateAddr(wgaddr.Address) error { return nil }
func (f *fakeTunDevice) WgAddress() wgaddr.Address { return wgaddr.Address{} }
func (f *fakeTunDevice) MTU() uint16 { return DefaultMTU }
func (f *fakeTunDevice) DeviceName() string { return "nb-close-test" }
func (f *fakeTunDevice) FilteredDevice() *device.FilteredDevice { return nil }
func (f *fakeTunDevice) Device() *wgdevice.Device { return nil }
func (f *fakeTunDevice) GetNet() *netstack.Net { return nil }
func (f *fakeTunDevice) GetICEBind() device.EndpointManager { return nil }
func (f *fakeTunDevice) Close() error {
close(f.closeStarted)
<-f.unblockClose
return nil
}
type fakeProxyFactory struct{}
func (fakeProxyFactory) GetProxy() wgproxy.Proxy { return nil }
func (fakeProxyFactory) GetProxyPort() uint16 { return 0 }
func (fakeProxyFactory) Free() error { return nil }
// TestWGIface_CloseReleasesMutexBeforeTunClose guards against a deadlock
// that surfaces as a macOS test-timeout in
// TestDNSPermanent_updateUpstream: WGIface.Close() used to hold w.mu
// while waiting for the wireguard-go device goroutines to finish, and
// one of those goroutines (the DNS filter hook) calls back into
// WGIface.GetDevice() which needs the same mutex. The fix is to drop
// the lock before tun.Close() returns control.
func TestWGIface_CloseReleasesMutexBeforeTunClose(t *testing.T) {
tun := &fakeTunDevice{
closeStarted: make(chan struct{}),
unblockClose: make(chan struct{}),
}
w := &WGIface{
tun: tun,
wgProxyFactory: fakeProxyFactory{},
}
closeDone := make(chan error, 1)
go func() {
closeDone <- w.Close()
}()
select {
case <-tun.closeStarted:
case <-time.After(2 * time.Second):
close(tun.unblockClose)
t.Fatal("tun.Close() was never invoked")
}
// Simulate the WireGuard read goroutine calling back into WGIface
// via the packet filter's DNS hook. If Close() still held w.mu
// during tun.Close(), this would block until the test timeout.
getDeviceDone := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = w.GetDevice()
close(getDeviceDone)
}()
select {
case <-getDeviceDone:
case <-time.After(2 * time.Second):
close(tun.unblockClose)
wg.Wait()
t.Fatal("GetDevice() deadlocked while WGIface.Close was closing the tun")
}
close(tun.unblockClose)
select {
case <-closeDone:
case <-time.After(2 * time.Second):
t.Fatal("WGIface.Close() never returned after the tun was unblocked")
}
}

View File

@@ -171,7 +171,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
}
if u.address.Network.Contains(a) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}
@@ -181,7 +181,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
u.addrCache.Store(addr.String(), isRouted)
if isRouted {
// Extra log, as the error only shows up with ICE logging enabled
log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix)
log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
}
}

View File

@@ -315,6 +315,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
a.config.RosenpassEnabled,
a.config.RosenpassPermissive,
a.config.ServerSSHAllowed,
a.config.ServerVNCAllowed,
a.config.DisableClientRoutes,
a.config.DisableServerRoutes,
a.config.DisableDNS,
@@ -327,6 +328,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
a.config.EnableSSHLocalPortForwarding,
a.config.EnableSSHRemotePortForwarding,
a.config.DisableSSHAuth,
a.config.DisableVNCAuth,
)
}

View File

@@ -94,6 +94,7 @@ func (c *ConnectClient) RunOnAndroid(
dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener,
stateFilePath string,
cacheDir string,
) error {
// in case of non Android os these variables will be nil
mobileDependency := MobileDependency{
@@ -103,6 +104,7 @@ func (c *ConnectClient) RunOnAndroid(
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
StateFilePath: stateFilePath,
TempDir: cacheDir,
}
return c.run(mobileDependency, nil, "")
}
@@ -338,6 +340,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Error(err)
return wrapErr(err)
}
engineConfig.TempDir = mobileDependency.TempDir
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
c.statusRecorder.SetRelayMgr(relayManager)
@@ -543,11 +546,13 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
ServerVNCAllowed: config.ServerVNCAllowed != nil && *config.ServerVNCAllowed,
EnableSSHRoot: config.EnableSSHRoot,
EnableSSHSFTP: config.EnableSSHSFTP,
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
DisableSSHAuth: config.DisableSSHAuth,
DisableVNCAuth: config.DisableVNCAuth,
DNSRouteInterval: config.DNSRouteInterval,
DisableClientRoutes: config.DisableClientRoutes,
@@ -624,6 +629,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.ServerVNCAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
@@ -636,6 +642,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
config.DisableVNCAuth,
)
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
}

View File

@@ -16,7 +16,6 @@ import (
"path/filepath"
"runtime"
"runtime/pprof"
"slices"
"sort"
"strings"
"time"
@@ -31,7 +30,6 @@ import (
"github.com/netbirdio/netbird/client/internal/updater/installer"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
const readmeContent = `Netbird debug bundle
@@ -234,6 +232,7 @@ type BundleGenerator struct {
statusRecorder *peer.Status
syncResponse *mgmProto.SyncResponse
logPath string
tempDir string
cpuProfile []byte
refreshStatus func() // Optional callback to refresh status before bundle generation
clientMetrics MetricsExporter
@@ -256,6 +255,7 @@ type GeneratorDependencies struct {
StatusRecorder *peer.Status
SyncResponse *mgmProto.SyncResponse
LogPath string
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
CPUProfile []byte
RefreshStatus func() // Optional callback to refresh status before bundle generation
ClientMetrics MetricsExporter
@@ -275,6 +275,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
statusRecorder: deps.StatusRecorder,
syncResponse: deps.SyncResponse,
logPath: deps.LogPath,
tempDir: deps.TempDir,
cpuProfile: deps.CPUProfile,
refreshStatus: deps.RefreshStatus,
clientMetrics: deps.ClientMetrics,
@@ -287,7 +288,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
// Generate creates a debug bundle and returns the location.
func (g *BundleGenerator) Generate() (resp string, err error) {
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
bundlePath, err := os.CreateTemp(g.tempDir, "netbird.debug.*.zip")
if err != nil {
return "", fmt.Errorf("create zip file: %w", err)
}
@@ -373,15 +374,8 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add wg show output: %v", err)
}
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("failed to add systemd logs as fallback: %v", err)
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("failed to add systemd logs: %v", err)
if err := g.addPlatformLog(); err != nil {
log.Errorf("failed to add logs to debug bundle: %v", err)
}
if err := g.addUpdateLogs(); err != nil {

View File

@@ -0,0 +1,41 @@
//go:build android
package debug
import (
"fmt"
"io"
"os/exec"
log "github.com/sirupsen/logrus"
)
func (g *BundleGenerator) addPlatformLog() error {
cmd := exec.Command("/system/bin/logcat", "-d")
stdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("logcat stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("start logcat: %w", err)
}
var logReader io.Reader = stdout
if g.anonymize {
var pw *io.PipeWriter
logReader, pw = io.Pipe()
go anonymizeLog(stdout, pw, g.anonymizer)
}
if err := g.addFileToZip(logReader, "logcat.txt"); err != nil {
return fmt.Errorf("add logcat to zip: %w", err)
}
if err := cmd.Wait(); err != nil {
return fmt.Errorf("wait logcat: %w", err)
}
log.Debug("added logcat output to debug bundle")
return nil
}

View File

@@ -0,0 +1,25 @@
//go:build !android
package debug
import (
"slices"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util"
)
func (g *BundleGenerator) addPlatformLog() error {
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
return err
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
return err
}
return nil
}

View File

@@ -117,11 +117,13 @@ type EngineConfig struct {
RosenpassPermissive bool
ServerSSHAllowed bool
ServerVNCAllowed bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
DisableVNCAuth *bool
DNSRouteInterval time.Duration
@@ -140,6 +142,7 @@ type EngineConfig struct {
ProfileConfig *profilemanager.Config
LogPath string
TempDir string
}
// EngineServices holds the external service dependencies required by the Engine.
@@ -197,6 +200,7 @@ type Engine struct {
networkMonitor *networkmonitor.NetworkMonitor
sshServer sshServer
vncSrv vncServer
statusRecorder *peer.Status
@@ -310,6 +314,10 @@ func (e *Engine) Stop() error {
log.Warnf("failed to stop SSH server: %v", err)
}
if err := e.stopVNCServer(); err != nil {
log.Warnf("failed to stop VNC server: %v", err)
}
e.cleanupSSHConfig()
if e.ingressGatewayMgr != nil {
@@ -997,6 +1005,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -1009,6 +1018,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
e.config.DisableVNCAuth,
)
if err := e.mgmClient.SyncMeta(info); err != nil {
@@ -1036,6 +1046,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
}
}
if err := e.updateVNC(conf.GetSshConfig()); err != nil {
log.Warnf("failed handling VNC server setup: %v", err)
}
state := e.statusRecorder.GetLocalPeerState()
state.IP = e.wgInterface.Address().String()
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
@@ -1095,6 +1109,7 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
StatusRecorder: e.statusRecorder,
SyncResponse: syncResponse,
LogPath: e.config.LogPath,
TempDir: e.config.TempDir,
ClientMetrics: e.clientMetrics,
RefreshStatus: func() {
e.RunHealthProbes(true)
@@ -1137,6 +1152,7 @@ func (e *Engine) receiveManagementEvents() {
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -1149,6 +1165,7 @@ func (e *Engine) receiveManagementEvents() {
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
e.config.DisableVNCAuth,
)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
@@ -1323,6 +1340,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
// VNC auth: use dedicated VNCAuth if present.
if vncAuth := networkMap.GetVncAuth(); vncAuth != nil {
e.updateVNCServerAuth(vncAuth)
}
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
@@ -1732,6 +1754,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -1744,6 +1767,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
e.config.DisableVNCAuth,
)
netMap, err := e.mgmClient.GetNetworkMap(info)

View File

@@ -55,6 +55,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -1634,7 +1635,12 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
@@ -1656,7 +1662,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
if err != nil {
return nil, "", err
}

View File

@@ -0,0 +1,309 @@
package internal
import (
"context"
"errors"
"fmt"
"net/netip"
"os"
"path/filepath"
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
const envVNCForceRecording = "NB_VNC_FORCE_RECORDING"
const (
vncExternalPort uint16 = 5900
vncInternalPort uint16 = 25900
)
type vncServer interface {
Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error
Stop() error
}
func (e *Engine) setupVNCPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
return fmt.Errorf("add VNC port redirection: %w", err)
}
log.Infof("VNC port redirection: %s:%d -> %s:%d", localAddr, vncExternalPort, localAddr, vncInternalPort)
return nil
}
func (e *Engine) cleanupVNCPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
return fmt.Errorf("remove VNC port redirection: %w", err)
}
return nil
}
// updateVNC handles starting/stopping the VNC server based on the config flag.
// sshConf provides the JWT identity provider config (shared with SSH).
func (e *Engine) updateVNC(sshConf *mgmProto.SSHConfig) error {
if !e.config.ServerVNCAllowed {
if e.vncSrv != nil {
log.Info("VNC server disabled, stopping")
}
return e.stopVNCServer()
}
if e.config.BlockInbound {
log.Info("VNC server disabled because inbound connections are blocked")
return e.stopVNCServer()
}
if e.vncSrv != nil {
// Update JWT config on existing server in case management sent new config.
e.updateVNCServerJWT(sshConf)
return nil
}
return e.startVNCServer(sshConf)
}
func (e *Engine) startVNCServer(sshConf *mgmProto.SSHConfig) error {
if e.wgInterface == nil {
return errors.New("wg interface not initialized")
}
capturer, injector := newPlatformVNC()
if capturer == nil || injector == nil {
log.Debug("VNC server not supported on this platform")
return nil
}
netbirdIP := e.wgInterface.Address().IP
srv := vncserver.New(capturer, injector, "")
if vncNeedsServiceMode() {
log.Info("VNC: running in Session 0, enabling service mode (agent proxy)")
srv.SetServiceMode(true)
}
// Configure VNC authentication.
if e.config.DisableVNCAuth != nil && *e.config.DisableVNCAuth {
log.Info("VNC: authentication disabled by config")
srv.SetDisableAuth(true)
} else if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
audiences := protoJWT.GetAudiences()
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
audiences = []string{protoJWT.GetAudience()}
}
srv.SetJWTConfig(&vncserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audiences: audiences,
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
})
log.Debugf("VNC: JWT authentication configured (issuer=%s)", protoJWT.GetIssuer())
}
e.configureVNCRecording(srv, sshConf)
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
srv.SetNetstackNet(netstackNet)
}
listenAddr := netip.AddrPortFrom(netbirdIP, vncInternalPort)
network := e.wgInterface.Address().Network
if err := srv.Start(e.ctx, listenAddr, network); err != nil {
return fmt.Errorf("start VNC server: %w", err)
}
e.vncSrv = srv
if registrar, ok := e.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.RegisterNetstackService(nftypes.TCP, vncInternalPort)
log.Debugf("registered VNC service for TCP:%d", vncInternalPort)
}
if err := e.setupVNCPortRedirection(); err != nil {
log.Warnf("setup VNC port redirection: %v", err)
}
log.Info("VNC server enabled")
return nil
}
// configureVNCRecording enables session recording on the VNC server from the
// management-supplied settings. The env var NB_VNC_FORCE_RECORDING overrides
// the API for local development: when set, recording is always enabled and
// writes into that directory. Otherwise recordings go next to the state file
// under vnc-recordings/.
func (e *Engine) configureVNCRecording(srv *vncserver.Server, sshConf *mgmProto.SSHConfig) {
recDir := os.Getenv(envVNCForceRecording)
apiEnabled := sshConf.GetEnableRecording()
if recDir == "" && !apiEnabled {
log.Debugf("VNC recording disabled (env=%q, api=%v)", recDir, apiEnabled)
return
}
if recDir == "" {
base := e.defaultRecordingBase()
if base == "" {
log.Warn("VNC recording requested by management but no state directory is available")
return
}
recDir = filepath.Join(base, "vnc-recordings")
} else {
recDir = filepath.Join(recDir, "vnc")
}
srv.SetRecordingDir(recDir)
log.Infof("VNC recording enabled (dir=%s, source=%s)", recDir, recordingSource(apiEnabled))
encKey := string(sshConf.GetRecordingEncryptionKey())
if encKey == "" {
encKey = os.Getenv("NB_VNC_RECORDING_ENCRYPTION_KEY")
}
if encKey != "" {
srv.SetRecordingEncryptionKey(encKey)
log.Info("VNC recording encryption enabled")
}
}
func (e *Engine) defaultRecordingBase() string {
if e.stateManager == nil {
return ""
}
p := e.stateManager.FilePath()
if p == "" {
return ""
}
return filepath.Dir(p)
}
func recordingSource(api bool) string {
if api {
return "management"
}
return "env"
}
// updateVNCServerJWT configures the JWT validation for the VNC server using
// the same JWT config as SSH (same identity provider).
func (e *Engine) updateVNCServerJWT(sshConf *mgmProto.SSHConfig) {
if e.vncSrv == nil {
return
}
vncSrv, ok := e.vncSrv.(*vncserver.Server)
if !ok {
return
}
if e.config.DisableVNCAuth != nil && *e.config.DisableVNCAuth {
vncSrv.SetDisableAuth(true)
return
}
protoJWT := sshConf.GetJwtConfig()
if protoJWT == nil {
return
}
audiences := protoJWT.GetAudiences()
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
audiences = []string{protoJWT.GetAudience()}
}
vncSrv.SetJWTConfig(&vncserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audiences: audiences,
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
})
}
// updateVNCServerAuth updates VNC fine-grained access control from management.
func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
if vncAuth == nil || e.vncSrv == nil {
return
}
vncSrv, ok := e.vncSrv.(*vncserver.Server)
if !ok {
return
}
protoUsers := vncAuth.GetAuthorizedUsers()
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
for i, hash := range protoUsers {
if len(hash) != 16 {
log.Warnf("invalid VNC auth hash length %d, expected 16", len(hash))
return
}
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
}
machineUsers := make(map[string][]uint32)
for osUser, indexes := range vncAuth.GetMachineUsers() {
machineUsers[osUser] = indexes.GetIndexes()
}
vncSrv.UpdateVNCAuth(&sshauth.Config{
UserIDClaim: vncAuth.GetUserIDClaim(),
AuthorizedUsers: authorizedUsers,
MachineUsers: machineUsers,
})
}
// GetVNCServerStatus returns whether the VNC server is running.
func (e *Engine) GetVNCServerStatus() bool {
return e.vncSrv != nil
}
func (e *Engine) stopVNCServer() error {
if e.vncSrv == nil {
return nil
}
if err := e.cleanupVNCPortRedirection(); err != nil {
log.Warnf("cleanup VNC port redirection: %v", err)
}
if registrar, ok := e.firewall.(interface {
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.UnregisterNetstackService(nftypes.TCP, vncInternalPort)
}
log.Info("stopping VNC server")
err := e.vncSrv.Stop()
e.vncSrv = nil
if err != nil {
return fmt.Errorf("stop VNC server: %w", err)
}
return nil
}

View File

@@ -0,0 +1,23 @@
//go:build darwin && !ios
package internal
import (
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
capturer := vncserver.NewMacPoller()
injector, err := vncserver.NewMacInputInjector()
if err != nil {
log.Debugf("VNC: macOS input injector: %v", err)
return capturer, &vncserver.StubInputInjector{}
}
return capturer, injector
}
func vncNeedsServiceMode() bool {
return false
}

View File

@@ -0,0 +1,13 @@
//go:build !windows && !darwin && !freebsd && !(linux && !android)
package internal
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
return nil, nil
}
func vncNeedsServiceMode() bool {
return false
}

View File

@@ -0,0 +1,13 @@
//go:build windows
package internal
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector()
}
func vncNeedsServiceMode() bool {
return vncserver.GetCurrentSessionID() == 0
}

View File

@@ -0,0 +1,23 @@
//go:build (linux && !android) || freebsd
package internal
import (
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
capturer := vncserver.NewX11Poller("")
injector, err := vncserver.NewX11InputInjector("")
if err != nil {
log.Debugf("VNC: X11 input injector: %v", err)
return capturer, &vncserver.StubInputInjector{}
}
return capturer, injector
}
func vncNeedsServiceMode() bool {
return false
}

View File

@@ -22,4 +22,8 @@ type MobileDependency struct {
DnsManager dns.IosDnsManager
FileDescriptor int32
StateFilePath string
// TempDir is a writable directory for temporary files (e.g., debug bundle zip).
// On Android, this should be set to the app's cache directory.
TempDir string
}

View File

@@ -7,7 +7,9 @@ import (
"fmt"
"net/netip"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
nfct "github.com/ti-mo/conntrack"
@@ -17,31 +19,64 @@ import (
nbnet "github.com/netbirdio/netbird/client/net"
)
const defaultChannelSize = 100
const (
defaultChannelSize = 100
reconnectInitInterval = 5 * time.Second
reconnectMaxInterval = 5 * time.Minute
reconnectRandomization = 0.5
)
// listener abstracts a netlink conntrack connection for testability.
type listener interface {
Listen(evChan chan<- nfct.Event, numWorkers uint8, groups []netfilter.NetlinkGroup) (chan error, error)
Close() error
}
// ConnTrack manages kernel-based conntrack events
type ConnTrack struct {
flowLogger nftypes.FlowLogger
iface nftypes.IFaceMapper
conn *nfct.Conn
conn listener
mux sync.Mutex
dial func() (listener, error)
instanceID uuid.UUID
started bool
done chan struct{}
sysctlModified bool
}
// DialFunc is a constructor for netlink conntrack connections.
type DialFunc func() (listener, error)
// Option configures a ConnTrack instance.
type Option func(*ConnTrack)
// WithDialer overrides the default netlink dialer, primarily for testing.
func WithDialer(dial DialFunc) Option {
return func(c *ConnTrack) {
c.dial = dial
}
}
func defaultDial() (listener, error) {
return nfct.Dial(nil)
}
// New creates a new connection tracker that interfaces with the kernel's conntrack system
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack {
return &ConnTrack{
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper, opts ...Option) *ConnTrack {
ct := &ConnTrack{
flowLogger: flowLogger,
iface: iface,
instanceID: uuid.New(),
started: false,
dial: defaultDial,
done: make(chan struct{}, 1),
}
for _, opt := range opts {
opt(ct)
}
return ct
}
// Start begins tracking connections by listening for conntrack events. This method is idempotent.
@@ -59,8 +94,9 @@ func (c *ConnTrack) Start(enableCounters bool) error {
c.EnableAccounting()
}
conn, err := nfct.Dial(nil)
conn, err := c.dial()
if err != nil {
c.RestoreAccounting()
return fmt.Errorf("dial conntrack: %w", err)
}
c.conn = conn
@@ -76,9 +112,16 @@ func (c *ConnTrack) Start(enableCounters bool) error {
log.Errorf("Error closing conntrack connection: %v", err)
}
c.conn = nil
c.RestoreAccounting()
return fmt.Errorf("start conntrack listener: %w", err)
}
// Drain any stale stop signal from a previous cycle.
select {
case <-c.done:
default:
}
c.started = true
go c.receiverRoutine(events, errChan)
@@ -92,17 +135,98 @@ func (c *ConnTrack) receiverRoutine(events chan nfct.Event, errChan chan error)
case event := <-events:
c.handleEvent(event)
case err := <-errChan:
log.Errorf("Error from conntrack event listener: %v", err)
if err := c.conn.Close(); err != nil {
log.Errorf("Error closing conntrack connection: %v", err)
if events, errChan = c.handleListenerError(err); events == nil {
return
}
return
case <-c.done:
return
}
}
}
// handleListenerError closes the failed connection and attempts to reconnect.
// Returns new channels on success, or nil if shutdown was requested.
func (c *ConnTrack) handleListenerError(err error) (chan nfct.Event, chan error) {
log.Warnf("conntrack event listener failed: %v", err)
c.closeConn()
return c.reconnect()
}
func (c *ConnTrack) closeConn() {
c.mux.Lock()
defer c.mux.Unlock()
if c.conn != nil {
if err := c.conn.Close(); err != nil {
log.Debugf("close conntrack connection: %v", err)
}
c.conn = nil
}
}
// reconnect attempts to re-establish the conntrack netlink listener with exponential backoff.
// Returns new channels on success, or nil if shutdown was requested.
func (c *ConnTrack) reconnect() (chan nfct.Event, chan error) {
bo := &backoff.ExponentialBackOff{
InitialInterval: reconnectInitInterval,
RandomizationFactor: reconnectRandomization,
Multiplier: backoff.DefaultMultiplier,
MaxInterval: reconnectMaxInterval,
MaxElapsedTime: 0, // retry indefinitely
Clock: backoff.SystemClock,
}
bo.Reset()
for {
delay := bo.NextBackOff()
log.Infof("reconnecting conntrack listener in %s", delay)
select {
case <-c.done:
c.mux.Lock()
c.started = false
c.mux.Unlock()
return nil, nil
case <-time.After(delay):
}
conn, err := c.dial()
if err != nil {
log.Warnf("reconnect conntrack dial: %v", err)
continue
}
events := make(chan nfct.Event, defaultChannelSize)
errChan, err := conn.Listen(events, 1, []netfilter.NetlinkGroup{
netfilter.GroupCTNew,
netfilter.GroupCTDestroy,
})
if err != nil {
log.Warnf("reconnect conntrack listen: %v", err)
if closeErr := conn.Close(); closeErr != nil {
log.Debugf("close conntrack connection: %v", closeErr)
}
continue
}
c.mux.Lock()
if !c.started {
// Stop() ran while we were reconnecting.
c.mux.Unlock()
if closeErr := conn.Close(); closeErr != nil {
log.Debugf("close conntrack connection: %v", closeErr)
}
return nil, nil
}
c.conn = conn
c.mux.Unlock()
log.Infof("conntrack listener reconnected successfully")
return events, errChan
}
}
// Stop stops the connection tracking. This method is idempotent.
func (c *ConnTrack) Stop() {
c.mux.Lock()
@@ -136,23 +260,27 @@ func (c *ConnTrack) Close() error {
c.mux.Lock()
defer c.mux.Unlock()
if c.started {
select {
case c.done <- struct{}{}:
default:
}
if !c.started {
return nil
}
select {
case c.done <- struct{}{}:
default:
}
c.started = false
var closeErr error
if c.conn != nil {
err := c.conn.Close()
closeErr = c.conn.Close()
c.conn = nil
c.started = false
}
c.RestoreAccounting()
c.RestoreAccounting()
if err != nil {
return fmt.Errorf("close conntrack: %w", err)
}
if closeErr != nil {
return fmt.Errorf("close conntrack: %w", closeErr)
}
return nil

View File

@@ -0,0 +1,224 @@
//go:build linux && !android
package conntrack
import (
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nfct "github.com/ti-mo/conntrack"
"github.com/ti-mo/netfilter"
)
type mockListener struct {
errChan chan error
closed atomic.Bool
closedCh chan struct{}
}
func newMockListener() *mockListener {
return &mockListener{
errChan: make(chan error, 1),
closedCh: make(chan struct{}),
}
}
func (m *mockListener) Listen(evChan chan<- nfct.Event, _ uint8, _ []netfilter.NetlinkGroup) (chan error, error) {
return m.errChan, nil
}
func (m *mockListener) Close() error {
if m.closed.CompareAndSwap(false, true) {
close(m.closedCh)
}
return nil
}
func TestReconnectAfterError(t *testing.T) {
first := newMockListener()
second := newMockListener()
third := newMockListener()
listeners := []*mockListener{first, second, third}
callCount := atomic.Int32{}
ct := New(nil, nil, WithDialer(func() (listener, error) {
n := int(callCount.Add(1)) - 1
return listeners[n], nil
}))
err := ct.Start(false)
require.NoError(t, err)
// Inject an error on the first listener.
first.errChan <- assert.AnError
// Wait for reconnect to complete.
require.Eventually(t, func() bool {
return callCount.Load() >= 2
}, 15*time.Second, 100*time.Millisecond, "reconnect should dial a new connection")
// The first connection must have been closed.
select {
case <-first.closedCh:
case <-time.After(2 * time.Second):
t.Fatal("first connection was not closed")
}
// Verify the receiver is still running by injecting and handling a second error.
second.errChan <- assert.AnError
require.Eventually(t, func() bool {
return callCount.Load() >= 3
}, 15*time.Second, 100*time.Millisecond, "second reconnect should succeed")
ct.Stop()
}
func TestStopDuringReconnectBackoff(t *testing.T) {
mock := newMockListener()
ct := New(nil, nil, WithDialer(func() (listener, error) {
return mock, nil
}))
err := ct.Start(false)
require.NoError(t, err)
// Trigger an error so the receiver enters reconnect.
mock.errChan <- assert.AnError
// Wait for the error handler to close the old listener before calling Stop.
select {
case <-mock.closedCh:
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for reconnect to start")
}
// Stop while reconnecting.
ct.Stop()
ct.mux.Lock()
assert.False(t, ct.started, "started should be false after Stop")
assert.Nil(t, ct.conn, "conn should be nil after Stop")
ct.mux.Unlock()
}
func TestStopRaceWithReconnectDial(t *testing.T) {
first := newMockListener()
dialStarted := make(chan struct{})
dialProceed := make(chan struct{})
second := newMockListener()
callCount := atomic.Int32{}
ct := New(nil, nil, WithDialer(func() (listener, error) {
n := callCount.Add(1)
if n == 1 {
return first, nil
}
// Second dial: signal that we're in progress, wait for test to call Stop.
close(dialStarted)
<-dialProceed
return second, nil
}))
err := ct.Start(false)
require.NoError(t, err)
// Trigger error to enter reconnect.
first.errChan <- assert.AnError
// Wait for reconnect's second dial to begin.
select {
case <-dialStarted:
case <-time.After(15 * time.Second):
t.Fatal("timed out waiting for reconnect dial")
}
// Stop while dial is in progress (conn is nil at this point).
ct.Stop()
// Let the dial complete. reconnect should detect started==false and close the new conn.
close(dialProceed)
// The second connection should be closed (not leaked).
select {
case <-second.closedCh:
case <-time.After(2 * time.Second):
t.Fatal("second connection was leaked after Stop")
}
ct.mux.Lock()
assert.False(t, ct.started)
assert.Nil(t, ct.conn)
ct.mux.Unlock()
}
func TestCloseRaceWithReconnectDial(t *testing.T) {
first := newMockListener()
dialStarted := make(chan struct{})
dialProceed := make(chan struct{})
second := newMockListener()
callCount := atomic.Int32{}
ct := New(nil, nil, WithDialer(func() (listener, error) {
n := callCount.Add(1)
if n == 1 {
return first, nil
}
close(dialStarted)
<-dialProceed
return second, nil
}))
err := ct.Start(false)
require.NoError(t, err)
first.errChan <- assert.AnError
select {
case <-dialStarted:
case <-time.After(15 * time.Second):
t.Fatal("timed out waiting for reconnect dial")
}
// Close while dial is in progress (conn is nil).
require.NoError(t, ct.Close())
close(dialProceed)
// The second connection should be closed (not leaked).
select {
case <-second.closedCh:
case <-time.After(2 * time.Second):
t.Fatal("second connection was leaked after Close")
}
ct.mux.Lock()
assert.False(t, ct.started)
assert.Nil(t, ct.conn)
ct.mux.Unlock()
}
func TestStartIsIdempotent(t *testing.T) {
mock := newMockListener()
callCount := atomic.Int32{}
ct := New(nil, nil, WithDialer(func() (listener, error) {
callCount.Add(1)
return mock, nil
}))
err := ct.Start(false)
require.NoError(t, err)
// Second Start should be a no-op.
err = ct.Start(false)
require.NoError(t, err)
assert.Equal(t, int32(1), callCount.Load(), "dial should only be called once")
ct.Stop()
}

View File

@@ -8,18 +8,27 @@ import (
)
const (
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
envDisablePCPHealthCheck = "NB_DISABLE_PCP_HEALTH_CHECK"
)
func isDisabledByEnv() bool {
val := os.Getenv(envDisableNATMapper)
return parseBoolEnv(envDisableNATMapper)
}
func isHealthCheckDisabled() bool {
return parseBoolEnv(envDisablePCPHealthCheck)
}
func parseBoolEnv(key string) bool {
val := os.Getenv(key)
if val == "" {
return false
}
disabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envDisableNATMapper, err)
log.Warnf("failed to parse %s: %v", key, err)
return false
}
return disabled

View File

@@ -12,12 +12,15 @@ import (
"github.com/libp2p/go-nat"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/portforward/pcp"
)
const (
defaultMappingTTL = 2 * time.Hour
discoveryTimeout = 10 * time.Second
mappingDescription = "NetBird"
defaultMappingTTL = 2 * time.Hour
healthCheckInterval = 1 * time.Minute
discoveryTimeout = 10 * time.Second
mappingDescription = "NetBird"
)
// upnpErrPermanentLeaseOnly matches UPnP error 725 in SOAP fault XML,
@@ -154,7 +157,7 @@ func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) {
discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout)
defer discoverCancel()
gateway, err := nat.DiscoverGateway(discoverCtx)
gateway, err := discoverGateway(discoverCtx)
if err != nil {
return nil, nil, fmt.Errorf("discover gateway: %w", err)
}
@@ -189,7 +192,6 @@ func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping,
externalIP, err := gateway.GetExternalAddress()
if err != nil {
log.Debugf("failed to get external address: %v", err)
// todo return with err?
}
mapping := &Mapping{
@@ -208,27 +210,87 @@ func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping,
func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT, ttl time.Duration) {
if ttl == 0 {
// Permanent mappings don't expire, just wait for cancellation.
<-ctx.Done()
// Permanent mappings don't expire, just wait for cancellation
// but still run health checks for PCP gateways.
m.permanentLeaseLoop(ctx, gateway)
return
}
ticker := time.NewTicker(ttl / 2)
defer ticker.Stop()
renewTicker := time.NewTicker(ttl / 2)
healthTicker := time.NewTicker(healthCheckInterval)
defer renewTicker.Stop()
defer healthTicker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
case <-renewTicker.C:
if err := m.renewMapping(ctx, gateway); err != nil {
log.Warnf("failed to renew port mapping: %v", err)
continue
}
case <-healthTicker.C:
if m.checkHealthAndRecreate(ctx, gateway) {
renewTicker.Reset(ttl / 2)
}
}
}
}
func (m *Manager) permanentLeaseLoop(ctx context.Context, gateway nat.NAT) {
healthTicker := time.NewTicker(healthCheckInterval)
defer healthTicker.Stop()
for {
select {
case <-ctx.Done():
return
case <-healthTicker.C:
m.checkHealthAndRecreate(ctx, gateway)
}
}
}
func (m *Manager) checkHealthAndRecreate(ctx context.Context, gateway nat.NAT) bool {
if isHealthCheckDisabled() {
return false
}
m.mappingLock.Lock()
hasMapping := m.mapping != nil
m.mappingLock.Unlock()
if !hasMapping {
return false
}
pcpNAT, ok := gateway.(*pcp.NAT)
if !ok {
return false
}
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
epoch, serverRestarted, err := pcpNAT.CheckServerHealth(ctx)
if err != nil {
log.Debugf("PCP health check failed: %v", err)
return false
}
if serverRestarted {
log.Warnf("PCP server restart detected (epoch=%d), recreating port mapping", epoch)
if err := m.renewMapping(ctx, gateway); err != nil {
log.Errorf("failed to recreate port mapping after server restart: %v", err)
return false
}
return true
}
return false
}
func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()

View File

@@ -0,0 +1,408 @@
package pcp
import (
"context"
"crypto/rand"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
defaultTimeout = 3 * time.Second
responseBufferSize = 128
// RFC 6887 Section 8.1.1 retry timing
initialRetryDelay = 3 * time.Second
maxRetryDelay = 1024 * time.Second
maxRetries = 4 // 3s + 6s + 12s + 24s = 45s total worst case
)
// Client is a PCP protocol client.
// All methods are safe for concurrent use.
type Client struct {
gateway netip.Addr
timeout time.Duration
mu sync.Mutex
// localIP caches the resolved local IP address.
localIP netip.Addr
// lastEpoch is the last observed server epoch value.
lastEpoch uint32
// epochTime tracks when lastEpoch was received for state loss detection.
epochTime time.Time
// externalIP caches the external IP from the last successful MAP response.
externalIP netip.Addr
// epochStateLost is set when epoch indicates server restart.
epochStateLost bool
}
// NewClient creates a new PCP client for the gateway at the given IP.
func NewClient(gateway net.IP) *Client {
addr, ok := netip.AddrFromSlice(gateway)
if !ok {
log.Debugf("invalid gateway IP: %v", gateway)
}
return &Client{
gateway: addr.Unmap(),
timeout: defaultTimeout,
}
}
// NewClientWithTimeout creates a new PCP client with a custom timeout.
func NewClientWithTimeout(gateway net.IP, timeout time.Duration) *Client {
addr, ok := netip.AddrFromSlice(gateway)
if !ok {
log.Debugf("invalid gateway IP: %v", gateway)
}
return &Client{
gateway: addr.Unmap(),
timeout: timeout,
}
}
// SetLocalIP sets the local IP address to use in PCP requests.
func (c *Client) SetLocalIP(ip net.IP) {
addr, ok := netip.AddrFromSlice(ip)
if !ok {
log.Debugf("invalid local IP: %v", ip)
}
c.mu.Lock()
c.localIP = addr.Unmap()
c.mu.Unlock()
}
// Gateway returns the gateway IP address.
func (c *Client) Gateway() net.IP {
return c.gateway.AsSlice()
}
// Announce sends a PCP ANNOUNCE request to discover PCP support.
// Returns the server's epoch time on success.
func (c *Client) Announce(ctx context.Context) (epoch uint32, err error) {
localIP, err := c.getLocalIP()
if err != nil {
return 0, fmt.Errorf("get local IP: %w", err)
}
req := buildAnnounceRequest(localIP)
resp, err := c.sendRequest(ctx, req)
if err != nil {
return 0, fmt.Errorf("send announce: %w", err)
}
parsed, err := parseResponse(resp)
if err != nil {
return 0, fmt.Errorf("parse announce response: %w", err)
}
if parsed.ResultCode != ResultSuccess {
return 0, fmt.Errorf("PCP ANNOUNCE failed: %s", ResultCodeString(parsed.ResultCode))
}
c.mu.Lock()
if c.updateEpochLocked(parsed.Epoch) {
log.Warnf("PCP server epoch indicates state loss - mappings may need refresh")
}
c.mu.Unlock()
return parsed.Epoch, nil
}
// AddPortMapping requests a port mapping from the PCP server.
func (c *Client) AddPortMapping(ctx context.Context, protocol string, internalPort int, lifetime time.Duration) (*MapResponse, error) {
return c.addPortMappingWithHint(ctx, protocol, internalPort, internalPort, netip.Addr{}, lifetime)
}
// AddPortMappingWithHint requests a port mapping with suggested external port and IP.
// Use lifetime <= 0 to delete a mapping.
func (c *Client) AddPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP net.IP, lifetime time.Duration) (*MapResponse, error) {
var extIP netip.Addr
if suggestedExtIP != nil {
var ok bool
extIP, ok = netip.AddrFromSlice(suggestedExtIP)
if !ok {
log.Debugf("invalid suggested external IP: %v", suggestedExtIP)
}
extIP = extIP.Unmap()
}
return c.addPortMappingWithHint(ctx, protocol, internalPort, suggestedExtPort, extIP, lifetime)
}
func (c *Client) addPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP netip.Addr, lifetime time.Duration) (*MapResponse, error) {
localIP, err := c.getLocalIP()
if err != nil {
return nil, fmt.Errorf("get local IP: %w", err)
}
proto, err := protocolNumber(protocol)
if err != nil {
return nil, fmt.Errorf("parse protocol: %w", err)
}
var nonce [12]byte
if _, err := rand.Read(nonce[:]); err != nil {
return nil, fmt.Errorf("generate nonce: %w", err)
}
// Convert lifetime to seconds. Lifetime 0 means delete, so only apply
// default for positive durations that round to 0 seconds.
var lifetimeSec uint32
if lifetime > 0 {
lifetimeSec = uint32(lifetime.Seconds())
if lifetimeSec == 0 {
lifetimeSec = DefaultLifetime
}
}
req := buildMapRequest(localIP, nonce, proto, uint16(internalPort), uint16(suggestedExtPort), suggestedExtIP, lifetimeSec)
resp, err := c.sendRequest(ctx, req)
if err != nil {
return nil, fmt.Errorf("send map request: %w", err)
}
mapResp, err := parseMapResponse(resp)
if err != nil {
return nil, fmt.Errorf("parse map response: %w", err)
}
if mapResp.Nonce != nonce {
return nil, fmt.Errorf("nonce mismatch in response")
}
if mapResp.Protocol != proto {
return nil, fmt.Errorf("protocol mismatch: requested %d, got %d", proto, mapResp.Protocol)
}
if mapResp.InternalPort != uint16(internalPort) {
return nil, fmt.Errorf("internal port mismatch: requested %d, got %d", internalPort, mapResp.InternalPort)
}
if mapResp.ResultCode != ResultSuccess {
return nil, &Error{
Code: mapResp.ResultCode,
Message: ResultCodeString(mapResp.ResultCode),
}
}
c.mu.Lock()
if c.updateEpochLocked(mapResp.Epoch) {
log.Warnf("PCP server epoch indicates state loss - mappings may need refresh")
}
c.cacheExternalIPLocked(mapResp.ExternalIP)
c.mu.Unlock()
return mapResp, nil
}
// DeletePortMapping removes a port mapping by requesting zero lifetime.
func (c *Client) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
if _, err := c.addPortMappingWithHint(ctx, protocol, internalPort, 0, netip.Addr{}, 0); err != nil {
var pcpErr *Error
if errors.As(err, &pcpErr) && pcpErr.Code == ResultNotAuthorized {
return nil
}
return fmt.Errorf("delete mapping: %w", err)
}
return nil
}
// GetExternalAddress returns the external IP address.
// First checks for a cached value from previous MAP responses.
// If not cached, creates a short-lived mapping to discover the external IP.
func (c *Client) GetExternalAddress(ctx context.Context) (net.IP, error) {
c.mu.Lock()
if c.externalIP.IsValid() {
ip := c.externalIP.AsSlice()
c.mu.Unlock()
return ip, nil
}
c.mu.Unlock()
// Use an ephemeral port in the dynamic range (49152-65535).
// Port 0 is not valid with UDP/TCP protocols per RFC 6887.
ephemeralPort := 49152 + int(uint16(time.Now().UnixNano()))%(65535-49152)
// Use minimal lifetime (1 second) for discovery.
resp, err := c.AddPortMapping(ctx, "udp", ephemeralPort, time.Second)
if err != nil {
return nil, fmt.Errorf("create temporary mapping: %w", err)
}
if err := c.DeletePortMapping(ctx, "udp", ephemeralPort); err != nil {
log.Debugf("cleanup temporary PCP mapping: %v", err)
}
return resp.ExternalIP.AsSlice(), nil
}
// LastEpoch returns the last observed server epoch value.
// A decrease in epoch indicates the server may have restarted and mappings may be lost.
func (c *Client) LastEpoch() uint32 {
c.mu.Lock()
defer c.mu.Unlock()
return c.lastEpoch
}
// EpochStateLost returns true if epoch state loss was detected and clears the flag.
func (c *Client) EpochStateLost() bool {
c.mu.Lock()
defer c.mu.Unlock()
lost := c.epochStateLost
c.epochStateLost = false
return lost
}
// updateEpoch updates the epoch tracking and detects potential state loss.
// Returns true if state loss was detected (server likely restarted).
// Caller must hold c.mu.
func (c *Client) updateEpochLocked(newEpoch uint32) bool {
now := time.Now()
stateLost := false
// RFC 6887 Section 8.5: Detect invalid epoch indicating server state loss.
// client_delta = time since last response
// server_delta = epoch change since last response
// Invalid if: client_delta+2 < server_delta - server_delta/16
// OR: server_delta+2 < client_delta - client_delta/16
// The +2 handles quantization, /16 (6.25%) handles clock drift.
if !c.epochTime.IsZero() && c.lastEpoch > 0 {
clientDelta := uint32(now.Sub(c.epochTime).Seconds())
serverDelta := newEpoch - c.lastEpoch
// Check for epoch going backwards or jumping unexpectedly.
// Subtraction is safe: serverDelta/16 is always <= serverDelta.
if clientDelta+2 < serverDelta-(serverDelta/16) ||
serverDelta+2 < clientDelta-(clientDelta/16) {
stateLost = true
c.epochStateLost = true
}
}
c.lastEpoch = newEpoch
c.epochTime = now
return stateLost
}
// cacheExternalIP stores the external IP from a successful MAP response.
// Caller must hold c.mu.
func (c *Client) cacheExternalIPLocked(ip netip.Addr) {
if ip.IsValid() && !ip.IsUnspecified() {
c.externalIP = ip
}
}
// sendRequest sends a PCP request with retries per RFC 6887 Section 8.1.1.
func (c *Client) sendRequest(ctx context.Context, req []byte) ([]byte, error) {
addr := &net.UDPAddr{IP: c.gateway.AsSlice(), Port: Port}
var lastErr error
delay := initialRetryDelay
for range maxRetries {
resp, err := c.sendOnce(ctx, addr, req)
if err == nil {
return resp, nil
}
lastErr = err
if ctx.Err() != nil {
return nil, ctx.Err()
}
// RFC 6887 Section 8.1.1: RT = (1 + RAND) * MIN(2 * RTprev, MRT)
// RAND is random between -0.1 and +0.1
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(retryDelayWithJitter(delay)):
}
delay = min(delay*2, maxRetryDelay)
}
return nil, fmt.Errorf("PCP request failed after %d retries: %w", maxRetries, lastErr)
}
// retryDelayWithJitter applies RFC 6887 jitter: multiply by (1 + RAND) where RAND is [-0.1, +0.1].
func retryDelayWithJitter(d time.Duration) time.Duration {
var b [1]byte
_, _ = rand.Read(b[:])
// Convert byte to range [-0.1, +0.1]: (b/255 * 0.2) - 0.1
jitter := (float64(b[0])/255.0)*0.2 - 0.1
return time.Duration(float64(d) * (1 + jitter))
}
func (c *Client) sendOnce(ctx context.Context, addr *net.UDPAddr, req []byte) ([]byte, error) {
// Use ListenUDP instead of DialUDP to validate response source address per RFC 6887 §8.3.
conn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, fmt.Errorf("listen: %w", err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("close UDP connection: %v", err)
}
}()
timeout := c.timeout
if deadline, ok := ctx.Deadline(); ok {
if remaining := time.Until(deadline); remaining < timeout {
timeout = remaining
}
}
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
if _, err := conn.WriteToUDP(req, addr); err != nil {
return nil, fmt.Errorf("write: %w", err)
}
resp := make([]byte, responseBufferSize)
n, from, err := conn.ReadFromUDP(resp)
if err != nil {
return nil, fmt.Errorf("read: %w", err)
}
// RFC 6887 §8.3: Validate response came from expected PCP server.
if !from.IP.Equal(addr.IP) {
return nil, fmt.Errorf("response from unexpected source %s (expected %s)", from.IP, addr.IP)
}
return resp[:n], nil
}
func (c *Client) getLocalIP() (netip.Addr, error) {
c.mu.Lock()
defer c.mu.Unlock()
if !c.localIP.IsValid() {
return netip.Addr{}, fmt.Errorf("local IP not set for gateway %s", c.gateway)
}
return c.localIP, nil
}
func protocolNumber(protocol string) (uint8, error) {
switch protocol {
case "udp", "UDP":
return ProtoUDP, nil
case "tcp", "TCP":
return ProtoTCP, nil
default:
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
}
}
// Error represents a PCP error response.
type Error struct {
Code uint8
Message string
}
func (e *Error) Error() string {
return fmt.Sprintf("PCP error: %s (%d)", e.Message, e.Code)
}

View File

@@ -0,0 +1,187 @@
package pcp
import (
"context"
"net"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAddrConversion(t *testing.T) {
tests := []struct {
name string
addr netip.Addr
}{
{"IPv4", netip.MustParseAddr("192.168.1.100")},
{"IPv4 loopback", netip.MustParseAddr("127.0.0.1")},
{"IPv6", netip.MustParseAddr("2001:db8::1")},
{"IPv6 loopback", netip.MustParseAddr("::1")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b16 := addrTo16(tt.addr)
recovered := addrFrom16(b16)
assert.Equal(t, tt.addr, recovered, "address should round-trip")
})
}
}
func TestBuildAnnounceRequest(t *testing.T) {
clientIP := netip.MustParseAddr("192.168.1.100")
req := buildAnnounceRequest(clientIP)
require.Len(t, req, headerSize)
assert.Equal(t, byte(Version), req[0], "version")
assert.Equal(t, byte(OpAnnounce), req[1], "opcode")
// Check client IP is properly encoded as IPv4-mapped IPv6
assert.Equal(t, byte(0xff), req[18], "IPv4-mapped prefix byte 10")
assert.Equal(t, byte(0xff), req[19], "IPv4-mapped prefix byte 11")
assert.Equal(t, byte(192), req[20], "IP octet 1")
assert.Equal(t, byte(168), req[21], "IP octet 2")
assert.Equal(t, byte(1), req[22], "IP octet 3")
assert.Equal(t, byte(100), req[23], "IP octet 4")
}
func TestBuildMapRequest(t *testing.T) {
clientIP := netip.MustParseAddr("192.168.1.100")
nonce := [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}
req := buildMapRequest(clientIP, nonce, ProtoUDP, 51820, 51820, netip.Addr{}, 3600)
require.Len(t, req, mapRequestSize)
assert.Equal(t, byte(Version), req[0], "version")
assert.Equal(t, byte(OpMap), req[1], "opcode")
// Lifetime at bytes 4-7
assert.Equal(t, uint32(3600), (uint32(req[4])<<24)|(uint32(req[5])<<16)|(uint32(req[6])<<8)|uint32(req[7]), "lifetime")
// Nonce at bytes 24-35
assert.Equal(t, nonce[:], req[24:36], "nonce")
// Protocol at byte 36
assert.Equal(t, byte(ProtoUDP), req[36], "protocol")
// Internal port at bytes 40-41
assert.Equal(t, uint16(51820), (uint16(req[40])<<8)|uint16(req[41]), "internal port")
// External port at bytes 42-43
assert.Equal(t, uint16(51820), (uint16(req[42])<<8)|uint16(req[43]), "external port")
}
func TestParseResponse(t *testing.T) {
// Construct a valid ANNOUNCE response
resp := make([]byte, headerSize)
resp[0] = Version
resp[1] = OpAnnounce | OpReply
// Result code = 0 (success)
// Lifetime = 0
// Epoch = 12345
resp[8] = 0
resp[9] = 0
resp[10] = 0x30
resp[11] = 0x39
parsed, err := parseResponse(resp)
require.NoError(t, err)
assert.Equal(t, uint8(Version), parsed.Version)
assert.Equal(t, uint8(OpAnnounce|OpReply), parsed.Opcode)
assert.Equal(t, uint8(ResultSuccess), parsed.ResultCode)
assert.Equal(t, uint32(12345), parsed.Epoch)
}
func TestParseResponseErrors(t *testing.T) {
t.Run("too short", func(t *testing.T) {
_, err := parseResponse([]byte{1, 2, 3})
assert.Error(t, err)
})
t.Run("wrong version", func(t *testing.T) {
resp := make([]byte, headerSize)
resp[0] = 1 // Wrong version
resp[1] = OpReply
_, err := parseResponse(resp)
assert.Error(t, err)
})
t.Run("missing reply bit", func(t *testing.T) {
resp := make([]byte, headerSize)
resp[0] = Version
resp[1] = OpAnnounce // Missing OpReply bit
_, err := parseResponse(resp)
assert.Error(t, err)
})
}
func TestResultCodeString(t *testing.T) {
assert.Equal(t, "SUCCESS", ResultCodeString(ResultSuccess))
assert.Equal(t, "NOT_AUTHORIZED", ResultCodeString(ResultNotAuthorized))
assert.Equal(t, "ADDRESS_MISMATCH", ResultCodeString(ResultAddressMismatch))
assert.Contains(t, ResultCodeString(255), "UNKNOWN")
}
func TestProtocolNumber(t *testing.T) {
proto, err := protocolNumber("udp")
require.NoError(t, err)
assert.Equal(t, uint8(ProtoUDP), proto)
proto, err = protocolNumber("tcp")
require.NoError(t, err)
assert.Equal(t, uint8(ProtoTCP), proto)
proto, err = protocolNumber("UDP")
require.NoError(t, err)
assert.Equal(t, uint8(ProtoUDP), proto)
_, err = protocolNumber("icmp")
assert.Error(t, err)
}
func TestClientCreation(t *testing.T) {
gateway := netip.MustParseAddr("192.168.1.1").AsSlice()
client := NewClient(gateway)
assert.Equal(t, net.IP(gateway), client.Gateway())
assert.Equal(t, defaultTimeout, client.timeout)
clientWithTimeout := NewClientWithTimeout(gateway, 5*time.Second)
assert.Equal(t, 5*time.Second, clientWithTimeout.timeout)
}
func TestNATType(t *testing.T) {
n := NewNAT(netip.MustParseAddr("192.168.1.1").AsSlice(), netip.MustParseAddr("192.168.1.100").AsSlice())
assert.Equal(t, "PCP", n.Type())
}
// Integration test - skipped unless PCP_TEST_GATEWAY env is set
func TestClientIntegration(t *testing.T) {
t.Skip("Integration test - run manually with PCP_TEST_GATEWAY=<gateway-ip>")
gateway := netip.MustParseAddr("10.0.1.1").AsSlice() // Change to your test gateway
localIP := netip.MustParseAddr("10.0.1.100").AsSlice() // Change to your local IP
client := NewClient(gateway)
client.SetLocalIP(localIP)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Test ANNOUNCE
epoch, err := client.Announce(ctx)
require.NoError(t, err)
t.Logf("Server epoch: %d", epoch)
// Test MAP
resp, err := client.AddPortMapping(ctx, "udp", 51820, 1*time.Hour)
require.NoError(t, err)
t.Logf("Mapping: internal=%d external=%d externalIP=%s",
resp.InternalPort, resp.ExternalPort, resp.ExternalIP)
// Cleanup
err = client.DeletePortMapping(ctx, "udp", 51820)
require.NoError(t, err)
}

View File

@@ -0,0 +1,209 @@
package pcp
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/libp2p/go-nat"
"github.com/libp2p/go-netroute"
)
var _ nat.NAT = (*NAT)(nil)
// NAT implements the go-nat NAT interface using PCP.
// Supports dual-stack (IPv4 and IPv6) when available.
// All methods are safe for concurrent use.
//
// TODO: IPv6 pinholes use the local IPv6 address. If the address changes
// (e.g., due to SLAAC rotation or network change), the pinhole becomes stale
// and needs to be recreated with the new address.
type NAT struct {
client *Client
mu sync.RWMutex
// client6 is the IPv6 PCP client, nil if IPv6 is unavailable.
client6 *Client
// localIP6 caches the local IPv6 address used for PCP requests.
localIP6 netip.Addr
}
// NewNAT creates a new NAT instance backed by PCP.
func NewNAT(gateway, localIP net.IP) *NAT {
client := NewClient(gateway)
client.SetLocalIP(localIP)
return &NAT{
client: client,
}
}
// Type returns "PCP" as the NAT type.
func (n *NAT) Type() string {
return "PCP"
}
// GetDeviceAddress returns the gateway IP address.
func (n *NAT) GetDeviceAddress() (net.IP, error) {
return n.client.Gateway(), nil
}
// GetExternalAddress returns the external IP address.
func (n *NAT) GetExternalAddress() (net.IP, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
return n.client.GetExternalAddress(ctx)
}
// GetInternalAddress returns the local IP address used to communicate with the gateway.
func (n *NAT) GetInternalAddress() (net.IP, error) {
addr, err := n.client.getLocalIP()
if err != nil {
return nil, err
}
return addr.AsSlice(), nil
}
// AddPortMapping creates a port mapping on both IPv4 and IPv6 (if available).
func (n *NAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, _ string, timeout time.Duration) (int, error) {
resp, err := n.client.AddPortMapping(ctx, protocol, internalPort, timeout)
if err != nil {
return 0, fmt.Errorf("add mapping: %w", err)
}
n.mu.RLock()
client6 := n.client6
localIP6 := n.localIP6
n.mu.RUnlock()
if client6 == nil {
return int(resp.ExternalPort), nil
}
if _, err := client6.AddPortMapping(ctx, protocol, internalPort, timeout); err != nil {
log.Warnf("IPv6 PCP mapping failed (continuing with IPv4): %v", err)
return int(resp.ExternalPort), nil
}
log.Infof("created IPv6 PCP pinhole: %s:%d", localIP6, internalPort)
return int(resp.ExternalPort), nil
}
// DeletePortMapping removes a port mapping from both IPv4 and IPv6.
func (n *NAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
err := n.client.DeletePortMapping(ctx, protocol, internalPort)
n.mu.RLock()
client6 := n.client6
n.mu.RUnlock()
if client6 != nil {
if err6 := client6.DeletePortMapping(ctx, protocol, internalPort); err6 != nil {
log.Warnf("IPv6 PCP delete mapping failed: %v", err6)
}
}
if err != nil {
return fmt.Errorf("delete mapping: %w", err)
}
return nil
}
// CheckServerHealth sends an ANNOUNCE to verify the server is still responsive.
// Returns the current epoch and whether the server may have restarted (epoch state loss detected).
func (n *NAT) CheckServerHealth(ctx context.Context) (epoch uint32, serverRestarted bool, err error) {
epoch, err = n.client.Announce(ctx)
if err != nil {
return 0, false, fmt.Errorf("announce: %w", err)
}
return epoch, n.client.EpochStateLost(), nil
}
// DiscoverPCP attempts to discover a PCP-capable gateway.
// Returns a NAT interface if PCP is supported, or an error otherwise.
// Discovers both IPv4 and IPv6 gateways when available.
func DiscoverPCP(ctx context.Context) (nat.NAT, error) {
gateway, localIP, err := getDefaultGateway()
if err != nil {
return nil, fmt.Errorf("get default gateway: %w", err)
}
client := NewClient(gateway)
client.SetLocalIP(localIP)
if _, err := client.Announce(ctx); err != nil {
return nil, fmt.Errorf("PCP announce: %w", err)
}
result := &NAT{client: client}
discoverIPv6(ctx, result)
return result, nil
}
func discoverIPv6(ctx context.Context, result *NAT) {
gateway6, localIP6, err := getDefaultGateway6()
if err != nil {
log.Debugf("IPv6 gateway discovery failed: %v", err)
return
}
client6 := NewClient(gateway6)
client6.SetLocalIP(localIP6)
if _, err := client6.Announce(ctx); err != nil {
log.Debugf("PCP IPv6 announce failed: %v", err)
return
}
addr, ok := netip.AddrFromSlice(localIP6)
if !ok {
log.Debugf("invalid IPv6 local IP: %v", localIP6)
return
}
result.mu.Lock()
result.client6 = client6
result.localIP6 = addr
result.mu.Unlock()
log.Debugf("PCP IPv6 gateway discovered: %s (local: %s)", gateway6, localIP6)
}
// getDefaultGateway returns the default IPv4 gateway and local IP using the system routing table.
func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) {
router, err := netroute.New()
if err != nil {
return nil, nil, err
}
_, gateway, localIP, err = router.Route(net.IPv4zero)
if err != nil {
return nil, nil, err
}
if gateway == nil {
return nil, nil, nat.ErrNoNATFound
}
return gateway, localIP, nil
}
// getDefaultGateway6 returns the default IPv6 gateway IP address using the system routing table.
func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) {
router, err := netroute.New()
if err != nil {
return nil, nil, err
}
_, gateway, localIP, err = router.Route(net.IPv6zero)
if err != nil {
return nil, nil, err
}
if gateway == nil {
return nil, nil, nat.ErrNoNATFound
}
return gateway, localIP, nil
}

View File

@@ -0,0 +1,225 @@
// Package pcp implements the Port Control Protocol (RFC 6887).
//
// # Implemented Features
//
// - ANNOUNCE opcode: Discovers PCP server support
// - MAP opcode: Creates/deletes port mappings (IPv4 NAT) and firewall pinholes (IPv6)
// - Dual-stack: Simultaneous IPv4 and IPv6 support via separate clients
// - Nonce validation: Prevents response spoofing
// - Epoch tracking: Detects server restarts per Section 8.5
// - RFC-compliant retry timing: 3s initial, exponential backoff to 1024s max (Section 8.1.1)
//
// # Not Implemented
//
// - PEER opcode: For outbound peer connections (not needed for inbound NAT traversal)
// - THIRD_PARTY option: For managing mappings on behalf of other devices
// - PREFER_FAILURE option: Requires exact external port or fail (IPv4 NAT only, not needed for IPv6 pinholing)
// - FILTER option: To restrict remote peer addresses
//
// These optional features are omitted because the primary use case is simple
// port forwarding for WireGuard, which only requires MAP with default behavior.
package pcp
import (
"encoding/binary"
"fmt"
"net/netip"
)
const (
// Version is the PCP protocol version (RFC 6887).
Version = 2
// Port is the standard PCP server port.
Port = 5351
// DefaultLifetime is the default requested mapping lifetime in seconds.
DefaultLifetime = 7200 // 2 hours
// Header sizes
headerSize = 24
mapPayloadSize = 36
mapRequestSize = headerSize + mapPayloadSize // 60 bytes
)
// Opcodes
const (
OpAnnounce = 0
OpMap = 1
OpPeer = 2
OpReply = 0x80 // OR'd with opcode in responses
)
// Protocol numbers for MAP requests
const (
ProtoUDP = 17
ProtoTCP = 6
)
// Result codes (RFC 6887 Section 7.4)
const (
ResultSuccess = 0
ResultUnsuppVersion = 1
ResultNotAuthorized = 2
ResultMalformedRequest = 3
ResultUnsuppOpcode = 4
ResultUnsuppOption = 5
ResultMalformedOption = 6
ResultNetworkFailure = 7
ResultNoResources = 8
ResultUnsuppProtocol = 9
ResultUserExQuota = 10
ResultCannotProvideExt = 11
ResultAddressMismatch = 12
ResultExcessiveRemotePeers = 13
)
// ResultCodeString returns a human-readable string for a result code.
func ResultCodeString(code uint8) string {
switch code {
case ResultSuccess:
return "SUCCESS"
case ResultUnsuppVersion:
return "UNSUPP_VERSION"
case ResultNotAuthorized:
return "NOT_AUTHORIZED"
case ResultMalformedRequest:
return "MALFORMED_REQUEST"
case ResultUnsuppOpcode:
return "UNSUPP_OPCODE"
case ResultUnsuppOption:
return "UNSUPP_OPTION"
case ResultMalformedOption:
return "MALFORMED_OPTION"
case ResultNetworkFailure:
return "NETWORK_FAILURE"
case ResultNoResources:
return "NO_RESOURCES"
case ResultUnsuppProtocol:
return "UNSUPP_PROTOCOL"
case ResultUserExQuota:
return "USER_EX_QUOTA"
case ResultCannotProvideExt:
return "CANNOT_PROVIDE_EXTERNAL"
case ResultAddressMismatch:
return "ADDRESS_MISMATCH"
case ResultExcessiveRemotePeers:
return "EXCESSIVE_REMOTE_PEERS"
default:
return fmt.Sprintf("UNKNOWN(%d)", code)
}
}
// Response represents a parsed PCP response header.
type Response struct {
Version uint8
Opcode uint8
ResultCode uint8
Lifetime uint32
Epoch uint32
}
// MapResponse contains the full response to a MAP request.
type MapResponse struct {
Response
Nonce [12]byte
Protocol uint8
InternalPort uint16
ExternalPort uint16
ExternalIP netip.Addr
}
// addrTo16 converts an address to its 16-byte IPv4-mapped IPv6 representation.
func addrTo16(addr netip.Addr) [16]byte {
if addr.Is4() {
return netip.AddrFrom4(addr.As4()).As16()
}
return addr.As16()
}
// addrFrom16 extracts an address from a 16-byte representation, unmapping IPv4.
func addrFrom16(b [16]byte) netip.Addr {
return netip.AddrFrom16(b).Unmap()
}
// buildAnnounceRequest creates a PCP ANNOUNCE request packet.
func buildAnnounceRequest(clientIP netip.Addr) []byte {
req := make([]byte, headerSize)
req[0] = Version
req[1] = OpAnnounce
mapped := addrTo16(clientIP)
copy(req[8:24], mapped[:])
return req
}
// buildMapRequest creates a PCP MAP request packet.
func buildMapRequest(clientIP netip.Addr, nonce [12]byte, protocol uint8, internalPort, suggestedExtPort uint16, suggestedExtIP netip.Addr, lifetime uint32) []byte {
req := make([]byte, mapRequestSize)
// Header
req[0] = Version
req[1] = OpMap
binary.BigEndian.PutUint32(req[4:8], lifetime)
mapped := addrTo16(clientIP)
copy(req[8:24], mapped[:])
// MAP payload
copy(req[24:36], nonce[:])
req[36] = protocol
binary.BigEndian.PutUint16(req[40:42], internalPort)
binary.BigEndian.PutUint16(req[42:44], suggestedExtPort)
if suggestedExtIP.IsValid() {
extMapped := addrTo16(suggestedExtIP)
copy(req[44:60], extMapped[:])
}
return req
}
// parseResponse parses the common PCP response header.
func parseResponse(data []byte) (*Response, error) {
if len(data) < headerSize {
return nil, fmt.Errorf("response too short: %d bytes", len(data))
}
resp := &Response{
Version: data[0],
Opcode: data[1],
ResultCode: data[3], // Byte 2 is reserved, byte 3 is result code (RFC 6887 §7.2)
Lifetime: binary.BigEndian.Uint32(data[4:8]),
Epoch: binary.BigEndian.Uint32(data[8:12]),
}
if resp.Version != Version {
return nil, fmt.Errorf("unsupported PCP version: %d", resp.Version)
}
if resp.Opcode&OpReply == 0 {
return nil, fmt.Errorf("response missing reply bit: opcode=0x%02x", resp.Opcode)
}
return resp, nil
}
// parseMapResponse parses a complete MAP response.
func parseMapResponse(data []byte) (*MapResponse, error) {
if len(data) < mapRequestSize {
return nil, fmt.Errorf("MAP response too short: %d bytes", len(data))
}
resp, err := parseResponse(data)
if err != nil {
return nil, fmt.Errorf("parse header: %w", err)
}
mapResp := &MapResponse{
Response: *resp,
Protocol: data[36],
InternalPort: binary.BigEndian.Uint16(data[40:42]),
ExternalPort: binary.BigEndian.Uint16(data[42:44]),
ExternalIP: addrFrom16([16]byte(data[44:60])),
}
copy(mapResp.Nonce[:], data[24:36])
return mapResp, nil
}

View File

@@ -0,0 +1,63 @@
//go:build !js
package portforward
import (
"context"
"fmt"
"github.com/libp2p/go-nat"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/portforward/pcp"
)
// discoverGateway is the function used for NAT gateway discovery.
// It can be replaced in tests to avoid real network operations.
// Tries PCP first, then falls back to NAT-PMP/UPnP.
var discoverGateway = defaultDiscoverGateway
func defaultDiscoverGateway(ctx context.Context) (nat.NAT, error) {
pcpGateway, err := pcp.DiscoverPCP(ctx)
if err == nil {
return pcpGateway, nil
}
log.Debugf("PCP discovery failed: %v, trying NAT-PMP/UPnP", err)
return nat.DiscoverGateway(ctx)
}
// State is persisted only for crash recovery cleanup
type State struct {
InternalPort uint16 `json:"internal_port,omitempty"`
Protocol string `json:"protocol,omitempty"`
}
func (s *State) Name() string {
return "port_forward_state"
}
// Cleanup implements statemanager.CleanableState for crash recovery
func (s *State) Cleanup() error {
if s.InternalPort == 0 {
return nil
}
log.Infof("cleaning up stale port mapping for port %d", s.InternalPort)
ctx, cancel := context.WithTimeout(context.Background(), discoveryTimeout)
defer cancel()
gateway, err := discoverGateway(ctx)
if err != nil {
// Discovery failure is not an error - gateway may not exist
log.Debugf("cleanup: no gateway found: %v", err)
return nil
}
if err := gateway.DeletePortMapping(ctx, s.Protocol, int(s.InternalPort)); err != nil {
return fmt.Errorf("delete port mapping: %w", err)
}
return nil
}

View File

@@ -64,11 +64,13 @@ type ConfigInput struct {
StateFilePath string
PreSharedKey *string
ServerSSHAllowed *bool
ServerVNCAllowed *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
DisableVNCAuth *bool
SSHJWTCacheTTL *int
NATExternalIPs []string
CustomDNSAddress []byte
@@ -114,11 +116,13 @@ type Config struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed *bool
ServerVNCAllowed *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
DisableVNCAuth *bool
SSHJWTCacheTTL *int
DisableClientRoutes bool
@@ -415,6 +419,21 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.ServerVNCAllowed != nil {
if config.ServerVNCAllowed == nil || *input.ServerVNCAllowed != *config.ServerVNCAllowed {
if *input.ServerVNCAllowed {
log.Infof("enabling VNC server")
} else {
log.Infof("disabling VNC server")
}
config.ServerVNCAllowed = input.ServerVNCAllowed
updated = true
}
} else if config.ServerVNCAllowed == nil {
config.ServerVNCAllowed = util.True()
updated = true
}
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
if *input.EnableSSHRoot {
log.Infof("enabling SSH root login")
@@ -465,6 +484,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.DisableVNCAuth != nil && input.DisableVNCAuth != config.DisableVNCAuth {
if *input.DisableVNCAuth {
log.Infof("disabling VNC authentication")
} else {
log.Infof("enabling VNC authentication")
}
config.DisableVNCAuth = input.DisableVNCAuth
updated = true
}
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL

View File

@@ -168,6 +168,7 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
NetworkType: route.IPv4Network,
}
cr = append(cr, fakeIPRoute)
m.notifier.SetFakeIPRoute(fakeIPRoute)
}
m.notifier.SetInitialClientRoutes(cr, routesForComparison)

View File

@@ -16,6 +16,7 @@ import (
type Notifier struct {
initialRoutes []*route.Route
currentRoutes []*route.Route
fakeIPRoute *route.Route
listener listener.NetworkChangeListener
listenerMux sync.Mutex
@@ -31,13 +32,17 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listener = listener
}
// SetInitialClientRoutes stores the full initial route set (including fake IP blocks)
// and a separate comparison set (without fake IP blocks) for diff detection.
// SetInitialClientRoutes stores the initial route sets for TUN configuration.
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
n.initialRoutes = filterStatic(initialRoutes)
n.currentRoutes = filterStatic(routesForComparison)
}
// SetFakeIPRoute stores the fake IP route to be included in every TUN rebuild.
func (n *Notifier) SetFakeIPRoute(r *route.Route) {
n.fakeIPRoute = r
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
var newRoutes []*route.Route
for _, routes := range idMap {
@@ -69,7 +74,9 @@ func (n *Notifier) notify() {
}
allRoutes := slices.Clone(n.currentRoutes)
allRoutes = append(allRoutes, n.extraInitialRoutes()...)
if n.fakeIPRoute != nil {
allRoutes = append(allRoutes, n.fakeIPRoute)
}
routeStrings := n.routesToStrings(allRoutes)
sort.Strings(routeStrings)
@@ -78,23 +85,6 @@ func (n *Notifier) notify() {
}(n.listener)
}
// extraInitialRoutes returns initialRoutes whose network prefix is absent
// from currentRoutes (e.g. the fake IP block added at setup time).
func (n *Notifier) extraInitialRoutes() []*route.Route {
currentNets := make(map[netip.Prefix]struct{}, len(n.currentRoutes))
for _, r := range n.currentRoutes {
currentNets[r.Network] = struct{}{}
}
var extra []*route.Route
for _, r := range n.initialRoutes {
if _, ok := currentNets[r.Network]; !ok {
extra = append(extra, r)
}
}
return extra
}
func filterStatic(routes []*route.Route) []*route.Route {
out := make([]*route.Route, 0, len(routes))
for _, r := range routes {

View File

@@ -34,6 +34,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
// iOS doesn't care about initial routes
}
func (n *Notifier) SetFakeIPRoute(*route.Route) {
// Not used on iOS
}
func (n *Notifier) OnNewRoutes(route.HAMap) {
// Not used on iOS
}

View File

@@ -23,6 +23,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
// Not used on non-mobile platforms
}
func (n *Notifier) SetFakeIPRoute(*route.Route) {
// Not used on non-mobile platforms
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
// Not used on non-mobile platforms
}

View File

@@ -0,0 +1,10 @@
//go:build (dragonfly || freebsd || netbsd || openbsd) && !darwin
package systemops
// Non-darwin BSDs don't support the IP_BOUND_IF + scoped default model. They
// always fall through to the ref-counter exclusion-route path; these stubs
// exist only so systemops_unix.go compiles.
func (r *SysOps) setupAdvancedRouting() error { return nil }
func (r *SysOps) cleanupAdvancedRouting() error { return nil }
func (r *SysOps) flushPlatformExtras() error { return nil }

View File

@@ -0,0 +1,241 @@
//go:build darwin && !ios
package systemops
import (
"errors"
"fmt"
"net/netip"
"os"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
nbnet "github.com/netbirdio/netbird/client/net"
)
// scopedRouteBudget bounds retries for the scoped default route. Installing or
// deleting it matters enough that we're willing to spend longer waiting for the
// kernel reply than for per-prefix exclusion routes.
const scopedRouteBudget = 5 * time.Second
// setupAdvancedRouting installs an RTF_IFSCOPE default route per address family
// pinned to the current physical egress, so IP_BOUND_IF scoped lookups can
// resolve gateway'd destinations while the VPN's split default owns the
// unscoped table.
//
// Timing note: this runs during routeManager.Init, which happens before the
// VPN interface is created and before any peer routes propagate. The initial
// mgmt / signal / relay TCP dials always fire before this runs, so those
// sockets miss the IP_BOUND_IF binding and rely on the kernel's normal route
// lookup, which at that point correctly picks the physical default. Those
// already-established TCP flows keep their originally-selected interface for
// their lifetime on Darwin because the kernel caches the egress route
// per-socket at connect time; adding the VPN's 0/1 + 128/1 split default
// afterwards does not migrate them since the original en0 default stays in
// the table. Any subsequent reconnect via nbnet.NewDialer picks up the
// populated bound-iface cache and gets IP_BOUND_IF set cleanly.
func (r *SysOps) setupAdvancedRouting() error {
// Drop any previously-cached egress interface before reinstalling. On a
// refresh, a family that no longer resolves would otherwise keep the stale
// binding, causing new sockets to scope to an interface without a matching
// scoped default.
nbnet.ClearBoundInterfaces()
if err := r.flushScopedDefaults(); err != nil {
log.Warnf("flush residual scoped defaults: %v", err)
}
var merr *multierror.Error
installed := 0
for _, unspec := range []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified()} {
ok, err := r.installScopedDefaultFor(unspec)
if err != nil {
merr = multierror.Append(merr, err)
continue
}
if ok {
installed++
}
}
if installed == 0 && merr != nil {
return nberrors.FormatErrorOrNil(merr)
}
if merr != nil {
log.Warnf("advanced routing setup partially succeeded: %v", nberrors.FormatErrorOrNil(merr))
}
return nil
}
// installScopedDefaultFor resolves the physical default nexthop for the given
// address family, installs a scoped default via it, and caches the iface for
// subsequent IP_BOUND_IF / IPV6_BOUND_IF socket binds.
func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
nexthop, err := GetNextHop(unspec)
if err != nil {
if errors.Is(err, vars.ErrRouteNotFound) {
return false, nil
}
return false, fmt.Errorf("get default nexthop for %s: %w", unspec, err)
}
if nexthop.Intf == nil {
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
}
if err := r.addScopedDefault(unspec, nexthop); err != nil {
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
}
af := unix.AF_INET
if unspec.Is6() {
af = unix.AF_INET6
}
nbnet.SetBoundInterface(af, nexthop.Intf)
via := "point-to-point"
if nexthop.IP.IsValid() {
via = nexthop.IP.String()
}
log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec))
return true, nil
}
func (r *SysOps) cleanupAdvancedRouting() error {
nbnet.ClearBoundInterfaces()
return r.flushScopedDefaults()
}
// flushPlatformExtras runs darwin-specific residual cleanup hooked into the
// generic FlushMarkedRoutes path, so a crashed daemon's scoped defaults get
// removed on the next boot regardless of whether a profile is brought up.
func (r *SysOps) flushPlatformExtras() error {
return r.flushScopedDefaults()
}
// flushScopedDefaults removes any scoped default routes tagged with routeProtoFlag.
// Safe to call at startup to clear residual entries from a prior session.
func (r *SysOps) flushScopedDefaults() error {
rib, err := retryFetchRIB()
if err != nil {
return fmt.Errorf("fetch routing table: %w", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return fmt.Errorf("parse routing table: %w", err)
}
var merr *multierror.Error
removed := 0
for _, msg := range msgs {
rtMsg, ok := msg.(*route.RouteMessage)
if !ok {
continue
}
if rtMsg.Flags&routeProtoFlag == 0 {
continue
}
if rtMsg.Flags&unix.RTF_IFSCOPE == 0 {
continue
}
info, err := MsgToRoute(rtMsg)
if err != nil {
log.Debugf("skip scoped flush: %v", err)
continue
}
if !info.Dst.IsValid() || info.Dst.Bits() != 0 {
continue
}
if err := r.deleteScopedRoute(rtMsg); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete scoped default %s on index %d: %w",
info.Dst, rtMsg.Index, err))
continue
}
removed++
log.Debugf("flushed residual scoped default %s on index %d", info.Dst, rtMsg.Index)
}
if removed > 0 {
log.Infof("flushed %d residual scoped default route(s)", removed)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *SysOps) addScopedDefault(unspec netip.Addr, nexthop Nexthop) error {
return r.scopedRouteSocket(unix.RTM_ADD, unspec, nexthop)
}
func (r *SysOps) deleteScopedRoute(rtMsg *route.RouteMessage) error {
// Preserve identifying flags from the stored route (including RTF_GATEWAY
// only if present); kernel-set bits like RTF_DONE don't belong on RTM_DELETE.
keep := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY | unix.RTF_IFSCOPE | routeProtoFlag
del := &route.RouteMessage{
Type: unix.RTM_DELETE,
Flags: rtMsg.Flags & keep,
Version: unix.RTM_VERSION,
Seq: r.getSeq(),
Index: rtMsg.Index,
Addrs: rtMsg.Addrs,
}
return r.writeRouteMessage(del, scopedRouteBudget)
}
func (r *SysOps) scopedRouteSocket(action int, unspec netip.Addr, nexthop Nexthop) error {
flags := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_IFSCOPE | routeProtoFlag
msg := &route.RouteMessage{
Type: action,
Flags: flags,
Version: unix.RTM_VERSION,
ID: uintptr(os.Getpid()),
Seq: r.getSeq(),
Index: nexthop.Intf.Index,
}
const numAddrs = unix.RTAX_NETMASK + 1
addrs := make([]route.Addr, numAddrs)
dst, err := addrToRouteAddr(unspec)
if err != nil {
return fmt.Errorf("build destination: %w", err)
}
mask, err := prefixToRouteNetmask(netip.PrefixFrom(unspec, 0))
if err != nil {
return fmt.Errorf("build netmask: %w", err)
}
addrs[unix.RTAX_DST] = dst
addrs[unix.RTAX_NETMASK] = mask
if nexthop.IP.IsValid() {
msg.Flags |= unix.RTF_GATEWAY
gw, err := addrToRouteAddr(nexthop.IP.Unmap())
if err != nil {
return fmt.Errorf("build gateway: %w", err)
}
addrs[unix.RTAX_GATEWAY] = gw
} else {
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
Index: nexthop.Intf.Index,
Name: nexthop.Intf.Name,
}
}
msg.Addrs = addrs
return r.writeRouteMessage(msg, scopedRouteBudget)
}
func afOf(a netip.Addr) string {
if a.Is4() {
return "IPv4"
}
return "IPv6"
}

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/net/hooks"
)
@@ -31,8 +32,6 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRoutingIsSeparate = errors.New("routing is separate")
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
stateManager.RegisterState(&ShutdownState{})
@@ -397,12 +396,16 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
}
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
// When advanced routing is active the WG socket is bound to the physical interface (fwmark on linux,
// IP_UNICAST_IF on windows, IP_BOUND_IF on darwin) and bypasses the main routing table, so the check is skipped.
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
localRoutes, err := hasSeparateRouting()
if nbnet.AdvancedRouting() {
return false, netip.Prefix{}
}
localRoutes, err := GetRoutesFromTable()
if err != nil {
if !errors.Is(err, ErrRoutingIsSeparate) {
log.Errorf("Failed to get routes: %v", err)
}
log.Errorf("Failed to get routes: %v", err)
return false, netip.Prefix{}
}

View File

@@ -22,10 +22,6 @@ func GetRoutesFromTable() ([]netip.Prefix, error) {
return []netip.Prefix{}, nil
}
func hasSeparateRouting() ([]netip.Prefix, error) {
return []netip.Prefix{}, nil
}
// GetDetailedRoutesFromTable returns empty routes for WASM.
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
return []DetailedRoute{}, nil

View File

@@ -894,13 +894,6 @@ func getAddressFamily(prefix netip.Prefix) int {
return netlink.FAMILY_V6
}
func hasSeparateRouting() ([]netip.Prefix, error) {
if !nbnet.AdvancedRouting() {
return GetRoutesFromTable()
}
return nil, ErrRoutingIsSeparate
}
func isOpErr(err error) bool {
// EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported
if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) {

View File

@@ -48,10 +48,6 @@ func EnableIPForwarding() error {
return nil
}
func hasSeparateRouting() ([]netip.Prefix, error) {
return GetRoutesFromTable()
}
// GetIPRules returns IP rules for debugging (not supported on non-Linux platforms)
func GetIPRules() ([]IPRule, error) {
log.Infof("IP rules collection is not supported on %s", runtime.GOOS)

View File

@@ -25,6 +25,9 @@ import (
const (
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
// routeBudget bounds retries for per-prefix exclusion route programming.
routeBudget = 1 * time.Second
)
var routeProtoFlag int
@@ -41,26 +44,42 @@ func init() {
}
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return r.setupAdvancedRouting()
}
log.Infof("Using legacy routing setup with ref counters")
return r.setupRefCounter(initAddresses, stateManager)
}
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return r.cleanupAdvancedRouting()
}
return r.cleanupRefCounter(stateManager)
}
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
// On darwin it also flushes residual RTF_IFSCOPE scoped default routes so a
// crashed prior session can't leave crud in the table.
func (r *SysOps) FlushMarkedRoutes() error {
var merr *multierror.Error
if err := r.flushPlatformExtras(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush platform extras: %w", err))
}
rib, err := retryFetchRIB()
if err != nil {
return fmt.Errorf("fetch routing table: %w", err)
return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("fetch routing table: %w", err)))
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return fmt.Errorf("parse routing table: %w", err)
return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("parse routing table: %w", err)))
}
var merr *multierror.Error
flushedCount := 0
for _, msg := range msgs {
@@ -117,12 +136,12 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
return fmt.Errorf("invalid prefix: %s", prefix)
}
expBackOff := backoff.NewExponentialBackOff()
expBackOff.InitialInterval = 50 * time.Millisecond
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = 1 * time.Second
msg, err := r.buildRouteMessage(action, prefix, nexthop)
if err != nil {
return fmt.Errorf("build route message: %w", err)
}
if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil {
if err := r.writeRouteMessage(msg, routeBudget); err != nil {
a := "add"
if action == unix.RTM_DELETE {
a = "remove"
@@ -132,50 +151,91 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
return nil
}
func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error {
operation := func() error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("open routing socket: %w", err)
// writeRouteMessage sends a route message over AF_ROUTE and waits for the
// kernel's matching reply, retrying transient failures until budget elapses.
// Callers do not need to manage sockets or seq numbers themselves.
func (r *SysOps) writeRouteMessage(msg *route.RouteMessage, budget time.Duration) error {
expBackOff := backoff.NewExponentialBackOff()
expBackOff.InitialInterval = 50 * time.Millisecond
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = budget
return backoff.Retry(func() error { return routeMessageRoundtrip(msg) }, expBackOff)
}
func routeMessageRoundtrip(msg *route.RouteMessage) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("open routing socket: %w", err)
}
defer func() {
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("close routing socket: %v", err)
}
defer func() {
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("failed to close routing socket: %v", err)
}()
tv := unix.Timeval{Sec: 1}
if err := unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil {
return backoff.Permanent(fmt.Errorf("set recv timeout: %w", err))
}
// AF_ROUTE is a broadcast channel: every route socket on the host sees
// every RTM_* event. With concurrent route programming the default
// per-socket queue overflows and our own reply gets dropped.
if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1<<20); err != nil {
log.Debugf("set SO_RCVBUF on route socket: %v", err)
}
bytes, err := msg.Marshal()
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal: %w", err))
}
if _, err = unix.Write(fd, bytes); err != nil {
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
return fmt.Errorf("write: %w", err)
}
return backoff.Permanent(fmt.Errorf("write: %w", err))
}
return readRouteResponse(fd, msg.Type, msg.Seq)
}
// readRouteResponse reads from the AF_ROUTE socket until it sees a reply
// matching our write (same type, seq, and pid). AF_ROUTE SOCK_RAW is a
// broadcast channel: interface up/down, third-party route changes and neighbor
// discovery events can all land between our write and read, so we must filter.
func readRouteResponse(fd, wantType, wantSeq int) error {
pid := int32(os.Getpid())
resp := make([]byte, 2048)
deadline := time.Now().Add(time.Second)
for {
if time.Now().After(deadline) {
// Transient: under concurrent pressure the kernel can drop our reply
// from the socket buffer. Let backoff.Retry re-send with a fresh seq.
return fmt.Errorf("read: timeout waiting for route reply type=%d seq=%d", wantType, wantSeq)
}
n, err := unix.Read(fd, resp)
if err != nil {
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) {
// SO_RCVTIMEO fired while waiting; loop to re-check the absolute deadline.
continue
}
}()
msg, err := r.buildRouteMessage(action, prefix, nexthop)
if err != nil {
return backoff.Permanent(fmt.Errorf("build route message: %w", err))
return backoff.Permanent(fmt.Errorf("read: %w", err))
}
msgBytes, err := msg.Marshal()
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal route message: %w", err))
if n < int(unsafe.Sizeof(unix.RtMsghdr{})) {
continue
}
if _, err = unix.Write(fd, msgBytes); err != nil {
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
return fmt.Errorf("write: %w", err)
}
return backoff.Permanent(fmt.Errorf("write: %w", err))
hdr := (*unix.RtMsghdr)(unsafe.Pointer(&resp[0]))
// Darwin reflects the sender's pid on replies; matching (Type, Seq, Pid)
// uniquely identifies our own reply among broadcast traffic.
if int(hdr.Type) != wantType || int(hdr.Seq) != wantSeq || hdr.Pid != pid {
continue
}
respBuf := make([]byte, 2048)
n, err := unix.Read(fd, respBuf)
if err != nil {
return backoff.Permanent(fmt.Errorf("read route response: %w", err))
if hdr.Errno != 0 {
return backoff.Permanent(fmt.Errorf("kernel: %w", syscall.Errno(hdr.Errno)))
}
if n > 0 {
if err := r.parseRouteResponse(respBuf[:n]); err != nil {
return backoff.Permanent(err)
}
}
return nil
}
return operation
}
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
@@ -183,6 +243,7 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
Type: action,
Flags: unix.RTF_UP | routeProtoFlag,
Version: unix.RTM_VERSION,
ID: uintptr(os.Getpid()),
Seq: r.getSeq(),
}
@@ -221,19 +282,6 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
return msg, nil
}
func (r *SysOps) parseRouteResponse(buf []byte) error {
if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) {
return nil
}
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
if rtMsg.Errno != 0 {
return fmt.Errorf("parse: %d", rtMsg.Errno)
}
return nil
}
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
if addr.Is4() {

View File

@@ -74,6 +74,14 @@ func New(filePath string) *Manager {
}
}
// FilePath returns the path of the underlying state file.
func (m *Manager) FilePath() string {
if m == nil {
return ""
}
return m.filePath
}
// Start starts the state manager periodic save routine
func (m *Manager) Start() {
if m == nil {

View File

@@ -0,0 +1,5 @@
package net
func (d *Dialer) init() {
d.Dialer.Control = applyBoundIfToSocket
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows
//go:build !linux && !windows && !darwin
package net

View File

@@ -1,24 +0,0 @@
//go:build android
package net
// Init initializes the network environment for Android
func Init() {
// No initialization needed on Android
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
// Always returns true on Android since we cannot handle routes dynamically.
func AdvancedRouting() bool {
return true
}
// SetVPNInterfaceName is a no-op on Android
func SetVPNInterfaceName(name string) {
// No-op on Android - not needed for Android VPN service
}
// GetVPNInterfaceName returns empty string on Android
func GetVPNInterfaceName() string {
return ""
}

View File

@@ -1,4 +1,4 @@
//go:build windows
//go:build (darwin && !ios) || windows
package net
@@ -24,17 +24,22 @@ func Init() {
}
func checkAdvancedRoutingSupport() bool {
var err error
var legacyRouting bool
legacyRouting := false
if val := os.Getenv(envUseLegacyRouting); val != "" {
legacyRouting, err = strconv.ParseBool(val)
parsed, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
log.Warnf("ignoring unparsable %s=%q: %v", envUseLegacyRouting, val, err)
} else {
legacyRouting = parsed
}
}
if legacyRouting || netstack.IsEnabled() {
log.Info("advanced routing has been requested to be disabled")
if legacyRouting {
log.Infof("advanced routing disabled: legacy routing requested via %s", envUseLegacyRouting)
return false
}
if netstack.IsEnabled() {
log.Info("advanced routing disabled: netstack mode is enabled")
return false
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows && !android
//go:build !linux && !windows && !darwin
package net

25
client/net/env_mobile.go Normal file
View File

@@ -0,0 +1,25 @@
//go:build ios || android
package net
// Init initializes the network environment for mobile platforms.
func Init() {
// no-op on mobile: routing scope is owned by the VPN extension.
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
// Always returns true on mobile since routes cannot be handled dynamically and the VPN extension
// owns the routing scope.
func AdvancedRouting() bool {
return true
}
// SetVPNInterfaceName is a no-op on mobile.
func SetVPNInterfaceName(string) {
// no-op on mobile: the VPN extension manages the interface.
}
// GetVPNInterfaceName returns an empty string on mobile.
func GetVPNInterfaceName() string {
return ""
}

View File

@@ -0,0 +1,5 @@
package net
func (l *ListenerConfig) init() {
l.ListenConfig.Control = applyBoundIfToSocket
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows
//go:build !linux && !windows && !darwin
package net

160
client/net/net_darwin.go Normal file
View File

@@ -0,0 +1,160 @@
package net
import (
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// On darwin IPV6_BOUND_IF also scopes v4-mapped egress from dual-stack
// (IPV6_V6ONLY=0) AF_INET6 sockets, so a single setsockopt on "udp6"/"tcp6"
// covers both families. Setting IP_BOUND_IF on an AF_INET6 socket returns
// EINVAL regardless of V6ONLY because the IPPROTO_IP ctloutput path is
// dispatched by socket domain (AF_INET only) not by inp_vflag.
// boundIface holds the physical interface chosen at routing setup time. Sockets
// created via nbnet.NewDialer / nbnet.NewListener bind to it via IP_BOUND_IF
// (IPv4) or IPV6_BOUND_IF (IPv6 / dual-stack) so their scoped route lookup
// hits the RTF_IFSCOPE default installed by the routemanager, rather than
// following the VPN's split default.
var (
boundIfaceMu sync.RWMutex
boundIface4 *net.Interface
boundIface6 *net.Interface
)
// SetBoundInterface records the egress interface for an address family. Called
// by the routemanager after a scoped default route has been installed.
// af must be unix.AF_INET or unix.AF_INET6; other values are ignored.
// nil iface is rejected — use ClearBoundInterfaces to clear all slots.
func SetBoundInterface(af int, iface *net.Interface) {
if iface == nil {
log.Warnf("SetBoundInterface: nil iface for AF %d, ignored", af)
return
}
boundIfaceMu.Lock()
defer boundIfaceMu.Unlock()
switch af {
case unix.AF_INET:
boundIface4 = iface
case unix.AF_INET6:
boundIface6 = iface
default:
log.Warnf("SetBoundInterface: unsupported address family %d", af)
}
}
// ClearBoundInterfaces resets the cached egress interfaces. Called by the
// routemanager during cleanup.
func ClearBoundInterfaces() {
boundIfaceMu.Lock()
defer boundIfaceMu.Unlock()
boundIface4 = nil
boundIface6 = nil
}
// boundInterfaceFor returns the cached egress interface for a socket's address
// family, falling back to the other family if the preferred slot is empty.
// The kernel stores both IP_BOUND_IF and IPV6_BOUND_IF in inp_boundifp, so
// either setsockopt scopes the socket; preferring same-family still matters
// when v4 and v6 defaults egress different NICs.
func boundInterfaceFor(network, address string) *net.Interface {
if iface := zoneInterface(address); iface != nil {
return iface
}
boundIfaceMu.RLock()
defer boundIfaceMu.RUnlock()
primary, secondary := boundIface4, boundIface6
if isV6Network(network) {
primary, secondary = boundIface6, boundIface4
}
if primary != nil {
return primary
}
return secondary
}
func isV6Network(network string) bool {
return strings.HasSuffix(network, "6")
}
// zoneInterface extracts an explicit interface from an IPv6 link-local zone (e.g. fe80::1%en0).
func zoneInterface(address string) *net.Interface {
if address == "" {
return nil
}
addr, err := netip.ParseAddrPort(address)
if err != nil {
a, err := netip.ParseAddr(address)
if err != nil {
return nil
}
addr = netip.AddrPortFrom(a, 0)
}
zone := addr.Addr().Zone()
if zone == "" {
return nil
}
if iface, err := net.InterfaceByName(zone); err == nil {
return iface
}
if idx, err := strconv.Atoi(zone); err == nil {
if iface, err := net.InterfaceByIndex(idx); err == nil {
return iface
}
}
return nil
}
func setIPv4BoundIf(fd uintptr, iface *net.Interface) error {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, iface.Index); err != nil {
return fmt.Errorf("set IP_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
}
return nil
}
func setIPv6BoundIf(fd uintptr, iface *net.Interface) error {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, iface.Index); err != nil {
return fmt.Errorf("set IPV6_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
}
return nil
}
// applyBoundIfToSocket binds the socket to the cached physical egress interface
// so scoped route lookup avoids the VPN utun and egresses the underlay directly.
func applyBoundIfToSocket(network, address string, c syscall.RawConn) error {
if !AdvancedRouting() {
return nil
}
iface := boundInterfaceFor(network, address)
if iface == nil {
log.Debugf("no bound iface cached for %s to %s, skipping BOUND_IF", network, address)
return nil
}
isV6 := isV6Network(network)
var controlErr error
if err := c.Control(func(fd uintptr) {
if isV6 {
controlErr = setIPv6BoundIf(fd, iface)
} else {
controlErr = setIPv4BoundIf(fd, iface)
}
if controlErr == nil {
log.Debugf("set BOUND_IF=%d on %s for %s to %s", iface.Index, iface.Name, network, address)
}
}); err != nil {
return fmt.Errorf("control: %w", err)
}
return controlErr
}

File diff suppressed because it is too large Load Diff

View File

@@ -209,6 +209,9 @@ message LoginRequest {
optional bool enableSSHRemotePortForwarding = 37;
optional bool disableSSHAuth = 38;
optional int32 sshJWTCacheTTL = 39;
optional bool serverVNCAllowed = 41;
optional bool disableVNCAuth = 42;
}
message LoginResponse {
@@ -316,6 +319,10 @@ message GetConfigResponse {
bool disableSSHAuth = 25;
int32 sshJWTCacheTTL = 26;
bool serverVNCAllowed = 28;
bool disableVNCAuth = 29;
}
// PeerState contains the latest state of a peer
@@ -394,6 +401,11 @@ message SSHServerState {
repeated SSHSessionInfo sessions = 2;
}
// VNCServerState contains the latest state of the VNC server
message VNCServerState {
bool enabled = 1;
}
// FullStatus contains the full state held by the Status instance
message FullStatus {
ManagementState managementState = 1;
@@ -408,6 +420,7 @@ message FullStatus {
bool lazyConnectionEnabled = 9;
SSHServerState sshServerState = 10;
VNCServerState vncServerState = 11;
}
// Networks
@@ -677,6 +690,9 @@ message SetConfigRequest {
optional bool enableSSHRemotePortForwarding = 32;
optional bool disableSSHAuth = 33;
optional int32 sshJWTCacheTTL = 34;
optional bool serverVNCAllowed = 36;
optional bool disableVNCAuth = 37;
}
message SetConfigResponse{}
@@ -727,6 +743,7 @@ message GetFeaturesRequest{}
message GetFeaturesResponse{
bool disable_profiles = 1;
bool disable_update_settings = 2;
bool disable_networks = 3;
}
message TriggerUpdateRequest {}

View File

@@ -9,6 +9,8 @@ import (
"strings"
"golang.org/x/exp/maps"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
@@ -27,6 +29,10 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
s.mutex.Lock()
defer s.mutex.Unlock()
if s.networksDisabled {
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
}
if s.connectClient == nil {
return nil, fmt.Errorf("not connected")
}
@@ -118,6 +124,10 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
s.mutex.Lock()
defer s.mutex.Unlock()
if s.networksDisabled {
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
}
if s.connectClient == nil {
return nil, fmt.Errorf("not connected")
}
@@ -164,6 +174,10 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
s.mutex.Lock()
defer s.mutex.Unlock()
if s.networksDisabled {
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
}
if s.connectClient == nil {
return nil, fmt.Errorf("not connected")
}

View File

@@ -53,6 +53,7 @@ const (
errRestoreResidualState = "failed to restore residual state: %v"
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
errNetworksDisabled = "network selection is disabled by the administrator"
)
var ErrServiceNotUp = errors.New("service is not up")
@@ -88,6 +89,7 @@ type Server struct {
profileManager *profilemanager.ServiceManager
profilesDisabled bool
updateSettingsDisabled bool
networksDisabled bool
sleepHandler *sleephandler.SleepHandler
@@ -104,7 +106,7 @@ type oauthAuthFlow struct {
}
// New server instance constructor.
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server {
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, networksDisabled bool) *Server {
s := &Server{
rootCtx: ctx,
logFile: logFile,
@@ -113,6 +115,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
profileManager: profilemanager.NewServiceManager(configFile),
profilesDisabled: profilesDisabled,
updateSettingsDisabled: updateSettingsDisabled,
networksDisabled: networksDisabled,
jwtCache: newJWTCache(),
}
agent := &serverAgent{s}
@@ -366,6 +369,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.RosenpassPermissive = msg.RosenpassPermissive
config.DisableAutoConnect = msg.DisableAutoConnect
config.ServerSSHAllowed = msg.ServerSSHAllowed
config.ServerVNCAllowed = msg.ServerVNCAllowed
config.NetworkMonitor = msg.NetworkMonitor
config.DisableClientRoutes = msg.DisableClientRoutes
config.DisableServerRoutes = msg.DisableServerRoutes
@@ -382,6 +386,9 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
if msg.DisableSSHAuth != nil {
config.DisableSSHAuth = msg.DisableSSHAuth
}
if msg.DisableVNCAuth != nil {
config.DisableVNCAuth = msg.DisableVNCAuth
}
if msg.SshJWTCacheTTL != nil {
ttl := int(*msg.SshJWTCacheTTL)
config.SSHJWTCacheTTL = &ttl
@@ -1120,6 +1127,7 @@ func (s *Server) Status(
pbFullStatus := fullStatus.ToProto()
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
pbFullStatus.SshServerState = s.getSSHServerState()
pbFullStatus.VncServerState = s.getVNCServerState()
statusResponse.FullStatus = pbFullStatus
}
@@ -1159,6 +1167,26 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
return sshServerState
}
// getVNCServerState retrieves the current VNC server state.
func (s *Server) getVNCServerState() *proto.VNCServerState {
s.mutex.Lock()
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return nil
}
engine := connectClient.Engine()
if engine == nil {
return nil
}
return &proto.VNCServerState{
Enabled: engine.GetVNCServerStatus(),
}
}
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
func (s *Server) GetPeerSSHHostKey(
ctx context.Context,
@@ -1500,6 +1528,11 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
disableSSHAuth = *cfg.DisableSSHAuth
}
disableVNCAuth := false
if cfg.DisableVNCAuth != nil {
disableVNCAuth = *cfg.DisableVNCAuth
}
sshJWTCacheTTL := int32(0)
if cfg.SSHJWTCacheTTL != nil {
sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL)
@@ -1514,6 +1547,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed,
ServerVNCAllowed: cfg.ServerVNCAllowed != nil && *cfg.ServerVNCAllowed,
RosenpassEnabled: cfg.RosenpassEnabled,
RosenpassPermissive: cfg.RosenpassPermissive,
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
@@ -1529,6 +1563,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
DisableSSHAuth: disableSSHAuth,
DisableVNCAuth: disableVNCAuth,
SshJWTCacheTTL: sshJWTCacheTTL,
}, nil
}
@@ -1628,6 +1663,7 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
features := &proto.GetFeaturesResponse{
DisableProfiles: s.checkProfilesDisabled(),
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
DisableNetworks: s.networksDisabled,
}
return features, nil

View File

@@ -36,6 +36,7 @@ import (
daemonProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -103,7 +104,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Fatalf("failed to set active profile state: %v", err)
}
s := New(ctx, "debug", "", false, false)
s := New(ctx, "debug", "", false, false, false)
s.config = config
@@ -164,7 +165,7 @@ func TestServer_Up(t *testing.T) {
t.Fatalf("failed to set active profile state: %v", err)
}
s := New(ctx, "console", "", false, false)
s := New(ctx, "console", "", false, false, false)
err = s.Start()
require.NoError(t, err)
@@ -234,7 +235,7 @@ func TestServer_SubcribeEvents(t *testing.T) {
t.Fatalf("failed to set active profile state: %v", err)
}
s := New(ctx, "console", "", false, false)
s := New(ctx, "console", "", false, false, false)
err = s.Start()
require.NoError(t, err)
@@ -309,7 +310,12 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
jobManager := job.NewJobManager(nil, store, peersManager)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
@@ -320,7 +326,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
if err != nil {
return nil, "", err
}

View File

@@ -53,11 +53,13 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.NoError(t, err)
ctx := context.Background()
s := New(ctx, "console", "", false, false)
s := New(ctx, "console", "", false, false, false)
rosenpassEnabled := true
rosenpassPermissive := true
serverSSHAllowed := true
serverVNCAllowed := true
disableVNCAuth := true
interfaceName := "utun100"
wireguardPort := int64(51820)
preSharedKey := "test-psk"
@@ -82,6 +84,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
RosenpassEnabled: &rosenpassEnabled,
RosenpassPermissive: &rosenpassPermissive,
ServerSSHAllowed: &serverSSHAllowed,
ServerVNCAllowed: &serverVNCAllowed,
DisableVNCAuth: &disableVNCAuth,
InterfaceName: &interfaceName,
WireguardPort: &wireguardPort,
OptionalPreSharedKey: &preSharedKey,
@@ -125,6 +129,10 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
require.NotNil(t, cfg.ServerSSHAllowed)
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
require.NotNil(t, cfg.ServerVNCAllowed)
require.Equal(t, serverVNCAllowed, *cfg.ServerVNCAllowed)
require.NotNil(t, cfg.DisableVNCAuth)
require.Equal(t, disableVNCAuth, *cfg.DisableVNCAuth)
require.Equal(t, interfaceName, cfg.WgIface)
require.Equal(t, int(wireguardPort), cfg.WgPort)
require.Equal(t, preSharedKey, cfg.PreSharedKey)
@@ -176,6 +184,8 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"ServerVNCAllowed": true,
"DisableVNCAuth": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
@@ -236,6 +246,8 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
"enable-rosenpass": "RosenpassEnabled",
"rosenpass-permissive": "RosenpassPermissive",
"allow-server-ssh": "ServerSSHAllowed",
"allow-server-vnc": "ServerVNCAllowed",
"disable-vnc-auth": "DisableVNCAuth",
"interface-name": "InterfaceName",
"wireguard-port": "WireguardPort",
"preshared-key": "OptionalPreSharedKey",

View File

@@ -12,7 +12,6 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/proto"
)
@@ -138,10 +137,8 @@ func restoreResidualState(ctx context.Context, statePath string) error {
}
// clean up any remaining routes independently of the state file
if !nbnet.AdvancedRouting() {
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
}
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
}
return nberrors.FormatErrorOrNil(merr)

View File

@@ -187,24 +187,23 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
return "", fmt.Errorf("get NetBird executable path: %w", err)
}
hostLine := strings.Join(deduplicatedPatterns, " ")
config := fmt.Sprintf("Host %s\n", hostLine)
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
config += " PasswordAuthentication yes\n"
config += " PubkeyAuthentication yes\n"
config += " BatchMode no\n"
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
config += " StrictHostKeyChecking no\n"
hostList := strings.Join(deduplicatedPatterns, ",")
config := fmt.Sprintf("Match host \"%s\" exec \"%s ssh detect %%h %%p\"\n", hostList, execPath)
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
config += " PasswordAuthentication yes\n"
config += " PubkeyAuthentication yes\n"
config += " BatchMode no\n"
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
config += " StrictHostKeyChecking no\n"
if runtime.GOOS == "windows" {
config += " UserKnownHostsFile NUL\n"
config += " UserKnownHostsFile NUL\n"
} else {
config += " UserKnownHostsFile /dev/null\n"
config += " UserKnownHostsFile /dev/null\n"
}
config += " CheckHostIP no\n"
config += " LogLevel ERROR\n\n"
config += " CheckHostIP no\n"
config += " LogLevel ERROR\n\n"
return config, nil
}

View File

@@ -116,6 +116,37 @@ func TestManager_PeerLimit(t *testing.T) {
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
}
func TestManager_MatchHostFormat(t *testing.T) {
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
require.NoError(t, err)
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
}
peers := []PeerSSHInfo{
{Hostname: "peer1", IP: "100.125.1.1", FQDN: "peer1.nb.internal"},
{Hostname: "peer2", IP: "100.125.1.2", FQDN: "peer2.nb.internal"},
}
err = manager.SetupSSHClientConfig(peers)
require.NoError(t, err)
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
content, err := os.ReadFile(configPath)
require.NoError(t, err)
configStr := string(content)
// Must use "Match host" with comma-separated patterns, not a bare "Host" directive.
// A bare "Host" followed by "Match exec" is incorrect per ssh_config(5): the Host block
// ends at the next Match keyword, making it a no-op and leaving the Match exec unscoped.
assert.NotContains(t, configStr, "\nHost ", "should not use bare Host directive")
assert.Contains(t, configStr, "Match host \"100.125.1.1,peer1.nb.internal,peer1,100.125.1.2,peer2.nb.internal,peer2\"",
"should use Match host with comma-separated patterns")
}
func TestManager_ForcedSSHConfig(t *testing.T) {
// Set force environment variable
t.Setenv(EnvForceSSHConfig, "true")

View File

@@ -200,8 +200,8 @@ func newLsaString(s string) lsaString {
}
}
// generateS4UUserToken creates a Windows token using S4U authentication
// This is the exact approach OpenSSH for Windows uses for public key authentication
// generateS4UUserToken creates a Windows token using S4U authentication.
// This is the same approach OpenSSH for Windows uses for public key authentication.
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
userCpn := buildUserCpn(username, domain)

View File

@@ -507,27 +507,7 @@ func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
maxTokenAge = DefaultJWTMaxTokenAge
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
}
iat, ok := claims["iat"].(float64)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token missing iat claim (user=%s)", userID)
}
issuedAt := time.Unix(int64(iat), 0)
tokenAge := time.Since(issuedAt)
maxAge := time.Duration(maxTokenAge) * time.Second
if tokenAge > maxAge {
userID := getUserIDFromClaims(claims)
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
}
return nil
return jwt.CheckTokenAge(token, time.Duration(maxTokenAge)*time.Second)
}
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
@@ -558,27 +538,7 @@ func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool {
}
func extractUserID(token *gojwt.Token) string {
if token == nil {
return "unknown"
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
return "unknown"
}
return getUserIDFromClaims(claims)
}
func getUserIDFromClaims(claims gojwt.MapClaims) string {
if sub, ok := claims["sub"].(string); ok && sub != "" {
return sub
}
if userID, ok := claims["user_id"].(string); ok && userID != "" {
return userID
}
if email, ok := claims["email"].(string); ok && email != "" {
return email
}
return "unknown"
return jwt.UserIDFromToken(token)
}
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {

View File

@@ -130,6 +130,10 @@ type SSHServerStateOutput struct {
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
}
type VNCServerStateOutput struct {
Enabled bool `json:"enabled" yaml:"enabled"`
}
type OutputOverview struct {
Peers PeersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
@@ -151,6 +155,7 @@ type OutputOverview struct {
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
ProfileName string `json:"profileName" yaml:"profileName"`
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
VNCServerState VNCServerStateOutput `json:"vncServer" yaml:"vncServer"`
}
// ConvertToStatusOutputOverview converts protobuf status to the output overview.
@@ -171,6 +176,9 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
relayOverview := mapRelays(pbFullStatus.GetRelays())
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
vncServerOverview := VNCServerStateOutput{
Enabled: pbFullStatus.GetVncServerState().GetEnabled(),
}
peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter)
overview := OutputOverview{
@@ -194,6 +202,7 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
ProfileName: opts.ProfileName,
SSHServerState: sshServerOverview,
VNCServerState: vncServerOverview,
}
if opts.Anonymize {
@@ -524,6 +533,11 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
}
}
vncServerStatus := "Disabled"
if o.VNCServerState.Enabled {
vncServerStatus = "Enabled"
}
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
var forwardingRulesString string
@@ -553,6 +567,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
"Quantum resistance: %s\n"+
"Lazy connection: %s\n"+
"SSH Server: %s\n"+
"VNC Server: %s\n"+
"Networks: %s\n"+
"%s"+
"Peers count: %s\n",
@@ -570,6 +585,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
rosenpassEnabledStatus,
lazyConnectionEnabledStatus,
sshServerStatus,
vncServerStatus,
networks,
forwardingRulesString,
peersCountString,

View File

@@ -398,6 +398,9 @@ func TestParsingToJSON(t *testing.T) {
"sshServer":{
"enabled":false,
"sessions":[]
},
"vncServer":{
"enabled":false
}
}`
// @formatter:on
@@ -505,6 +508,8 @@ profileName: ""
sshServer:
enabled: false
sessions: []
vncServer:
enabled: false
`
assert.Equal(t, expectedYAML, yaml)
@@ -572,6 +577,7 @@ Interface type: Kernel
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
VNC Server: Disabled
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
@@ -596,6 +602,7 @@ Interface type: Kernel
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
VNC Server: Disabled
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`

View File

@@ -2,7 +2,6 @@ package system
import (
"context"
"net"
"net/netip"
"strings"
@@ -63,6 +62,7 @@ type Info struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed bool
ServerVNCAllowed bool
DisableClientRoutes bool
DisableServerRoutes bool
@@ -78,21 +78,27 @@ type Info struct {
EnableSSHLocalPortForwarding bool
EnableSSHRemotePortForwarding bool
DisableSSHAuth bool
DisableVNCAuth bool
}
func (i *Info) SetFlags(
rosenpassEnabled, rosenpassPermissive bool,
serverSSHAllowed *bool,
serverVNCAllowed *bool,
disableClientRoutes, disableServerRoutes,
disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool,
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
disableSSHAuth *bool,
disableVNCAuth *bool,
) {
i.RosenpassEnabled = rosenpassEnabled
i.RosenpassPermissive = rosenpassPermissive
if serverSSHAllowed != nil {
i.ServerSSHAllowed = *serverSSHAllowed
}
if serverVNCAllowed != nil {
i.ServerVNCAllowed = *serverVNCAllowed
}
i.DisableClientRoutes = disableClientRoutes
i.DisableServerRoutes = disableServerRoutes
@@ -118,6 +124,9 @@ func (i *Info) SetFlags(
if disableSSHAuth != nil {
i.DisableSSHAuth = *disableSSHAuth
}
if disableVNCAuth != nil {
i.DisableVNCAuth = *disableVNCAuth
}
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
@@ -145,59 +154,6 @@ func extractDeviceName(ctx context.Context, defaultName string) string {
return v
}
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
if iface.HardwareAddr.String() == "" {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, address := range addrs {
ipNet, ok := address.(*net.IPNet)
if !ok {
continue
}
if ipNet.IP.IsLoopback() {
continue
}
netAddr := NetworkAddress{
NetIP: netip.MustParsePrefix(ipNet.String()),
Mac: iface.HardwareAddr.String(),
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}
// GetInfoWithChecks retrieves and parses the system information with applied checks.
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
log.Debugf("gathering system information with checks: %d", len(checks))

View File

@@ -2,12 +2,16 @@ package system
import (
"context"
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/version"
)
// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update
// UpdateStaticInfoAsync is a no-op on iOS as there is no static info to update
func UpdateStaticInfoAsync() {
// do nothing
}
@@ -15,11 +19,24 @@ func UpdateStaticInfoAsync() {
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
// Convert fixed-size byte arrays to Go strings
sysName := extractOsName(ctx, "sysName")
swVersion := extractOsVersion(ctx, "swVersion")
gio := &Info{Kernel: sysName, OSVersion: swVersion, Platform: "unknown", OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: swVersion}
addrs, err := networkAddresses()
if err != nil {
log.Warnf("failed to discover network addresses: %s", err)
}
gio := &Info{
Kernel: sysName,
OSVersion: swVersion,
Platform: "unknown",
OS: sysName,
GoOS: runtime.GOOS,
CPUs: runtime.NumCPU(),
KernelVersion: swVersion,
NetworkAddresses: addrs,
}
gio.Hostname = extractDeviceName(ctx, "hostname")
gio.NetbirdVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx)
@@ -27,6 +44,66 @@ func GetInfo(ctx context.Context) *Info {
return gio
}
// networkAddresses returns the list of network addresses on iOS.
// On iOS, hardware (MAC) addresses are not available due to Apple's privacy
// restrictions (iOS returns a fixed 02:00:00:00:00:00 placeholder), so we
// leave Mac empty to match Android's behavior. We also skip the HardwareAddr
// check that other platforms use and filter out link-local addresses as they
// are not useful for posture checks.
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, address := range addrs {
netAddr, ok := toNetworkAddress(address)
if !ok {
continue
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func toNetworkAddress(address net.Addr) (NetworkAddress, bool) {
ipNet, ok := address.(*net.IPNet)
if !ok {
return NetworkAddress{}, false
}
if ipNet.IP.IsLoopback() || ipNet.IP.IsLinkLocalUnicast() || ipNet.IP.IsMulticast() {
return NetworkAddress{}, false
}
prefix, err := netip.ParsePrefix(ipNet.String())
if err != nil {
return NetworkAddress{}, false
}
return NetworkAddress{NetIP: prefix, Mac: ""}, true
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil

View File

@@ -0,0 +1,66 @@
//go:build !ios
package system
import (
"net"
"net/netip"
)
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
if iface.HardwareAddr.String() == "" {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
mac := iface.HardwareAddr.String()
for _, address := range addrs {
netAddr, ok := toNetworkAddress(address, mac)
if !ok {
continue
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func toNetworkAddress(address net.Addr, mac string) (NetworkAddress, bool) {
ipNet, ok := address.(*net.IPNet)
if !ok {
return NetworkAddress{}, false
}
if ipNet.IP.IsLoopback() {
return NetworkAddress{}, false
}
prefix, err := netip.ParsePrefix(ipNet.String())
if err != nil {
return NetworkAddress{}, false
}
return NetworkAddress{NetIP: prefix, Mac: mac}, true
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}

View File

@@ -314,6 +314,7 @@ type serviceClient struct {
lastNotifiedVersion string
settingsEnabled bool
profilesEnabled bool
networksEnabled bool
showNetworks bool
wNetworks fyne.Window
wProfiles fyne.Window
@@ -368,6 +369,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
showAdvancedSettings: args.showSettings,
showNetworks: args.showNetworks,
networksEnabled: true,
}
s.eventHandler = newEventHandler(s)
@@ -920,8 +922,10 @@ func (s *serviceClient) updateStatus() error {
s.mStatus.SetIcon(s.icConnectedDot)
s.mUp.Disable()
s.mDown.Enable()
s.mNetworks.Enable()
s.mExitNode.Enable()
if s.networksEnabled {
s.mNetworks.Enable()
s.mExitNode.Enable()
}
s.startExitNodeRefresh()
systrayIconState = true
case status.Status == string(internal.StatusConnecting):
@@ -1093,14 +1097,14 @@ func (s *serviceClient) onTrayReady() {
s.getSrvConfig()
time.Sleep(100 * time.Millisecond) // To prevent race condition caused by systray not being fully initialized and ignoring setIcon
for {
// Check features before status so menus respect disable flags before being enabled
s.checkAndUpdateFeatures()
err := s.updateStatus()
if err != nil {
log.Errorf("error while updating status: %v", err)
}
// Check features periodically to handle daemon restarts
s.checkAndUpdateFeatures()
time.Sleep(2 * time.Second)
}
}()
@@ -1299,6 +1303,16 @@ func (s *serviceClient) checkAndUpdateFeatures() {
s.mProfile.setEnabled(profilesEnabled)
}
}
// Update networks and exit node menus based on current features
s.networksEnabled = features == nil || !features.DisableNetworks
if s.networksEnabled && s.connected {
s.mNetworks.Enable()
s.mExitNode.Enable()
} else {
s.mNetworks.Disable()
s.mExitNode.Disable()
}
}
// getFeatures from the daemon to determine which features are enabled/disabled.

View File

@@ -0,0 +1,474 @@
//go:build windows
package server
import (
crand "crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"os"
"sync"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
agentPort = "15900"
// agentTokenLen is the length of the random authentication token
// used to verify that connections to the agent come from the service.
agentTokenLen = 32
stillActive = 259
tokenPrimary = 1
securityImpersonation = 2
tokenSessionID = 12
createUnicodeEnvironment = 0x00000400
createNoWindow = 0x08000000
)
var (
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
advapi32 = windows.NewLazySystemDLL("advapi32.dll")
userenv = windows.NewLazySystemDLL("userenv.dll")
procWTSGetActiveConsoleSessionId = kernel32.NewProc("WTSGetActiveConsoleSessionId")
procSetTokenInformation = advapi32.NewProc("SetTokenInformation")
procCreateEnvironmentBlock = userenv.NewProc("CreateEnvironmentBlock")
procDestroyEnvironmentBlock = userenv.NewProc("DestroyEnvironmentBlock")
wtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll")
procWTSEnumerateSessionsW = wtsapi32.NewProc("WTSEnumerateSessionsW")
procWTSFreeMemory = wtsapi32.NewProc("WTSFreeMemory")
)
// GetCurrentSessionID returns the session ID of the current process.
func GetCurrentSessionID() uint32 {
var token windows.Token
if err := windows.OpenProcessToken(windows.CurrentProcess(),
windows.TOKEN_QUERY, &token); err != nil {
return 0
}
defer token.Close()
var id uint32
var ret uint32
_ = windows.GetTokenInformation(token, windows.TokenSessionId,
(*byte)(unsafe.Pointer(&id)), 4, &ret)
return id
}
func getConsoleSessionID() uint32 {
r, _, _ := procWTSGetActiveConsoleSessionId.Call()
return uint32(r)
}
const (
wtsActive = 0
wtsConnected = 1
wtsDisconnected = 4
)
type wtsSessionInfo struct {
SessionID uint32
WinStationName [66]byte // actually *uint16, but we just need the struct size
State uint32
}
// getActiveSessionID returns the session ID of the best session to attach to.
// Prefers an active (logged-in, interactive) session over the console session.
// This avoids kicking out an RDP user when the console is at the login screen.
func getActiveSessionID() uint32 {
var sessionInfo uintptr
var count uint32
r, _, _ := procWTSEnumerateSessionsW.Call(
0, // WTS_CURRENT_SERVER_HANDLE
0, // reserved
1, // version
uintptr(unsafe.Pointer(&sessionInfo)),
uintptr(unsafe.Pointer(&count)),
)
if r == 0 || count == 0 {
return getConsoleSessionID()
}
defer procWTSFreeMemory.Call(sessionInfo)
type wtsSession struct {
SessionID uint32
Station *uint16
State uint32
}
sessions := unsafe.Slice((*wtsSession)(unsafe.Pointer(sessionInfo)), count)
// Find the first active session (not session 0, which is the services session).
var bestID uint32
found := false
for _, s := range sessions {
if s.SessionID == 0 {
continue
}
if s.State == wtsActive {
bestID = s.SessionID
found = true
break
}
}
if !found {
return getConsoleSessionID()
}
return bestID
}
// getSystemTokenForSession duplicates the current SYSTEM token and sets its
// session ID so the spawned process runs in the target session. Using a SYSTEM
// token gives access to both Default and Winlogon desktops plus UIPI bypass.
func getSystemTokenForSession(sessionID uint32) (windows.Token, error) {
var cur windows.Token
if err := windows.OpenProcessToken(windows.CurrentProcess(),
windows.MAXIMUM_ALLOWED, &cur); err != nil {
return 0, fmt.Errorf("OpenProcessToken: %w", err)
}
defer cur.Close()
var dup windows.Token
if err := windows.DuplicateTokenEx(cur, windows.MAXIMUM_ALLOWED, nil,
securityImpersonation, tokenPrimary, &dup); err != nil {
return 0, fmt.Errorf("DuplicateTokenEx: %w", err)
}
sid := sessionID
r, _, err := procSetTokenInformation.Call(
uintptr(dup),
uintptr(tokenSessionID),
uintptr(unsafe.Pointer(&sid)),
unsafe.Sizeof(sid),
)
if r == 0 {
dup.Close()
return 0, fmt.Errorf("SetTokenInformation(SessionId=%d): %w", sessionID, err)
}
return dup, nil
}
const agentTokenEnvVar = "NB_VNC_AGENT_TOKEN"
// injectEnvVar appends a KEY=VALUE entry to a Unicode environment block.
// The block is a sequence of null-terminated UTF-16 strings, terminated by
// an extra null. Returns a new block pointer with the entry added.
func injectEnvVar(envBlock uintptr, key, value string) uintptr {
entry := key + "=" + value
// Walk the existing block to find its total length.
ptr := (*uint16)(unsafe.Pointer(envBlock))
var totalChars int
for {
ch := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars)*2))
if ch == 0 {
// Check for double-null terminator.
next := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars+1)*2))
totalChars++
if next == 0 {
// End of block (don't count the final null yet, we'll rebuild).
break
}
} else {
totalChars++
}
}
entryUTF16, _ := windows.UTF16FromString(entry)
// New block: existing entries + new entry (null-terminated) + final null.
newLen := totalChars + len(entryUTF16) + 1
newBlock := make([]uint16, newLen)
// Copy existing entries (up to but not including the final null).
for i := range totalChars {
newBlock[i] = *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(i)*2))
}
copy(newBlock[totalChars:], entryUTF16)
newBlock[newLen-1] = 0 // final null terminator
return uintptr(unsafe.Pointer(&newBlock[0]))
}
func spawnAgentInSession(sessionID uint32, port string, authToken string) (windows.Handle, error) {
token, err := getSystemTokenForSession(sessionID)
if err != nil {
return 0, fmt.Errorf("get SYSTEM token for session %d: %w", sessionID, err)
}
defer token.Close()
var envBlock uintptr
r, _, _ := procCreateEnvironmentBlock.Call(
uintptr(unsafe.Pointer(&envBlock)),
uintptr(token),
0,
)
if r != 0 {
defer procDestroyEnvironmentBlock.Call(envBlock)
}
// Inject the auth token into the environment block so it doesn't appear
// in the process command line (visible via tasklist/wmic).
if r != 0 {
envBlock = injectEnvVar(envBlock, agentTokenEnvVar, authToken)
}
exePath, err := os.Executable()
if err != nil {
return 0, fmt.Errorf("get executable path: %w", err)
}
cmdLine := fmt.Sprintf(`"%s" vnc-agent --port %s`, exePath, port)
cmdLineW, err := windows.UTF16PtrFromString(cmdLine)
if err != nil {
return 0, fmt.Errorf("UTF16 cmdline: %w", err)
}
// Create an inheritable pipe for the agent's stderr so we can relog
// its output in the service process.
var sa windows.SecurityAttributes
sa.Length = uint32(unsafe.Sizeof(sa))
sa.InheritHandle = 1
var stderrRead, stderrWrite windows.Handle
if err := windows.CreatePipe(&stderrRead, &stderrWrite, &sa, 0); err != nil {
return 0, fmt.Errorf("create stderr pipe: %w", err)
}
// The read end must NOT be inherited by the child.
windows.SetHandleInformation(stderrRead, windows.HANDLE_FLAG_INHERIT, 0)
desktop, _ := windows.UTF16PtrFromString(`WinSta0\Default`)
si := windows.StartupInfo{
Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})),
Desktop: desktop,
Flags: windows.STARTF_USESHOWWINDOW | windows.STARTF_USESTDHANDLES,
ShowWindow: 0,
StdErr: stderrWrite,
StdOutput: stderrWrite,
}
var pi windows.ProcessInformation
var envPtr *uint16
if envBlock != 0 {
envPtr = (*uint16)(unsafe.Pointer(envBlock))
}
err = windows.CreateProcessAsUser(
token, nil, cmdLineW,
nil, nil, true, // inheritHandles=true for the pipe
createUnicodeEnvironment|createNoWindow,
envPtr, nil, &si, &pi,
)
// Close the write end in the parent so reads will get EOF when the child exits.
windows.CloseHandle(stderrWrite)
if err != nil {
windows.CloseHandle(stderrRead)
return 0, fmt.Errorf("CreateProcessAsUser: %w", err)
}
windows.CloseHandle(pi.Thread)
// Relog agent output in the service with a [vnc-agent] prefix.
go relogAgentOutput(stderrRead)
log.Infof("spawned agent PID=%d in session %d on port %s", pi.ProcessId, sessionID, port)
return pi.Process, nil
}
// sessionManager monitors the active console session and ensures a VNC agent
// process is running in it. When the session changes (e.g., user switch, RDP
// connect/disconnect), it kills the old agent and spawns a new one.
type sessionManager struct {
port string
mu sync.Mutex
agentProc windows.Handle
sessionID uint32
authToken string
done chan struct{}
}
func newSessionManager(port string) *sessionManager {
return &sessionManager{port: port, sessionID: ^uint32(0), done: make(chan struct{})}
}
// generateAuthToken creates a new random hex token for agent authentication.
func generateAuthToken() string {
b := make([]byte, agentTokenLen)
if _, err := crand.Read(b); err != nil {
log.Warnf("generate agent auth token: %v", err)
return ""
}
return hex.EncodeToString(b)
}
// AuthToken returns the current agent authentication token.
func (m *sessionManager) AuthToken() string {
m.mu.Lock()
defer m.mu.Unlock()
return m.authToken
}
// Stop signals the session manager to exit its polling loop.
func (m *sessionManager) Stop() {
select {
case <-m.done:
default:
close(m.done)
}
}
func (m *sessionManager) run() {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
sid := getActiveSessionID()
m.mu.Lock()
if sid != m.sessionID {
log.Infof("active session changed: %d -> %d", m.sessionID, sid)
m.killAgent()
m.sessionID = sid
}
if m.agentProc != 0 {
var code uint32
_ = windows.GetExitCodeProcess(m.agentProc, &code)
if code != stillActive {
log.Infof("agent exited (code=%d), respawning", code)
windows.CloseHandle(m.agentProc)
m.agentProc = 0
}
}
if m.agentProc == 0 && sid != 0xFFFFFFFF {
m.authToken = generateAuthToken()
h, err := spawnAgentInSession(sid, m.port, m.authToken)
if err != nil {
log.Warnf("spawn agent in session %d: %v", sid, err)
m.authToken = ""
} else {
m.agentProc = h
}
}
m.mu.Unlock()
select {
case <-m.done:
m.mu.Lock()
m.killAgent()
m.mu.Unlock()
return
case <-ticker.C:
}
}
}
func (m *sessionManager) killAgent() {
if m.agentProc != 0 {
_ = windows.TerminateProcess(m.agentProc, 0)
windows.CloseHandle(m.agentProc)
m.agentProc = 0
log.Info("killed old agent")
}
}
// relogAgentOutput reads JSON log lines from the agent's stderr pipe and
// relogs them at the correct level with the service's formatter.
func relogAgentOutput(pipe windows.Handle) {
defer windows.CloseHandle(pipe)
f := os.NewFile(uintptr(pipe), "vnc-agent-stderr")
defer f.Close()
entry := log.WithField("component", "vnc-agent")
dec := json.NewDecoder(f)
for dec.More() {
var m map[string]any
if err := dec.Decode(&m); err != nil {
break
}
msg, _ := m["msg"].(string)
if msg == "" {
continue
}
// Forward extra fields from the agent (skip standard logrus fields).
// Remap "caller" to "source" so it doesn't conflict with logrus internals
// but still shows the original file/line from the agent process.
fields := make(log.Fields)
for k, v := range m {
switch k {
case "msg", "level", "time", "func":
continue
case "caller":
fields["source"] = v
default:
fields[k] = v
}
}
e := entry.WithFields(fields)
switch m["level"] {
case "error":
e.Error(msg)
case "warning":
e.Warn(msg)
case "debug":
e.Debug(msg)
case "trace":
e.Trace(msg)
default:
e.Info(msg)
}
}
}
// proxyToAgent connects to the agent, sends the auth token, then proxies
// the VNC client connection bidirectionally.
func proxyToAgent(client net.Conn, port string, authToken string) {
defer client.Close()
addr := "127.0.0.1:" + port
var agentConn net.Conn
var err error
for range 50 {
agentConn, err = net.DialTimeout("tcp", addr, time.Second)
if err == nil {
break
}
time.Sleep(200 * time.Millisecond)
}
if err != nil {
log.Warnf("proxy cannot reach agent at %s: %v", addr, err)
return
}
defer agentConn.Close()
// Send the auth token so the agent can verify this connection
// comes from the trusted service process.
tokenBytes, _ := hex.DecodeString(authToken)
if _, err := agentConn.Write(tokenBytes); err != nil {
log.Warnf("send auth token to agent: %v", err)
return
}
log.Debugf("proxy connected to agent, starting bidirectional copy")
done := make(chan struct{}, 2)
cp := func(label string, dst, src net.Conn) {
n, err := io.Copy(dst, src)
log.Debugf("proxy %s: %d bytes, err=%v", label, n, err)
done <- struct{}{}
}
go cp("client→agent", agentConn, client)
go cp("agent→client", client, agentConn)
<-done
}

View File

@@ -0,0 +1,486 @@
//go:build darwin && !ios
package server
import (
"errors"
"fmt"
"hash/maphash"
"image"
"os"
"runtime"
"strconv"
"sync"
"time"
"unsafe"
"github.com/ebitengine/purego"
log "github.com/sirupsen/logrus"
)
var darwinCaptureOnce sync.Once
var (
cgMainDisplayID func() uint32
cgDisplayPixelsWide func(uint32) uintptr
cgDisplayPixelsHigh func(uint32) uintptr
cgDisplayCreateImage func(uint32) uintptr
cgImageGetWidth func(uintptr) uintptr
cgImageGetHeight func(uintptr) uintptr
cgImageGetBytesPerRow func(uintptr) uintptr
cgImageGetBitsPerPixel func(uintptr) uintptr
cgImageGetDataProvider func(uintptr) uintptr
cgDataProviderCopyData func(uintptr) uintptr
cgImageRelease func(uintptr)
cfDataGetLength func(uintptr) int64
cfDataGetBytePtr func(uintptr) uintptr
cfRelease func(uintptr)
cgPreflightScreenCaptureAccess func() bool
cgRequestScreenCaptureAccess func() bool
darwinCaptureReady bool
)
func initDarwinCapture() {
darwinCaptureOnce.Do(func() {
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
log.Debugf("load CoreGraphics: %v", err)
return
}
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
log.Debugf("load CoreFoundation: %v", err)
return
}
purego.RegisterLibFunc(&cgMainDisplayID, cg, "CGMainDisplayID")
purego.RegisterLibFunc(&cgDisplayPixelsWide, cg, "CGDisplayPixelsWide")
purego.RegisterLibFunc(&cgDisplayPixelsHigh, cg, "CGDisplayPixelsHigh")
purego.RegisterLibFunc(&cgDisplayCreateImage, cg, "CGDisplayCreateImage")
purego.RegisterLibFunc(&cgImageGetWidth, cg, "CGImageGetWidth")
purego.RegisterLibFunc(&cgImageGetHeight, cg, "CGImageGetHeight")
purego.RegisterLibFunc(&cgImageGetBytesPerRow, cg, "CGImageGetBytesPerRow")
purego.RegisterLibFunc(&cgImageGetBitsPerPixel, cg, "CGImageGetBitsPerPixel")
purego.RegisterLibFunc(&cgImageGetDataProvider, cg, "CGImageGetDataProvider")
purego.RegisterLibFunc(&cgDataProviderCopyData, cg, "CGDataProviderCopyData")
purego.RegisterLibFunc(&cgImageRelease, cg, "CGImageRelease")
purego.RegisterLibFunc(&cfDataGetLength, cf, "CFDataGetLength")
purego.RegisterLibFunc(&cfDataGetBytePtr, cf, "CFDataGetBytePtr")
purego.RegisterLibFunc(&cfRelease, cf, "CFRelease")
// Screen capture permission APIs (macOS 11+). Might not exist on older versions.
if sym, err := purego.Dlsym(cg, "CGPreflightScreenCaptureAccess"); err == nil {
purego.RegisterFunc(&cgPreflightScreenCaptureAccess, sym)
}
if sym, err := purego.Dlsym(cg, "CGRequestScreenCaptureAccess"); err == nil {
purego.RegisterFunc(&cgRequestScreenCaptureAccess, sym)
}
darwinCaptureReady = true
})
}
// errFrameUnchanged signals that the raw capture bytes matched the previous
// frame, so the caller can skip the expensive BGRA to RGBA conversion.
var errFrameUnchanged = errors.New("frame unchanged")
// CGCapturer captures the macOS main display using Core Graphics.
type CGCapturer struct {
displayID uint32
w, h int
// downscale is 1 for pixel-perfect, 2 for Retina 2:1 box-filter downscale.
downscale int
hashSeed maphash.Seed
lastHash uint64
hasHash bool
}
// NewCGCapturer creates a screen capturer for the main display.
func NewCGCapturer() (*CGCapturer, error) {
initDarwinCapture()
if !darwinCaptureReady {
return nil, fmt.Errorf("CoreGraphics not available")
}
// Request Screen Recording permission (shows system dialog on macOS 11+).
if cgPreflightScreenCaptureAccess != nil && !cgPreflightScreenCaptureAccess() {
if cgRequestScreenCaptureAccess != nil {
cgRequestScreenCaptureAccess()
}
openPrivacyPane("Privacy_ScreenCapture")
log.Warn("Screen Recording permission not granted. " +
"Opened System Settings > Privacy & Security > Screen Recording; enable netbird and restart.")
}
displayID := cgMainDisplayID()
c := &CGCapturer{displayID: displayID, downscale: 1, hashSeed: maphash.MakeSeed()}
// Probe actual pixel dimensions via a test capture. CGDisplayPixelsWide/High
// returns logical points on Retina, but CGDisplayCreateImage produces native
// pixels (often 2x), so probing the image is the only reliable source.
img, err := c.Capture()
if err != nil {
return nil, fmt.Errorf("probe capture: %w", err)
}
nativeW := img.Rect.Dx()
nativeH := img.Rect.Dy()
c.hasHash = false
if nativeW == 0 || nativeH == 0 {
return nil, errors.New("display dimensions are zero")
}
logicalW := int(cgDisplayPixelsWide(displayID))
logicalH := int(cgDisplayPixelsHigh(displayID))
// Enable 2:1 downscale on Retina unless explicitly disabled. Cuts pixel
// count 4x, shrinking convert, diff, and wire data proportionally.
if !retinaDownscaleDisabled() && nativeW >= 2*logicalW && nativeH >= 2*logicalH && nativeW%2 == 0 && nativeH%2 == 0 {
c.downscale = 2
}
c.w = nativeW / c.downscale
c.h = nativeH / c.downscale
log.Infof("macOS capturer ready: %dx%d (native %dx%d, logical %dx%d, downscale=%d, display=%d)",
c.w, c.h, nativeW, nativeH, logicalW, logicalH, c.downscale, displayID)
return c, nil
}
func retinaDownscaleDisabled() bool {
v := os.Getenv(EnvVNCDisableDownscale)
if v == "" {
return false
}
disabled, err := strconv.ParseBool(v)
if err != nil {
log.Warnf("parse %s: %v", EnvVNCDisableDownscale, err)
return false
}
return disabled
}
// Width returns the screen width.
func (c *CGCapturer) Width() int { return c.w }
// Height returns the screen height.
func (c *CGCapturer) Height() int { return c.h }
// Capture returns the current screen as an RGBA image.
func (c *CGCapturer) Capture() (*image.RGBA, error) {
cgImage := cgDisplayCreateImage(c.displayID)
if cgImage == 0 {
return nil, fmt.Errorf("CGDisplayCreateImage returned nil (screen recording permission?)")
}
defer cgImageRelease(cgImage)
w := int(cgImageGetWidth(cgImage))
h := int(cgImageGetHeight(cgImage))
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
bpp := int(cgImageGetBitsPerPixel(cgImage))
provider := cgImageGetDataProvider(cgImage)
if provider == 0 {
return nil, fmt.Errorf("CGImageGetDataProvider returned nil")
}
cfData := cgDataProviderCopyData(provider)
if cfData == 0 {
return nil, fmt.Errorf("CGDataProviderCopyData returned nil")
}
defer cfRelease(cfData)
dataLen := int(cfDataGetLength(cfData))
dataPtr := cfDataGetBytePtr(cfData)
if dataPtr == 0 || dataLen == 0 {
return nil, fmt.Errorf("empty image data")
}
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
hash := maphash.Bytes(c.hashSeed, src)
if c.hasHash && hash == c.lastHash {
return nil, errFrameUnchanged
}
c.lastHash = hash
c.hasHash = true
ds := c.downscale
if ds < 1 {
ds = 1
}
outW := w / ds
outH := h / ds
img := image.NewRGBA(image.Rect(0, 0, outW, outH))
bytesPerPixel := bpp / 8
if bytesPerPixel == 4 && ds == 1 {
convertBGRAToRGBA(img.Pix, img.Stride, src, bytesPerRow, w, h)
} else if bytesPerPixel == 4 && ds == 2 {
convertBGRAToRGBADownscale2(img.Pix, img.Stride, src, bytesPerRow, outW, outH)
} else {
for row := 0; row < outH; row++ {
srcOff := row * ds * bytesPerRow
dstOff := row * img.Stride
for col := 0; col < outW; col++ {
si := srcOff + col*ds*bytesPerPixel
di := dstOff + col*4
img.Pix[di+0] = src[si+2]
img.Pix[di+1] = src[si+1]
img.Pix[di+2] = src[si+0]
img.Pix[di+3] = 0xff
}
}
}
return img, nil
}
// convertBGRAToRGBADownscale2 averages every 2x2 BGRA block into one RGBA
// output pixel, parallelised across GOMAXPROCS cores. outW and outH are the
// destination dimensions (source is 2*outW by 2*outH).
func convertBGRAToRGBADownscale2(dst []byte, dstStride int, src []byte, srcStride, outW, outH int) {
workers := runtime.GOMAXPROCS(0)
if workers > outH {
workers = outH
}
if workers < 1 || outH < 32 {
workers = 1
}
convertRows := func(y0, y1 int) {
for row := y0; row < y1; row++ {
srcRow0 := 2 * row * srcStride
srcRow1 := srcRow0 + srcStride
dstOff := row * dstStride
for col := 0; col < outW; col++ {
s0 := srcRow0 + col*8
s1 := srcRow1 + col*8
b := (uint32(src[s0]) + uint32(src[s0+4]) + uint32(src[s1]) + uint32(src[s1+4])) >> 2
g := (uint32(src[s0+1]) + uint32(src[s0+5]) + uint32(src[s1+1]) + uint32(src[s1+5])) >> 2
r := (uint32(src[s0+2]) + uint32(src[s0+6]) + uint32(src[s1+2]) + uint32(src[s1+6])) >> 2
di := dstOff + col*4
dst[di+0] = byte(r)
dst[di+1] = byte(g)
dst[di+2] = byte(b)
dst[di+3] = 0xff
}
}
}
if workers == 1 {
convertRows(0, outH)
return
}
var wg sync.WaitGroup
chunk := (outH + workers - 1) / workers
for i := 0; i < workers; i++ {
y0 := i * chunk
y1 := y0 + chunk
if y1 > outH {
y1 = outH
}
if y0 >= y1 {
break
}
wg.Add(1)
go func(y0, y1 int) {
defer wg.Done()
convertRows(y0, y1)
}(y0, y1)
}
wg.Wait()
}
// convertBGRAToRGBA swaps R/B channels using uint32 word operations, and
// parallelises across GOMAXPROCS cores for large images.
func convertBGRAToRGBA(dst []byte, dstStride int, src []byte, srcStride, w, h int) {
workers := runtime.GOMAXPROCS(0)
if workers > h {
workers = h
}
if workers < 1 || h < 64 {
workers = 1
}
convertRows := func(y0, y1 int) {
rowBytes := w * 4
for row := y0; row < y1; row++ {
dstRow := dst[row*dstStride : row*dstStride+rowBytes]
srcRow := src[row*srcStride : row*srcStride+rowBytes]
dstU := unsafe.Slice((*uint32)(unsafe.Pointer(&dstRow[0])), w)
srcU := unsafe.Slice((*uint32)(unsafe.Pointer(&srcRow[0])), w)
for i, p := range srcU {
dstU[i] = (p & 0xff00ff00) | ((p & 0x000000ff) << 16) | ((p & 0x00ff0000) >> 16) | 0xff000000
}
}
}
if workers == 1 {
convertRows(0, h)
return
}
var wg sync.WaitGroup
chunk := (h + workers - 1) / workers
for i := 0; i < workers; i++ {
y0 := i * chunk
y1 := y0 + chunk
if y1 > h {
y1 = h
}
if y0 >= y1 {
break
}
wg.Add(1)
go func(y0, y1 int) {
defer wg.Done()
convertRows(y0, y1)
}(y0, y1)
}
wg.Wait()
}
// MacPoller wraps CGCapturer in a continuous capture loop.
type MacPoller struct {
mu sync.Mutex
frame *image.RGBA
w, h int
done chan struct{}
// wake shortens the init-retry backoff when a client is trying to connect,
// so granting Screen Recording mid-session takes effect immediately.
wake chan struct{}
}
// NewMacPoller creates a capturer that continuously grabs the macOS display.
func NewMacPoller() *MacPoller {
p := &MacPoller{
done: make(chan struct{}),
wake: make(chan struct{}, 1),
}
go p.loop()
return p
}
// Wake pokes the init-retry loop so it doesn't wait out the full backoff
// before trying again. Safe to call from any goroutine; extra calls while a
// wake is pending are dropped.
func (p *MacPoller) Wake() {
select {
case p.wake <- struct{}{}:
default:
}
}
// Close stops the capture loop.
func (p *MacPoller) Close() {
select {
case <-p.done:
default:
close(p.done)
}
}
// Width returns the screen width.
func (p *MacPoller) Width() int {
p.mu.Lock()
defer p.mu.Unlock()
return p.w
}
// Height returns the screen height.
func (p *MacPoller) Height() int {
p.mu.Lock()
defer p.mu.Unlock()
return p.h
}
// Capture returns the most recent frame.
func (p *MacPoller) Capture() (*image.RGBA, error) {
p.mu.Lock()
img := p.frame
p.mu.Unlock()
if img != nil {
return img, nil
}
return nil, fmt.Errorf("no frame available yet")
}
func (p *MacPoller) loop() {
var capturer *CGCapturer
var initFails int
for {
select {
case <-p.done:
return
default:
}
if capturer == nil {
var err error
capturer, err = NewCGCapturer()
if err != nil {
initFails++
// Retry forever with backoff: the user may grant Screen
// Recording after the server started, and we need to pick it
// up whenever that happens.
delay := 2 * time.Second
if initFails > 15 {
delay = 30 * time.Second
} else if initFails > 5 {
delay = 10 * time.Second
}
if initFails == 1 || initFails%10 == 0 {
log.Warnf("macOS capturer: %v (attempt %d, retrying every %s)", err, initFails, delay)
} else {
log.Debugf("macOS capturer: %v (attempt %d)", err, initFails)
}
select {
case <-p.done:
return
case <-p.wake:
// Client is trying to connect, retry now.
case <-time.After(delay):
}
continue
}
initFails = 0
p.mu.Lock()
p.w, p.h = capturer.Width(), capturer.Height()
p.mu.Unlock()
}
img, err := capturer.Capture()
if errors.Is(err, errFrameUnchanged) {
select {
case <-p.done:
return
case <-time.After(33 * time.Millisecond):
}
continue
}
if err != nil {
log.Debugf("macOS capture: %v", err)
capturer = nil
select {
case <-p.done:
return
case <-time.After(500 * time.Millisecond):
}
continue
}
p.mu.Lock()
p.frame = img
p.mu.Unlock()
select {
case <-p.done:
return
case <-time.After(33 * time.Millisecond): // ~30 fps
}
}
}
var _ ScreenCapturer = (*MacPoller)(nil)

View File

@@ -0,0 +1,99 @@
//go:build windows
package server
import (
"errors"
"fmt"
"image"
"github.com/kirides/go-d3d/d3d11"
"github.com/kirides/go-d3d/outputduplication"
)
// dxgiCapturer captures the desktop using DXGI Desktop Duplication.
// Provides GPU-accelerated capture with native dirty rect tracking.
// Only works from the interactive user session, not Session 0.
//
// Uses a double-buffer: DXGI writes into img, then we copy to the current
// output buffer and hand it out. Alternating between two output buffers
// avoids allocating a new image.RGBA per frame (~8MB at 1080p, 30fps).
type dxgiCapturer struct {
dup *outputduplication.OutputDuplicator
device *d3d11.ID3D11Device
ctx *d3d11.ID3D11DeviceContext
img *image.RGBA
out [2]*image.RGBA
outIdx int
width int
height int
}
func newDXGICapturer() (*dxgiCapturer, error) {
device, deviceCtx, err := d3d11.NewD3D11Device()
if err != nil {
return nil, fmt.Errorf("create D3D11 device: %w", err)
}
dup, err := outputduplication.NewIDXGIOutputDuplication(device, deviceCtx, 0)
if err != nil {
device.Release()
deviceCtx.Release()
return nil, fmt.Errorf("create output duplication: %w", err)
}
w, h := screenSize()
if w == 0 || h == 0 {
dup.Release()
device.Release()
deviceCtx.Release()
return nil, fmt.Errorf("screen dimensions are zero")
}
rect := image.Rect(0, 0, w, h)
c := &dxgiCapturer{
dup: dup,
device: device,
ctx: deviceCtx,
img: image.NewRGBA(rect),
out: [2]*image.RGBA{image.NewRGBA(rect), image.NewRGBA(rect)},
width: w,
height: h,
}
// Grab the initial frame with a longer timeout to ensure we have
// a valid image before returning.
_ = dup.GetImage(c.img, 2000)
return c, nil
}
func (c *dxgiCapturer) capture() (*image.RGBA, error) {
err := c.dup.GetImage(c.img, 100)
if err != nil && !errors.Is(err, outputduplication.ErrNoImageYet) {
return nil, err
}
// Copy into the next output buffer. The DesktopCapturer hands out the
// returned pointer to VNC sessions that read pixels concurrently, so we
// alternate between two pre-allocated buffers instead of allocating per frame.
out := c.out[c.outIdx]
c.outIdx ^= 1
copy(out.Pix, c.img.Pix)
return out, nil
}
func (c *dxgiCapturer) close() {
if c.dup != nil {
c.dup.Release()
c.dup = nil
}
if c.ctx != nil {
c.ctx.Release()
c.ctx = nil
}
if c.device != nil {
c.device.Release()
c.device = nil
}
}

View File

@@ -0,0 +1,461 @@
//go:build windows
package server
import (
"fmt"
"image"
"runtime"
"sync"
"sync/atomic"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
var (
gdi32 = windows.NewLazySystemDLL("gdi32.dll")
user32 = windows.NewLazySystemDLL("user32.dll")
procGetDC = user32.NewProc("GetDC")
procReleaseDC = user32.NewProc("ReleaseDC")
procCreateCompatDC = gdi32.NewProc("CreateCompatibleDC")
procCreateDIBSection = gdi32.NewProc("CreateDIBSection")
procSelectObject = gdi32.NewProc("SelectObject")
procDeleteObject = gdi32.NewProc("DeleteObject")
procDeleteDC = gdi32.NewProc("DeleteDC")
procBitBlt = gdi32.NewProc("BitBlt")
procGetSystemMetrics = user32.NewProc("GetSystemMetrics")
// Desktop switching for service/Session 0 capture.
procOpenInputDesktop = user32.NewProc("OpenInputDesktop")
procSetThreadDesktop = user32.NewProc("SetThreadDesktop")
procCloseDesktop = user32.NewProc("CloseDesktop")
procOpenWindowStation = user32.NewProc("OpenWindowStationW")
procSetProcessWindowStation = user32.NewProc("SetProcessWindowStation")
procCloseWindowStation = user32.NewProc("CloseWindowStation")
procGetUserObjectInformationW = user32.NewProc("GetUserObjectInformationW")
)
const uoiName = 2
const (
smCxScreen = 0
smCyScreen = 1
srccopy = 0x00CC0020
dibRgbColors = 0
)
type bitmapInfoHeader struct {
Size uint32
Width int32
Height int32
Planes uint16
BitCount uint16
Compression uint32
SizeImage uint32
XPelsPerMeter int32
YPelsPerMeter int32
ClrUsed uint32
ClrImportant uint32
}
type bitmapInfo struct {
Header bitmapInfoHeader
}
// setupInteractiveWindowStation associates the current process with WinSta0,
// the interactive window station. This is required for a SYSTEM service in
// Session 0 to call OpenInputDesktop for screen capture and input injection.
func setupInteractiveWindowStation() error {
name, err := windows.UTF16PtrFromString("WinSta0")
if err != nil {
return fmt.Errorf("UTF16 WinSta0: %w", err)
}
hWinSta, _, err := procOpenWindowStation.Call(
uintptr(unsafe.Pointer(name)),
0,
uintptr(windows.MAXIMUM_ALLOWED),
)
if hWinSta == 0 {
return fmt.Errorf("OpenWindowStation(WinSta0): %w", err)
}
r, _, err := procSetProcessWindowStation.Call(hWinSta)
if r == 0 {
procCloseWindowStation.Call(hWinSta)
return fmt.Errorf("SetProcessWindowStation: %w", err)
}
log.Info("process window station set to WinSta0 (interactive)")
return nil
}
func screenSize() (int, int) {
w, _, _ := procGetSystemMetrics.Call(uintptr(smCxScreen))
h, _, _ := procGetSystemMetrics.Call(uintptr(smCyScreen))
return int(w), int(h)
}
func getDesktopName(hDesk uintptr) string {
var buf [256]uint16
var needed uint32
procGetUserObjectInformationW.Call(hDesk, uoiName,
uintptr(unsafe.Pointer(&buf[0])), 512,
uintptr(unsafe.Pointer(&needed)))
return windows.UTF16ToString(buf[:])
}
// switchToInputDesktop opens the desktop currently receiving user input
// and sets it as the calling OS thread's desktop. Must be called from a
// goroutine locked to its OS thread via runtime.LockOSThread().
func switchToInputDesktop() (bool, string) {
hDesk, _, _ := procOpenInputDesktop.Call(0, 0, uintptr(windows.MAXIMUM_ALLOWED))
if hDesk == 0 {
return false, ""
}
name := getDesktopName(hDesk)
ret, _, _ := procSetThreadDesktop.Call(hDesk)
procCloseDesktop.Call(hDesk)
return ret != 0, name
}
// gdiCapturer captures the desktop screen using GDI BitBlt.
// GDI objects (DC, DIBSection) are allocated once and reused across frames.
type gdiCapturer struct {
mu sync.Mutex
width int
height int
// Pre-allocated GDI resources, reused across captures.
memDC uintptr
bmp uintptr
bits uintptr
}
func newGDICapturer() (*gdiCapturer, error) {
w, h := screenSize()
if w == 0 || h == 0 {
return nil, fmt.Errorf("screen dimensions are zero")
}
c := &gdiCapturer{width: w, height: h}
if err := c.allocGDI(); err != nil {
return nil, err
}
return c, nil
}
// allocGDI pre-allocates the compatible DC and DIB section for reuse.
func (c *gdiCapturer) allocGDI() error {
screenDC, _, _ := procGetDC.Call(0)
if screenDC == 0 {
return fmt.Errorf("GetDC returned 0")
}
defer procReleaseDC.Call(0, screenDC)
memDC, _, _ := procCreateCompatDC.Call(screenDC)
if memDC == 0 {
return fmt.Errorf("CreateCompatibleDC returned 0")
}
bi := bitmapInfo{
Header: bitmapInfoHeader{
Size: uint32(unsafe.Sizeof(bitmapInfoHeader{})),
Width: int32(c.width),
Height: -int32(c.height), // negative = top-down DIB
Planes: 1,
BitCount: 32,
},
}
var bits uintptr
bmp, _, _ := procCreateDIBSection.Call(
screenDC,
uintptr(unsafe.Pointer(&bi)),
dibRgbColors,
uintptr(unsafe.Pointer(&bits)),
0, 0,
)
if bmp == 0 || bits == 0 {
procDeleteDC.Call(memDC)
return fmt.Errorf("CreateDIBSection returned 0")
}
procSelectObject.Call(memDC, bmp)
c.memDC = memDC
c.bmp = bmp
c.bits = bits
return nil
}
func (c *gdiCapturer) close() { c.freeGDI() }
// freeGDI releases pre-allocated GDI resources.
func (c *gdiCapturer) freeGDI() {
if c.bmp != 0 {
procDeleteObject.Call(c.bmp)
c.bmp = 0
}
if c.memDC != 0 {
procDeleteDC.Call(c.memDC)
c.memDC = 0
}
c.bits = 0
}
func (c *gdiCapturer) capture() (*image.RGBA, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.memDC == 0 {
return nil, fmt.Errorf("GDI resources not allocated")
}
screenDC, _, _ := procGetDC.Call(0)
if screenDC == 0 {
return nil, fmt.Errorf("GetDC returned 0")
}
defer procReleaseDC.Call(0, screenDC)
ret, _, _ := procBitBlt.Call(c.memDC, 0, 0, uintptr(c.width), uintptr(c.height),
screenDC, 0, 0, srccopy)
if ret == 0 {
return nil, fmt.Errorf("BitBlt returned 0")
}
n := c.width * c.height * 4
raw := unsafe.Slice((*byte)(unsafe.Pointer(c.bits)), n)
// GDI gives BGRA, the RFB encoder expects RGBA (img.Pix layout).
// Swap R and B in bulk using uint32 operations (one load + mask + shift
// per pixel instead of three separate byte assignments).
img := image.NewRGBA(image.Rect(0, 0, c.width, c.height))
pix := img.Pix
copy(pix, raw)
swizzleBGRAtoRGBA(pix)
return img, nil
}
// DesktopCapturer captures the interactive desktop, handling desktop transitions
// (login screen, UAC prompts). A dedicated OS-locked goroutine continuously
// captures frames, which are retrieved by the VNC session on demand.
// Capture pauses automatically when no clients are connected.
type DesktopCapturer struct {
mu sync.Mutex
frame *image.RGBA
w, h int
// clients tracks the number of active VNC sessions. When zero, the
// capture loop idles instead of grabbing frames.
clients atomic.Int32
// wake is signaled when a client connects and the loop should resume.
wake chan struct{}
// done is closed when Close is called, terminating the capture loop.
done chan struct{}
}
// NewDesktopCapturer creates a capturer that continuously grabs the active desktop.
func NewDesktopCapturer() *DesktopCapturer {
c := &DesktopCapturer{
wake: make(chan struct{}, 1),
done: make(chan struct{}),
}
go c.loop()
return c
}
// ClientConnect increments the active client count, resuming capture if needed.
func (c *DesktopCapturer) ClientConnect() {
c.clients.Add(1)
select {
case c.wake <- struct{}{}:
default:
}
}
// ClientDisconnect decrements the active client count.
func (c *DesktopCapturer) ClientDisconnect() {
c.clients.Add(-1)
}
// Close stops the capture loop and releases resources.
func (c *DesktopCapturer) Close() {
select {
case <-c.done:
default:
close(c.done)
}
}
// Width returns the current screen width.
func (c *DesktopCapturer) Width() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.w
}
// Height returns the current screen height.
func (c *DesktopCapturer) Height() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.h
}
// Capture returns the most recent desktop frame.
func (c *DesktopCapturer) Capture() (*image.RGBA, error) {
c.mu.Lock()
img := c.frame
c.mu.Unlock()
if img != nil {
return img, nil
}
return nil, fmt.Errorf("no frame available yet")
}
// waitForClient blocks until a client connects or the capturer is closed.
func (c *DesktopCapturer) waitForClient() bool {
if c.clients.Load() > 0 {
return true
}
select {
case <-c.wake:
return true
case <-c.done:
return false
}
}
func (c *DesktopCapturer) loop() {
runtime.LockOSThread()
// When running as a Windows service (Session 0), we need to attach to the
// interactive window station before OpenInputDesktop will succeed.
if err := setupInteractiveWindowStation(); err != nil {
log.Warnf("attach to interactive window station: %v", err)
}
frameTicker := time.NewTicker(33 * time.Millisecond) // ~30 fps
defer frameTicker.Stop()
retryTimer := time.NewTimer(0)
retryTimer.Stop()
defer retryTimer.Stop()
type frameCapturer interface {
capture() (*image.RGBA, error)
close()
}
var cap frameCapturer
var desktopFails int
var lastDesktop string
createCapturer := func() (frameCapturer, error) {
dc, err := newDXGICapturer()
if err == nil {
log.Info("using DXGI Desktop Duplication for capture")
return dc, nil
}
log.Debugf("DXGI unavailable (%v), falling back to GDI", err)
gc, err := newGDICapturer()
if err != nil {
return nil, err
}
log.Info("using GDI BitBlt for capture")
return gc, nil
}
for {
if !c.waitForClient() {
if cap != nil {
cap.close()
}
return
}
// No clients: release the capturer and wait.
if c.clients.Load() <= 0 {
if cap != nil {
cap.close()
cap = nil
}
continue
}
ok, desk := switchToInputDesktop()
if !ok {
desktopFails++
if desktopFails == 1 || desktopFails%100 == 0 {
log.Warnf("switchToInputDesktop failed (count=%d), no interactive desktop session?", desktopFails)
}
retryTimer.Reset(100 * time.Millisecond)
select {
case <-retryTimer.C:
case <-c.done:
return
}
continue
}
if desktopFails > 0 {
log.Infof("switchToInputDesktop recovered after %d failures, desktop=%q", desktopFails, desk)
desktopFails = 0
}
if desk != lastDesktop {
log.Infof("desktop changed: %q -> %q", lastDesktop, desk)
lastDesktop = desk
if cap != nil {
cap.close()
}
cap = nil
}
if cap == nil {
fc, err := createCapturer()
if err != nil {
log.Warnf("create capturer: %v", err)
retryTimer.Reset(500 * time.Millisecond)
select {
case <-retryTimer.C:
case <-c.done:
return
}
continue
}
cap = fc
w, h := screenSize()
c.mu.Lock()
c.w, c.h = w, h
c.mu.Unlock()
log.Infof("screen capturer ready: %dx%d", w, h)
}
img, err := cap.capture()
if err != nil {
log.Debugf("capture: %v", err)
cap.close()
cap = nil
retryTimer.Reset(100 * time.Millisecond)
select {
case <-retryTimer.C:
case <-c.done:
return
}
continue
}
c.mu.Lock()
c.frame = img
c.mu.Unlock()
select {
case <-frameTicker.C:
case <-c.done:
if cap != nil {
cap.close()
}
return
}
}
}

View File

@@ -0,0 +1,385 @@
//go:build (linux && !android) || freebsd
package server
import (
"fmt"
"image"
"os"
"os/exec"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/jezek/xgb"
"github.com/jezek/xgb/xproto"
)
// X11Capturer captures the screen from an X11 display using the MIT-SHM extension.
type X11Capturer struct {
mu sync.Mutex
conn *xgb.Conn
screen *xproto.ScreenInfo
w, h int
shmID int
shmAddr []byte
shmSeg uint32 // shm.Seg
useSHM bool
}
// detectX11Display finds the active X11 display and sets DISPLAY/XAUTHORITY
// environment variables if needed. This is required when running as a system
// service where these vars aren't set.
func detectX11Display() {
if os.Getenv("DISPLAY") != "" {
return
}
// Try /proc first (Linux), then ps fallback (FreeBSD and others).
if detectX11FromProc() {
return
}
if detectX11FromSockets() {
return
}
}
// detectX11FromProc scans /proc/*/cmdline for Xorg (Linux).
func detectX11FromProc() bool {
entries, err := os.ReadDir("/proc")
if err != nil {
return false
}
for _, e := range entries {
if !e.IsDir() {
continue
}
cmdline, err := os.ReadFile("/proc/" + e.Name() + "/cmdline")
if err != nil {
continue
}
if display, auth := parseXorgArgs(splitCmdline(cmdline)); display != "" {
setDisplayEnv(display, auth)
return true
}
}
return false
}
// detectX11FromSockets checks /tmp/.X11-unix/ for X sockets and uses ps
// to find the auth file. Works on FreeBSD and other systems without /proc.
func detectX11FromSockets() bool {
entries, err := os.ReadDir("/tmp/.X11-unix")
if err != nil {
return false
}
// Find the lowest display number.
for _, e := range entries {
name := e.Name()
if len(name) < 2 || name[0] != 'X' {
continue
}
display := ":" + name[1:]
os.Setenv("DISPLAY", display)
log.Infof("auto-detected DISPLAY=%s (from socket)", display)
// Try to find -auth from ps output.
if auth := findXorgAuthFromPS(); auth != "" {
os.Setenv("XAUTHORITY", auth)
log.Infof("auto-detected XAUTHORITY=%s (from ps)", auth)
}
return true
}
return false
}
// findXorgAuthFromPS runs ps to find Xorg and extract its -auth argument.
func findXorgAuthFromPS() string {
out, err := exec.Command("ps", "auxww").Output()
if err != nil {
return ""
}
for _, line := range strings.Split(string(out), "\n") {
if !strings.Contains(line, "Xorg") && !strings.Contains(line, "/X ") {
continue
}
fields := strings.Fields(line)
for i, f := range fields {
if f == "-auth" && i+1 < len(fields) {
return fields[i+1]
}
}
}
return ""
}
func parseXorgArgs(args []string) (display, auth string) {
if len(args) == 0 {
return "", ""
}
base := args[0]
if !(base == "Xorg" || base == "X" || len(base) > 0 && base[len(base)-1] == 'X' ||
strings.Contains(base, "/Xorg") || strings.Contains(base, "/X")) {
return "", ""
}
for i, arg := range args[1:] {
if len(arg) > 0 && arg[0] == ':' {
display = arg
}
if arg == "-auth" && i+2 < len(args) {
auth = args[i+2]
}
}
return display, auth
}
func setDisplayEnv(display, auth string) {
os.Setenv("DISPLAY", display)
log.Infof("auto-detected DISPLAY=%s", display)
if auth != "" {
os.Setenv("XAUTHORITY", auth)
log.Infof("auto-detected XAUTHORITY=%s", auth)
}
}
func splitCmdline(data []byte) []string {
var args []string
for _, b := range splitNull(data) {
if len(b) > 0 {
args = append(args, string(b))
}
}
return args
}
func splitNull(data []byte) [][]byte {
var parts [][]byte
start := 0
for i, b := range data {
if b == 0 {
parts = append(parts, data[start:i])
start = i + 1
}
}
if start < len(data) {
parts = append(parts, data[start:])
}
return parts
}
// NewX11Capturer connects to the X11 display and sets up shared memory capture.
func NewX11Capturer(display string) (*X11Capturer, error) {
detectX11Display()
if display == "" {
display = os.Getenv("DISPLAY")
}
if display == "" {
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
}
conn, err := xgb.NewConnDisplay(display)
if err != nil {
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
}
setup := xproto.Setup(conn)
if len(setup.Roots) == 0 {
conn.Close()
return nil, fmt.Errorf("no X11 screens")
}
screen := setup.Roots[0]
c := &X11Capturer{
conn: conn,
screen: &screen,
w: int(screen.WidthInPixels),
h: int(screen.HeightInPixels),
}
if err := c.initSHM(); err != nil {
log.Debugf("X11 SHM not available, using slow GetImage: %v", err)
}
log.Infof("X11 capturer ready: %dx%d (display=%s, shm=%v)", c.w, c.h, display, c.useSHM)
return c, nil
}
// initSHM is implemented in capture_x11_shm_linux.go (requires SysV SHM).
// On platforms without SysV SHM (FreeBSD), a stub returns an error and
// the capturer falls back to GetImage.
// Width returns the screen width.
func (c *X11Capturer) Width() int { return c.w }
// Height returns the screen height.
func (c *X11Capturer) Height() int { return c.h }
// Capture returns the current screen as an RGBA image.
func (c *X11Capturer) Capture() (*image.RGBA, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.useSHM {
return c.captureSHM()
}
return c.captureGetImage()
}
// captureSHM is implemented in capture_x11_shm_linux.go.
func (c *X11Capturer) captureGetImage() (*image.RGBA, error) {
cookie := xproto.GetImage(c.conn, xproto.ImageFormatZPixmap,
xproto.Drawable(c.screen.Root),
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF)
reply, err := cookie.Reply()
if err != nil {
return nil, fmt.Errorf("GetImage: %w", err)
}
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
data := reply.Data
n := c.w * c.h * 4
if len(data) < n {
return nil, fmt.Errorf("GetImage returned %d bytes, expected %d", len(data), n)
}
for i := 0; i < n; i += 4 {
img.Pix[i+0] = data[i+2] // R
img.Pix[i+1] = data[i+1] // G
img.Pix[i+2] = data[i+0] // B
img.Pix[i+3] = 0xff
}
return img, nil
}
// Close releases X11 resources.
func (c *X11Capturer) Close() {
c.closeSHM()
c.conn.Close()
}
// closeSHM is implemented in capture_x11_shm_linux.go.
// X11Poller wraps X11Capturer in a continuous capture loop, matching the
// DesktopCapturer pattern from Windows.
type X11Poller struct {
mu sync.Mutex
frame *image.RGBA
w, h int
display string
done chan struct{}
}
// NewX11Poller creates a capturer that continuously grabs the X11 display.
func NewX11Poller(display string) *X11Poller {
p := &X11Poller{
display: display,
done: make(chan struct{}),
}
go p.loop()
return p
}
// Close stops the capture loop.
func (p *X11Poller) Close() {
select {
case <-p.done:
default:
close(p.done)
}
}
// Width returns the screen width.
func (p *X11Poller) Width() int {
p.mu.Lock()
defer p.mu.Unlock()
return p.w
}
// Height returns the screen height.
func (p *X11Poller) Height() int {
p.mu.Lock()
defer p.mu.Unlock()
return p.h
}
// Capture returns the most recent frame.
func (p *X11Poller) Capture() (*image.RGBA, error) {
p.mu.Lock()
img := p.frame
p.mu.Unlock()
if img != nil {
return img, nil
}
return nil, fmt.Errorf("no frame available yet")
}
func (p *X11Poller) loop() {
var capturer *X11Capturer
var initFails int
defer func() {
if capturer != nil {
capturer.Close()
}
}()
for {
select {
case <-p.done:
return
default:
}
if capturer == nil {
var err error
capturer, err = NewX11Capturer(p.display)
if err != nil {
initFails++
if initFails <= maxCapturerRetries {
log.Debugf("X11 capturer: %v (attempt %d/%d)", err, initFails, maxCapturerRetries)
select {
case <-p.done:
return
case <-time.After(2 * time.Second):
}
continue
}
log.Warnf("X11 capturer unavailable after %d attempts, stopping poller", maxCapturerRetries)
return
}
initFails = 0
p.mu.Lock()
p.w, p.h = capturer.Width(), capturer.Height()
p.mu.Unlock()
}
img, err := capturer.Capture()
if err != nil {
log.Debugf("X11 capture: %v", err)
capturer.Close()
capturer = nil
select {
case <-p.done:
return
case <-time.After(500 * time.Millisecond):
}
continue
}
p.mu.Lock()
p.frame = img
p.mu.Unlock()
select {
case <-p.done:
return
case <-time.After(33 * time.Millisecond): // ~30 fps
}
}
}

View File

@@ -0,0 +1,78 @@
//go:build linux && !android
package server
import (
"fmt"
"image"
"github.com/jezek/xgb/shm"
"github.com/jezek/xgb/xproto"
"golang.org/x/sys/unix"
)
func (c *X11Capturer) initSHM() error {
if err := shm.Init(c.conn); err != nil {
return fmt.Errorf("init SHM extension: %w", err)
}
size := c.w * c.h * 4
id, err := unix.SysvShmGet(unix.IPC_PRIVATE, size, unix.IPC_CREAT|0600)
if err != nil {
return fmt.Errorf("shmget: %w", err)
}
addr, err := unix.SysvShmAttach(id, 0, 0)
if err != nil {
unix.SysvShmCtl(id, unix.IPC_RMID, nil)
return fmt.Errorf("shmat: %w", err)
}
unix.SysvShmCtl(id, unix.IPC_RMID, nil)
seg, err := shm.NewSegId(c.conn)
if err != nil {
unix.SysvShmDetach(addr)
return fmt.Errorf("new SHM seg: %w", err)
}
if err := shm.AttachChecked(c.conn, seg, uint32(id), false).Check(); err != nil {
unix.SysvShmDetach(addr)
return fmt.Errorf("SHM attach to X: %w", err)
}
c.shmID = id
c.shmAddr = addr
c.shmSeg = uint32(seg)
c.useSHM = true
return nil
}
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
cookie := shm.GetImage(c.conn, xproto.Drawable(c.screen.Root),
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF,
xproto.ImageFormatZPixmap, shm.Seg(c.shmSeg), 0)
_, err := cookie.Reply()
if err != nil {
return nil, fmt.Errorf("SHM GetImage: %w", err)
}
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
n := c.w * c.h * 4
for i := 0; i < n; i += 4 {
img.Pix[i+0] = c.shmAddr[i+2] // R
img.Pix[i+1] = c.shmAddr[i+1] // G
img.Pix[i+2] = c.shmAddr[i+0] // B
img.Pix[i+3] = 0xff
}
return img, nil
}
func (c *X11Capturer) closeSHM() {
if c.useSHM {
shm.Detach(c.conn, shm.Seg(c.shmSeg))
unix.SysvShmDetach(c.shmAddr)
}
}

View File

@@ -0,0 +1,18 @@
//go:build freebsd
package server
import (
"fmt"
"image"
)
func (c *X11Capturer) initSHM() error {
return fmt.Errorf("SysV SHM not available on this platform")
}
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
return nil, fmt.Errorf("SHM capture not available on this platform")
}
func (c *X11Capturer) closeSHM() {}

151
client/vnc/server/crypto.go Normal file
View File

@@ -0,0 +1,151 @@
package server
import (
"crypto/aes"
"crypto/cipher"
"crypto/ecdh"
"crypto/rand"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"crypto/sha256"
"golang.org/x/crypto/hkdf"
)
const (
aesKeySize = 32 // AES-256
gcmNonceSize = 12
)
// recCrypto holds per-session encryption state.
type recCrypto struct {
gcm cipher.AEAD
frameCounter uint64
// ephemeralPub is stored in the recording header so the admin can derive the same key.
ephemeralPub []byte
}
// newRecCrypto sets up encryption for a new recording session.
// adminPubKeyB64 is the base64-encoded X25519 public key from management settings.
func newRecCrypto(adminPubKeyB64 string) (*recCrypto, error) {
adminPubBytes, err := base64.StdEncoding.DecodeString(adminPubKeyB64)
if err != nil {
return nil, fmt.Errorf("decode admin public key: %w", err)
}
adminPub, err := ecdh.X25519().NewPublicKey(adminPubBytes)
if err != nil {
return nil, fmt.Errorf("parse admin X25519 public key: %w", err)
}
// Generate ephemeral keypair
ephemeral, err := ecdh.X25519().GenerateKey(rand.Reader)
if err != nil {
return nil, fmt.Errorf("generate ephemeral key: %w", err)
}
// ECDH shared secret
shared, err := ephemeral.ECDH(adminPub)
if err != nil {
return nil, fmt.Errorf("ECDH: %w", err)
}
// Derive AES-256 key via HKDF
aesKey, err := deriveKey(shared, ephemeral.PublicKey().Bytes())
if err != nil {
return nil, fmt.Errorf("derive key: %w", err)
}
block, err := aes.NewCipher(aesKey)
if err != nil {
return nil, fmt.Errorf("create AES cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("create GCM: %w", err)
}
return &recCrypto{
gcm: gcm,
ephemeralPub: ephemeral.PublicKey().Bytes(),
}, nil
}
// encrypt encrypts plaintext using a counter-based nonce. Each call increments the counter.
func (c *recCrypto) encrypt(plaintext []byte) []byte {
nonce := make([]byte, gcmNonceSize)
binary.LittleEndian.PutUint64(nonce, c.frameCounter)
c.frameCounter++
return c.gcm.Seal(nil, nonce, plaintext, nil)
}
// DecryptRecording creates a decryptor from the admin's private key and the ephemeral public key from the header.
func DecryptRecording(adminPrivKeyB64 string, ephemeralPubB64 string) (*recDecryptor, error) {
adminPrivBytes, err := base64.StdEncoding.DecodeString(adminPrivKeyB64)
if err != nil {
return nil, fmt.Errorf("decode admin private key: %w", err)
}
adminPriv, err := ecdh.X25519().NewPrivateKey(adminPrivBytes)
if err != nil {
return nil, fmt.Errorf("parse admin X25519 private key: %w", err)
}
ephPubBytes, err := base64.StdEncoding.DecodeString(ephemeralPubB64)
if err != nil {
return nil, fmt.Errorf("decode ephemeral public key: %w", err)
}
ephPub, err := ecdh.X25519().NewPublicKey(ephPubBytes)
if err != nil {
return nil, fmt.Errorf("parse ephemeral public key: %w", err)
}
shared, err := adminPriv.ECDH(ephPub)
if err != nil {
return nil, fmt.Errorf("ECDH: %w", err)
}
aesKey, err := deriveKey(shared, ephPubBytes)
if err != nil {
return nil, fmt.Errorf("derive key: %w", err)
}
block, err := aes.NewCipher(aesKey)
if err != nil {
return nil, fmt.Errorf("create AES cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("create GCM: %w", err)
}
return &recDecryptor{gcm: gcm}, nil
}
type recDecryptor struct {
gcm cipher.AEAD
frameCounter uint64
}
// Decrypt decrypts a frame. Must be called in the same order as encryption.
func (d *recDecryptor) Decrypt(ciphertext []byte) ([]byte, error) {
nonce := make([]byte, gcmNonceSize)
binary.LittleEndian.PutUint64(nonce, d.frameCounter)
d.frameCounter++
return d.gcm.Open(nil, nonce, ciphertext, nil)
}
func deriveKey(shared, ephemeralPub []byte) ([]byte, error) {
hkdfReader := hkdf.New(sha256.New, shared, ephemeralPub, []byte("netbird-recording"))
key := make([]byte, aesKeySize)
if _, err := io.ReadFull(hkdfReader, key); err != nil {
return nil, err
}
return key, nil
}

View File

@@ -0,0 +1,129 @@
package server
import (
"crypto/ecdh"
"crypto/rand"
"encoding/base64"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCryptoRoundtrip(t *testing.T) {
// Generate admin keypair
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
require.NoError(t, err)
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
adminPrivB64 := base64.StdEncoding.EncodeToString(adminPriv.Bytes())
// Create encryptor (recording side)
enc, err := newRecCrypto(adminPubB64)
require.NoError(t, err)
assert.Len(t, enc.ephemeralPub, 32)
ephPubB64 := base64.StdEncoding.EncodeToString(enc.ephemeralPub)
// Encrypt some frames
plaintext1 := []byte("frame data one - PNG bytes would go here")
plaintext2 := []byte("frame data two - different content")
plaintext3 := make([]byte, 1024*100) // 100KB frame
rand.Read(plaintext3)
ct1 := enc.encrypt(plaintext1)
ct2 := enc.encrypt(plaintext2)
ct3 := enc.encrypt(plaintext3)
// Ciphertext should differ from plaintext
assert.NotEqual(t, plaintext1, ct1)
// Ciphertext is larger (GCM tag overhead)
assert.Greater(t, len(ct1), len(plaintext1))
// Create decryptor (playback side)
dec, err := DecryptRecording(adminPrivB64, ephPubB64)
require.NoError(t, err)
// Decrypt in same order
got1, err := dec.Decrypt(ct1)
require.NoError(t, err)
assert.Equal(t, plaintext1, got1)
got2, err := dec.Decrypt(ct2)
require.NoError(t, err)
assert.Equal(t, plaintext2, got2)
got3, err := dec.Decrypt(ct3)
require.NoError(t, err)
assert.Equal(t, plaintext3, got3)
}
func TestCryptoWrongKey(t *testing.T) {
// Admin key
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
require.NoError(t, err)
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
// Encrypt with admin's public key
enc, err := newRecCrypto(adminPubB64)
require.NoError(t, err)
ephPubB64 := base64.StdEncoding.EncodeToString(enc.ephemeralPub)
ct := enc.encrypt([]byte("secret frame data"))
// Try to decrypt with a different private key
wrongPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
require.NoError(t, err)
wrongPrivB64 := base64.StdEncoding.EncodeToString(wrongPriv.Bytes())
dec, err := DecryptRecording(wrongPrivB64, ephPubB64)
require.NoError(t, err)
_, err = dec.Decrypt(ct)
assert.Error(t, err, "decryption with wrong key should fail")
}
func TestCryptoInvalidKey(t *testing.T) {
_, err := newRecCrypto("")
assert.Error(t, err, "empty key should fail")
_, err = newRecCrypto("not-base64!!!")
assert.Error(t, err, "invalid base64 should fail")
_, err = newRecCrypto(base64.StdEncoding.EncodeToString([]byte("too-short")))
assert.Error(t, err, "wrong-length key should fail")
_, err = DecryptRecording("", "validbutirrelevant")
assert.Error(t, err, "empty private key should fail")
_, err = DecryptRecording("not-base64!!!", base64.StdEncoding.EncodeToString(make([]byte, 32)))
assert.Error(t, err, "invalid base64 private key should fail")
}
func TestCryptoOutOfOrderFails(t *testing.T) {
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
require.NoError(t, err)
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
adminPrivB64 := base64.StdEncoding.EncodeToString(adminPriv.Bytes())
enc, err := newRecCrypto(adminPubB64)
require.NoError(t, err)
ephPubB64 := base64.StdEncoding.EncodeToString(enc.ephemeralPub)
ct0 := enc.encrypt([]byte("frame 0"))
ct1 := enc.encrypt([]byte("frame 1"))
dec, err := DecryptRecording(adminPrivB64, ephPubB64)
require.NoError(t, err)
// Skip frame 0, try to decrypt frame 1 first (wrong nonce)
_, err = dec.Decrypt(ct1)
assert.Error(t, err, "out-of-order decryption should fail due to nonce mismatch")
// But frame 0 with a fresh decryptor should work
dec2, err := DecryptRecording(adminPrivB64, ephPubB64)
require.NoError(t, err)
got, err := dec2.Decrypt(ct0)
require.NoError(t, err)
assert.Equal(t, []byte("frame 0"), got)
}

View File

@@ -0,0 +1,540 @@
//go:build darwin && !ios
package server
import (
"fmt"
"os/exec"
"strings"
"sync"
"github.com/ebitengine/purego"
log "github.com/sirupsen/logrus"
)
// Core Graphics event constants.
const (
kCGEventSourceStateCombinedSessionState int32 = 0
kCGEventLeftMouseDown int32 = 1
kCGEventLeftMouseUp int32 = 2
kCGEventRightMouseDown int32 = 3
kCGEventRightMouseUp int32 = 4
kCGEventMouseMoved int32 = 5
kCGEventLeftMouseDragged int32 = 6
kCGEventRightMouseDragged int32 = 7
kCGEventKeyDown int32 = 10
kCGEventKeyUp int32 = 11
kCGEventOtherMouseDown int32 = 25
kCGEventOtherMouseUp int32 = 26
kCGMouseButtonLeft int32 = 0
kCGMouseButtonRight int32 = 1
kCGMouseButtonCenter int32 = 2
kCGHIDEventTap int32 = 0
// IOKit power management constants.
kIOPMUserActiveLocal int32 = 0
kIOPMAssertionLevelOn uint32 = 255
kCFStringEncodingUTF8 uint32 = 0x08000100
)
var darwinInputOnce sync.Once
var (
cgEventSourceCreate func(int32) uintptr
cgEventCreateKeyboardEvent func(uintptr, uint16, bool) uintptr
// CGEventCreateMouseEvent takes CGPoint as two separate float64 args.
// purego can't handle array/struct types but individual float64s work.
cgEventCreateMouseEvent func(uintptr, int32, float64, float64, int32) uintptr
cgEventPost func(int32, uintptr)
// CGEventCreateScrollWheelEvent is variadic, call via SyscallN.
cgEventCreateScrollWheelEventAddr uintptr
axIsProcessTrusted func() bool
// IOKit power-management bindings used to wake the display and inhibit
// idle sleep while a VNC client is driving input.
iopmAssertionDeclareUserActivity func(uintptr, int32, *uint32) int32
iopmAssertionCreateWithName func(uintptr, uint32, uintptr, *uint32) int32
iopmAssertionRelease func(uint32) int32
cfStringCreateWithCString func(uintptr, string, uint32) uintptr
// Cached CFStrings for assertion name and idle-sleep type.
pmAssertionNameCFStr uintptr
pmPreventIdleDisplayCFStr uintptr
// Assertion IDs. userActivityID is reused across input events so repeated
// calls refresh the same assertion rather than create new ones.
pmMu sync.Mutex
userActivityID uint32
preventSleepID uint32
preventSleepHeld bool
darwinInputReady bool
darwinEventSource uintptr
)
func initDarwinInput() {
darwinInputOnce.Do(func() {
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
log.Debugf("load CoreGraphics for input: %v", err)
return
}
purego.RegisterLibFunc(&cgEventSourceCreate, cg, "CGEventSourceCreate")
purego.RegisterLibFunc(&cgEventCreateKeyboardEvent, cg, "CGEventCreateKeyboardEvent")
purego.RegisterLibFunc(&cgEventCreateMouseEvent, cg, "CGEventCreateMouseEvent")
purego.RegisterLibFunc(&cgEventPost, cg, "CGEventPost")
sym, err := purego.Dlsym(cg, "CGEventCreateScrollWheelEvent")
if err == nil {
cgEventCreateScrollWheelEventAddr = sym
}
if ax, err := purego.Dlopen("/System/Library/Frameworks/ApplicationServices.framework/ApplicationServices", purego.RTLD_NOW|purego.RTLD_GLOBAL); err == nil {
if sym, err := purego.Dlsym(ax, "AXIsProcessTrusted"); err == nil {
purego.RegisterFunc(&axIsProcessTrusted, sym)
}
}
initPowerAssertions()
darwinInputReady = true
})
}
func initPowerAssertions() {
iokit, err := purego.Dlopen("/System/Library/Frameworks/IOKit.framework/IOKit", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
log.Debugf("load IOKit: %v", err)
return
}
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
log.Debugf("load CoreFoundation for power assertions: %v", err)
return
}
purego.RegisterLibFunc(&cfStringCreateWithCString, cf, "CFStringCreateWithCString")
purego.RegisterLibFunc(&iopmAssertionDeclareUserActivity, iokit, "IOPMAssertionDeclareUserActivity")
purego.RegisterLibFunc(&iopmAssertionCreateWithName, iokit, "IOPMAssertionCreateWithName")
purego.RegisterLibFunc(&iopmAssertionRelease, iokit, "IOPMAssertionRelease")
pmAssertionNameCFStr = cfStringCreateWithCString(0, "NetBird VNC input", kCFStringEncodingUTF8)
pmPreventIdleDisplayCFStr = cfStringCreateWithCString(0, "PreventUserIdleDisplaySleep", kCFStringEncodingUTF8)
}
// wakeDisplay declares user activity so macOS treats the synthesized input as
// real HID activity, waking the display if it is asleep. Called on every key
// and pointer event; the kernel coalesces repeated calls cheaply.
func wakeDisplay() {
if iopmAssertionDeclareUserActivity == nil || pmAssertionNameCFStr == 0 {
return
}
pmMu.Lock()
id := userActivityID
pmMu.Unlock()
r := iopmAssertionDeclareUserActivity(pmAssertionNameCFStr, kIOPMUserActiveLocal, &id)
if r != 0 {
log.Tracef("IOPMAssertionDeclareUserActivity returned %d", r)
return
}
pmMu.Lock()
userActivityID = id
pmMu.Unlock()
}
// holdPreventIdleSleep creates an assertion that keeps the display from going
// idle-to-sleep while a VNC session is active. Safe to call repeatedly.
func holdPreventIdleSleep() {
if iopmAssertionCreateWithName == nil || pmPreventIdleDisplayCFStr == 0 || pmAssertionNameCFStr == 0 {
return
}
pmMu.Lock()
defer pmMu.Unlock()
if preventSleepHeld {
return
}
var id uint32
r := iopmAssertionCreateWithName(pmPreventIdleDisplayCFStr, kIOPMAssertionLevelOn, pmAssertionNameCFStr, &id)
if r != 0 {
log.Debugf("IOPMAssertionCreateWithName returned %d", r)
return
}
preventSleepID = id
preventSleepHeld = true
}
// releasePreventIdleSleep drops the idle-sleep assertion.
func releasePreventIdleSleep() {
if iopmAssertionRelease == nil {
return
}
pmMu.Lock()
defer pmMu.Unlock()
if !preventSleepHeld {
return
}
if r := iopmAssertionRelease(preventSleepID); r != 0 {
log.Debugf("IOPMAssertionRelease returned %d", r)
}
preventSleepHeld = false
preventSleepID = 0
}
func ensureEventSource() uintptr {
if darwinEventSource != 0 {
return darwinEventSource
}
darwinEventSource = cgEventSourceCreate(kCGEventSourceStateCombinedSessionState)
return darwinEventSource
}
// MacInputInjector injects keyboard and mouse events via Core Graphics.
type MacInputInjector struct {
lastButtons uint8
pbcopyPath string
pbpastePath string
}
// NewMacInputInjector creates a macOS input injector.
func NewMacInputInjector() (*MacInputInjector, error) {
initDarwinInput()
if !darwinInputReady {
return nil, fmt.Errorf("CoreGraphics not available for input injection")
}
checkMacPermissions()
m := &MacInputInjector{}
if path, err := exec.LookPath("pbcopy"); err == nil {
m.pbcopyPath = path
}
if path, err := exec.LookPath("pbpaste"); err == nil {
m.pbpastePath = path
}
if m.pbcopyPath == "" || m.pbpastePath == "" {
log.Debugf("clipboard tools not found (pbcopy=%q, pbpaste=%q)", m.pbcopyPath, m.pbpastePath)
}
holdPreventIdleSleep()
log.Info("macOS input injector ready")
return m, nil
}
// checkMacPermissions warns and opens the Privacy pane if Accessibility is
// missing. Uses AXIsProcessTrusted which returns immediately; the previous
// osascript probe blocked for 120s (AppleEvent timeout) when access was
// denied, which delayed VNC server startup past client deadlines.
func checkMacPermissions() {
if axIsProcessTrusted != nil && !axIsProcessTrusted() {
openPrivacyPane("Privacy_Accessibility")
log.Warn("Accessibility permission not granted. Input injection will not work. " +
"Opened System Settings > Privacy & Security > Accessibility; enable netbird.")
}
log.Info("Screen Recording permission is required for screen capture. " +
"If the screen appears black, grant in System Settings > Privacy & Security > Screen Recording.")
}
// openPrivacyPane opens the given Privacy pane in System Settings so the user
// can toggle the permission without navigating manually.
func openPrivacyPane(pane string) {
url := "x-apple.systempreferences:com.apple.preference.security?" + pane
if err := exec.Command("open", url).Start(); err != nil {
log.Debugf("open privacy pane %s: %v", pane, err)
}
}
// InjectKey simulates a key press or release.
func (m *MacInputInjector) InjectKey(keysym uint32, down bool) {
wakeDisplay()
src := ensureEventSource()
if src == 0 {
return
}
keycode := keysymToMacKeycode(keysym)
if keycode == 0xFFFF {
return
}
event := cgEventCreateKeyboardEvent(src, keycode, down)
if event == 0 {
return
}
cgEventPost(kCGHIDEventTap, event)
cfRelease(event)
}
// InjectPointer simulates mouse movement and button events.
func (m *MacInputInjector) InjectPointer(buttonMask uint8, px, py, serverW, serverH int) {
wakeDisplay()
if serverW == 0 || serverH == 0 {
return
}
src := ensureEventSource()
if src == 0 {
return
}
// Framebuffer is in physical pixels (Retina). CGEventCreateMouseEvent
// expects logical points, so scale down by the display's pixel/point ratio.
x := float64(px)
y := float64(py)
if cgDisplayPixelsWide != nil && cgMainDisplayID != nil {
displayID := cgMainDisplayID()
logicalW := int(cgDisplayPixelsWide(displayID))
logicalH := int(cgDisplayPixelsHigh(displayID))
if logicalW > 0 && logicalH > 0 {
x = float64(px) * float64(logicalW) / float64(serverW)
y = float64(py) * float64(logicalH) / float64(serverH)
}
}
leftDown := buttonMask&0x01 != 0
rightDown := buttonMask&0x04 != 0
middleDown := buttonMask&0x02 != 0
scrollUp := buttonMask&0x08 != 0
scrollDown := buttonMask&0x10 != 0
wasLeft := m.lastButtons&0x01 != 0
wasRight := m.lastButtons&0x04 != 0
wasMiddle := m.lastButtons&0x02 != 0
if leftDown {
m.postMouse(src, kCGEventLeftMouseDragged, x, y, kCGMouseButtonLeft)
} else if rightDown {
m.postMouse(src, kCGEventRightMouseDragged, x, y, kCGMouseButtonRight)
} else {
m.postMouse(src, kCGEventMouseMoved, x, y, kCGMouseButtonLeft)
}
if leftDown && !wasLeft {
m.postMouse(src, kCGEventLeftMouseDown, x, y, kCGMouseButtonLeft)
} else if !leftDown && wasLeft {
m.postMouse(src, kCGEventLeftMouseUp, x, y, kCGMouseButtonLeft)
}
if rightDown && !wasRight {
m.postMouse(src, kCGEventRightMouseDown, x, y, kCGMouseButtonRight)
} else if !rightDown && wasRight {
m.postMouse(src, kCGEventRightMouseUp, x, y, kCGMouseButtonRight)
}
if middleDown && !wasMiddle {
m.postMouse(src, kCGEventOtherMouseDown, x, y, kCGMouseButtonCenter)
} else if !middleDown && wasMiddle {
m.postMouse(src, kCGEventOtherMouseUp, x, y, kCGMouseButtonCenter)
}
if scrollUp {
m.postScroll(src, 3)
}
if scrollDown {
m.postScroll(src, -3)
}
m.lastButtons = buttonMask
}
func (m *MacInputInjector) postMouse(src uintptr, eventType int32, x, y float64, button int32) {
if cgEventCreateMouseEvent == nil {
return
}
event := cgEventCreateMouseEvent(src, eventType, x, y, button)
if event == 0 {
return
}
cgEventPost(kCGHIDEventTap, event)
cfRelease(event)
}
func (m *MacInputInjector) postScroll(src uintptr, deltaY int32) {
if cgEventCreateScrollWheelEventAddr == 0 {
return
}
// CGEventCreateScrollWheelEvent(source, units, wheelCount, wheel1delta)
// units=0 (pixel), wheelCount=1, wheel1delta=deltaY
// Variadic C function: pass args as uintptr via SyscallN.
r1, _, _ := purego.SyscallN(cgEventCreateScrollWheelEventAddr,
src, 0, 1, uintptr(uint32(deltaY)))
if r1 == 0 {
return
}
cgEventPost(kCGHIDEventTap, r1)
cfRelease(r1)
}
// SetClipboard sets the macOS clipboard using pbcopy.
func (m *MacInputInjector) SetClipboard(text string) {
if m.pbcopyPath == "" {
return
}
cmd := exec.Command(m.pbcopyPath)
cmd.Stdin = strings.NewReader(text)
if err := cmd.Run(); err != nil {
log.Tracef("set clipboard via pbcopy: %v", err)
}
}
// GetClipboard reads the macOS clipboard using pbpaste.
func (m *MacInputInjector) GetClipboard() string {
if m.pbpastePath == "" {
return ""
}
out, err := exec.Command(m.pbpastePath).Output()
if err != nil {
log.Tracef("get clipboard via pbpaste: %v", err)
return ""
}
return string(out)
}
// Close releases the idle-sleep assertion held for the injector's lifetime.
func (m *MacInputInjector) Close() {
releasePreventIdleSleep()
}
func keysymToMacKeycode(keysym uint32) uint16 {
if keysym >= 0x61 && keysym <= 0x7a {
return asciiToMacKey[keysym-0x61]
}
if keysym >= 0x41 && keysym <= 0x5a {
return asciiToMacKey[keysym-0x41]
}
if keysym >= 0x30 && keysym <= 0x39 {
return digitToMacKey[keysym-0x30]
}
if code, ok := specialKeyMap[keysym]; ok {
return code
}
return 0xFFFF
}
var asciiToMacKey = [26]uint16{
0x00, 0x0B, 0x08, 0x02, 0x0E, 0x03, 0x05, 0x04,
0x22, 0x26, 0x28, 0x25, 0x2E, 0x2D, 0x1F, 0x23,
0x0C, 0x0F, 0x01, 0x11, 0x20, 0x09, 0x0D, 0x07,
0x10, 0x06,
}
var digitToMacKey = [10]uint16{
0x1D, 0x12, 0x13, 0x14, 0x15, 0x17, 0x16, 0x1A, 0x1C, 0x19,
}
var specialKeyMap = map[uint32]uint16{
// Whitespace and editing
0x0020: 0x31, // space
0xff08: 0x33, // BackSpace
0xff09: 0x30, // Tab
0xff0d: 0x24, // Return
0xff1b: 0x35, // Escape
0xffff: 0x75, // Delete (forward)
// Navigation
0xff50: 0x73, // Home
0xff51: 0x7B, // Left
0xff52: 0x7E, // Up
0xff53: 0x7C, // Right
0xff54: 0x7D, // Down
0xff55: 0x74, // Page_Up
0xff56: 0x79, // Page_Down
0xff57: 0x77, // End
0xff63: 0x72, // Insert (Help on Mac)
// Modifiers
0xffe1: 0x38, // Shift_L
0xffe2: 0x3C, // Shift_R
0xffe3: 0x3B, // Control_L
0xffe4: 0x3E, // Control_R
0xffe5: 0x39, // Caps_Lock
0xffe9: 0x3A, // Alt_L (Option)
0xffea: 0x3D, // Alt_R (Option)
0xffe7: 0x37, // Meta_L (Command)
0xffe8: 0x36, // Meta_R (Command)
0xffeb: 0x37, // Super_L (Command) - noVNC sends this
0xffec: 0x36, // Super_R (Command)
// Mode_switch / ISO_Level3_Shift (sent by noVNC for macOS Option remap)
0xff7e: 0x3A, // Mode_switch -> Option
0xfe03: 0x3D, // ISO_Level3_Shift -> Right Option
// Function keys
0xffbe: 0x7A, // F1
0xffbf: 0x78, // F2
0xffc0: 0x63, // F3
0xffc1: 0x76, // F4
0xffc2: 0x60, // F5
0xffc3: 0x61, // F6
0xffc4: 0x62, // F7
0xffc5: 0x64, // F8
0xffc6: 0x65, // F9
0xffc7: 0x6D, // F10
0xffc8: 0x67, // F11
0xffc9: 0x6F, // F12
0xffca: 0x69, // F13
0xffcb: 0x6B, // F14
0xffcc: 0x71, // F15
0xffcd: 0x6A, // F16
0xffce: 0x40, // F17
0xffcf: 0x4F, // F18
0xffd0: 0x50, // F19
0xffd1: 0x5A, // F20
// Punctuation (US keyboard layout, keysym = ASCII code)
0x002d: 0x1B, // minus -
0x003d: 0x18, // equal =
0x005b: 0x21, // bracketleft [
0x005d: 0x1E, // bracketright ]
0x005c: 0x2A, // backslash
0x003b: 0x29, // semicolon ;
0x0027: 0x27, // apostrophe '
0x0060: 0x32, // grave `
0x002c: 0x2B, // comma ,
0x002e: 0x2F, // period .
0x002f: 0x2C, // slash /
// Shifted punctuation (noVNC sends these as separate keysyms)
0x005f: 0x1B, // underscore _ (shift+minus)
0x002b: 0x18, // plus + (shift+equal)
0x007b: 0x21, // braceleft { (shift+[)
0x007d: 0x1E, // braceright } (shift+])
0x007c: 0x2A, // bar | (shift+\)
0x003a: 0x29, // colon : (shift+;)
0x0022: 0x27, // quotedbl " (shift+')
0x007e: 0x32, // tilde ~ (shift+`)
0x003c: 0x2B, // less < (shift+,)
0x003e: 0x2F, // greater > (shift+.)
0x003f: 0x2C, // question ? (shift+/)
0x0021: 0x12, // exclam ! (shift+1)
0x0040: 0x13, // at @ (shift+2)
0x0023: 0x14, // numbersign # (shift+3)
0x0024: 0x15, // dollar $ (shift+4)
0x0025: 0x17, // percent % (shift+5)
0x005e: 0x16, // asciicircum ^ (shift+6)
0x0026: 0x1A, // ampersand & (shift+7)
0x002a: 0x1C, // asterisk * (shift+8)
0x0028: 0x19, // parenleft ( (shift+9)
0x0029: 0x1D, // parenright ) (shift+0)
// Numpad
0xffb0: 0x52, // KP_0
0xffb1: 0x53, // KP_1
0xffb2: 0x54, // KP_2
0xffb3: 0x55, // KP_3
0xffb4: 0x56, // KP_4
0xffb5: 0x57, // KP_5
0xffb6: 0x58, // KP_6
0xffb7: 0x59, // KP_7
0xffb8: 0x5B, // KP_8
0xffb9: 0x5C, // KP_9
0xffae: 0x41, // KP_Decimal
0xffaa: 0x43, // KP_Multiply
0xffab: 0x45, // KP_Add
0xffad: 0x4E, // KP_Subtract
0xffaf: 0x4B, // KP_Divide
0xff8d: 0x4C, // KP_Enter
0xffbd: 0x51, // KP_Equal
}
var _ InputInjector = (*MacInputInjector)(nil)

View File

@@ -0,0 +1,398 @@
//go:build windows
package server
import (
"runtime"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
var (
procOpenEventW = kernel32.NewProc("OpenEventW")
procSendInput = user32.NewProc("SendInput")
procVkKeyScanA = user32.NewProc("VkKeyScanA")
)
const eventModifyState = 0x0002
const (
inputMouse = 0
inputKeyboard = 1
mouseeventfMove = 0x0001
mouseeventfLeftDown = 0x0002
mouseeventfLeftUp = 0x0004
mouseeventfRightDown = 0x0008
mouseeventfRightUp = 0x0010
mouseeventfMiddleDown = 0x0020
mouseeventfMiddleUp = 0x0040
mouseeventfWheel = 0x0800
mouseeventfAbsolute = 0x8000
wheelDelta = 120
keyeventfKeyUp = 0x0002
keyeventfScanCode = 0x0008
)
type mouseInput struct {
Dx int32
Dy int32
MouseData uint32
DwFlags uint32
Time uint32
DwExtraInfo uintptr
}
type keybdInput struct {
WVk uint16
WScan uint16
DwFlags uint32
Time uint32
DwExtraInfo uintptr
_ [8]byte
}
type inputUnion [32]byte
type winInput struct {
Type uint32
_ [4]byte
Data inputUnion
}
func sendMouseInput(flags uint32, dx, dy int32, mouseData uint32) {
mi := mouseInput{
Dx: dx,
Dy: dy,
MouseData: mouseData,
DwFlags: flags,
}
inp := winInput{Type: inputMouse}
copy(inp.Data[:], (*[unsafe.Sizeof(mi)]byte)(unsafe.Pointer(&mi))[:])
r, _, err := procSendInput.Call(1, uintptr(unsafe.Pointer(&inp)), unsafe.Sizeof(inp))
if r == 0 {
log.Tracef("SendInput(mouse flags=0x%x): %v", flags, err)
}
}
func sendKeyInput(vk uint16, scanCode uint16, flags uint32) {
ki := keybdInput{
WVk: vk,
WScan: scanCode,
DwFlags: flags,
}
inp := winInput{Type: inputKeyboard}
copy(inp.Data[:], (*[unsafe.Sizeof(ki)]byte)(unsafe.Pointer(&ki))[:])
r, _, err := procSendInput.Call(1, uintptr(unsafe.Pointer(&inp)), unsafe.Sizeof(inp))
if r == 0 {
log.Tracef("SendInput(key vk=0x%x): %v", vk, err)
}
}
const sasEventName = `Global\NetBirdVNC_SAS`
type inputCmd struct {
isKey bool
keysym uint32
down bool
buttonMask uint8
x, y int
serverW int
serverH int
}
// WindowsInputInjector delivers input events from a dedicated OS thread that
// calls switchToInputDesktop before each injection. SendInput targets the
// calling thread's desktop, so the injection thread must be on the same
// desktop the user sees.
type WindowsInputInjector struct {
ch chan inputCmd
prevButtonMask uint8
ctrlDown bool
altDown bool
}
// NewWindowsInputInjector creates a desktop-aware input injector.
func NewWindowsInputInjector() *WindowsInputInjector {
w := &WindowsInputInjector{ch: make(chan inputCmd, 64)}
go w.loop()
return w
}
func (w *WindowsInputInjector) loop() {
runtime.LockOSThread()
for cmd := range w.ch {
// Switch to the current input desktop so SendInput reaches the right target.
switchToInputDesktop()
if cmd.isKey {
w.doInjectKey(cmd.keysym, cmd.down)
} else {
w.doInjectPointer(cmd.buttonMask, cmd.x, cmd.y, cmd.serverW, cmd.serverH)
}
}
}
// InjectKey queues a key event for injection on the input desktop thread.
func (w *WindowsInputInjector) InjectKey(keysym uint32, down bool) {
w.ch <- inputCmd{isKey: true, keysym: keysym, down: down}
}
// InjectPointer queues a pointer event for injection on the input desktop thread.
func (w *WindowsInputInjector) InjectPointer(buttonMask uint8, x, y, serverW, serverH int) {
w.ch <- inputCmd{buttonMask: buttonMask, x: x, y: y, serverW: serverW, serverH: serverH}
}
func (w *WindowsInputInjector) doInjectKey(keysym uint32, down bool) {
switch keysym {
case 0xffe3, 0xffe4:
w.ctrlDown = down
case 0xffe9, 0xffea:
w.altDown = down
}
if (keysym == 0xff9f || keysym == 0xffff) && w.ctrlDown && w.altDown && down {
signalSAS()
return
}
vk, _, extended := keysym2VK(keysym)
if vk == 0 {
return
}
var flags uint32
if !down {
flags |= keyeventfKeyUp
}
if extended {
flags |= keyeventfScanCode
}
sendKeyInput(vk, 0, flags)
}
// signalSAS signals the SAS named event. A listener in Session 0
// (startSASListener) calls SendSAS to trigger the Secure Attention Sequence.
func signalSAS() {
namePtr, err := windows.UTF16PtrFromString(sasEventName)
if err != nil {
log.Warnf("SAS UTF16: %v", err)
return
}
h, _, lerr := procOpenEventW.Call(
uintptr(eventModifyState),
0,
uintptr(unsafe.Pointer(namePtr)),
)
if h == 0 {
log.Warnf("OpenEvent(%s): %v", sasEventName, lerr)
return
}
ev := windows.Handle(h)
defer windows.CloseHandle(ev)
if err := windows.SetEvent(ev); err != nil {
log.Warnf("SetEvent SAS: %v", err)
} else {
log.Info("SAS event signaled")
}
}
func (w *WindowsInputInjector) doInjectPointer(buttonMask uint8, x, y, serverW, serverH int) {
if serverW == 0 || serverH == 0 {
return
}
absX := int32(x * 65535 / serverW)
absY := int32(y * 65535 / serverH)
sendMouseInput(mouseeventfMove|mouseeventfAbsolute, absX, absY, 0)
changed := buttonMask ^ w.prevButtonMask
w.prevButtonMask = buttonMask
type btnMap struct {
bit uint8
down uint32
up uint32
}
buttons := [...]btnMap{
{0x01, mouseeventfLeftDown, mouseeventfLeftUp},
{0x02, mouseeventfMiddleDown, mouseeventfMiddleUp},
{0x04, mouseeventfRightDown, mouseeventfRightUp},
}
for _, b := range buttons {
if changed&b.bit == 0 {
continue
}
var flags uint32
if buttonMask&b.bit != 0 {
flags = b.down
} else {
flags = b.up
}
sendMouseInput(flags|mouseeventfAbsolute, absX, absY, 0)
}
negWheelDelta := ^uint32(wheelDelta - 1)
if changed&0x08 != 0 && buttonMask&0x08 != 0 {
sendMouseInput(mouseeventfWheel|mouseeventfAbsolute, absX, absY, wheelDelta)
}
if changed&0x10 != 0 && buttonMask&0x10 != 0 {
sendMouseInput(mouseeventfWheel|mouseeventfAbsolute, absX, absY, negWheelDelta)
}
}
// keysym2VK converts an X11 keysym to a Windows virtual key code.
func keysym2VK(keysym uint32) (vk uint16, scan uint16, extended bool) {
if keysym >= 0x20 && keysym <= 0x7e {
r, _, _ := procVkKeyScanA.Call(uintptr(keysym))
vk = uint16(r & 0xff)
return
}
if keysym >= 0xffbe && keysym <= 0xffc9 {
vk = uint16(0x70 + keysym - 0xffbe)
return
}
switch keysym {
case 0xff08:
vk = 0x08 // Backspace
case 0xff09:
vk = 0x09 // Tab
case 0xff0d:
vk = 0x0d // Return
case 0xff1b:
vk = 0x1b // Escape
case 0xff63:
vk, extended = 0x2d, true // Insert
case 0xff9f, 0xffff:
vk, extended = 0x2e, true // Delete
case 0xff50:
vk, extended = 0x24, true // Home
case 0xff57:
vk, extended = 0x23, true // End
case 0xff55:
vk, extended = 0x21, true // PageUp
case 0xff56:
vk, extended = 0x22, true // PageDown
case 0xff51:
vk, extended = 0x25, true // Left
case 0xff52:
vk, extended = 0x26, true // Up
case 0xff53:
vk, extended = 0x27, true // Right
case 0xff54:
vk, extended = 0x28, true // Down
case 0xffe1, 0xffe2:
vk = 0x10 // Shift
case 0xffe3, 0xffe4:
vk = 0x11 // Control
case 0xffe9, 0xffea:
vk = 0x12 // Alt
case 0xffe5:
vk = 0x14 // CapsLock
case 0xffe7, 0xffeb:
vk, extended = 0x5B, true // Meta_L / Super_L -> Left Windows
case 0xffe8, 0xffec:
vk, extended = 0x5C, true // Meta_R / Super_R -> Right Windows
case 0xff61:
vk = 0x2c // PrintScreen
case 0xff13:
vk = 0x13 // Pause
case 0xff14:
vk = 0x91 // ScrollLock
}
return
}
var (
procOpenClipboard = user32.NewProc("OpenClipboard")
procCloseClipboard = user32.NewProc("CloseClipboard")
procEmptyClipboard = user32.NewProc("EmptyClipboard")
procSetClipboardData = user32.NewProc("SetClipboardData")
procGetClipboardData = user32.NewProc("GetClipboardData")
procIsClipboardFormatAvailable = user32.NewProc("IsClipboardFormatAvailable")
procGlobalAlloc = kernel32.NewProc("GlobalAlloc")
procGlobalLock = kernel32.NewProc("GlobalLock")
procGlobalUnlock = kernel32.NewProc("GlobalUnlock")
)
const (
cfUnicodeText = 13
gmemMoveable = 0x0002
)
// SetClipboard sets the Windows clipboard to the given UTF-8 text.
func (w *WindowsInputInjector) SetClipboard(text string) {
utf16, err := windows.UTF16FromString(text)
if err != nil {
log.Tracef("clipboard UTF16 encode: %v", err)
return
}
size := uintptr(len(utf16) * 2)
hMem, _, _ := procGlobalAlloc.Call(gmemMoveable, size)
if hMem == 0 {
log.Tracef("GlobalAlloc for clipboard: allocation returned nil")
return
}
ptr, _, _ := procGlobalLock.Call(hMem)
if ptr == 0 {
log.Tracef("GlobalLock for clipboard: lock returned nil")
return
}
copy(unsafe.Slice((*uint16)(unsafe.Pointer(ptr)), len(utf16)), utf16)
procGlobalUnlock.Call(hMem)
r, _, lerr := procOpenClipboard.Call(0)
if r == 0 {
log.Tracef("OpenClipboard: %v", lerr)
return
}
defer procCloseClipboard.Call()
procEmptyClipboard.Call()
r, _, lerr = procSetClipboardData.Call(cfUnicodeText, hMem)
if r == 0 {
log.Tracef("SetClipboardData: %v", lerr)
}
}
// GetClipboard reads the Windows clipboard as UTF-8 text.
func (w *WindowsInputInjector) GetClipboard() string {
r, _, _ := procIsClipboardFormatAvailable.Call(cfUnicodeText)
if r == 0 {
return ""
}
r, _, lerr := procOpenClipboard.Call(0)
if r == 0 {
log.Tracef("OpenClipboard for read: %v", lerr)
return ""
}
defer procCloseClipboard.Call()
hData, _, _ := procGetClipboardData.Call(cfUnicodeText)
if hData == 0 {
return ""
}
ptr, _, _ := procGlobalLock.Call(hData)
if ptr == 0 {
return ""
}
defer procGlobalUnlock.Call(hData)
return windows.UTF16PtrToString((*uint16)(unsafe.Pointer(ptr)))
}
var _ InputInjector = (*WindowsInputInjector)(nil)
var _ ScreenCapturer = (*DesktopCapturer)(nil)

View File

@@ -0,0 +1,242 @@
//go:build (linux && !android) || freebsd
package server
import (
"fmt"
"os"
"os/exec"
"strings"
log "github.com/sirupsen/logrus"
"github.com/jezek/xgb"
"github.com/jezek/xgb/xproto"
"github.com/jezek/xgb/xtest"
)
// X11InputInjector injects keyboard and mouse events via the XTest extension.
type X11InputInjector struct {
conn *xgb.Conn
root xproto.Window
screen *xproto.ScreenInfo
display string
keysymMap map[uint32]byte
lastButtons uint8
clipboardTool string
clipboardToolName string
}
// NewX11InputInjector connects to the X11 display and initializes XTest.
func NewX11InputInjector(display string) (*X11InputInjector, error) {
detectX11Display()
if display == "" {
display = os.Getenv("DISPLAY")
}
if display == "" {
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
}
conn, err := xgb.NewConnDisplay(display)
if err != nil {
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
}
if err := xtest.Init(conn); err != nil {
conn.Close()
return nil, fmt.Errorf("init XTest extension: %w", err)
}
setup := xproto.Setup(conn)
if len(setup.Roots) == 0 {
conn.Close()
return nil, fmt.Errorf("no X11 screens")
}
screen := setup.Roots[0]
inj := &X11InputInjector{
conn: conn,
root: screen.Root,
screen: &screen,
display: display,
}
inj.cacheKeyboardMapping()
inj.resolveClipboardTool()
log.Infof("X11 input injector ready (display=%s)", display)
return inj, nil
}
// InjectKey simulates a key press or release. keysym is an X11 KeySym.
func (x *X11InputInjector) InjectKey(keysym uint32, down bool) {
keycode := x.keysymToKeycode(keysym)
if keycode == 0 {
return
}
var eventType byte
if down {
eventType = xproto.KeyPress
} else {
eventType = xproto.KeyRelease
}
xtest.FakeInput(x.conn, eventType, keycode, 0, x.root, 0, 0, 0)
}
// InjectPointer simulates mouse movement and button events.
func (x *X11InputInjector) InjectPointer(buttonMask uint8, px, py, serverW, serverH int) {
if serverW == 0 || serverH == 0 {
return
}
// Scale to actual screen coordinates.
screenW := int(x.screen.WidthInPixels)
screenH := int(x.screen.HeightInPixels)
absX := px * screenW / serverW
absY := py * screenH / serverH
// Move pointer.
xtest.FakeInput(x.conn, xproto.MotionNotify, 0, 0, x.root, int16(absX), int16(absY), 0)
// Handle button events. RFB button mask: bit0=left, bit1=middle, bit2=right,
// bit3=scrollUp, bit4=scrollDown. X11 buttons: 1=left, 2=middle, 3=right,
// 4=scrollUp, 5=scrollDown.
type btnMap struct {
rfbBit uint8
x11Btn byte
}
buttons := [...]btnMap{
{0x01, 1}, // left
{0x02, 2}, // middle
{0x04, 3}, // right
{0x08, 4}, // scroll up
{0x10, 5}, // scroll down
}
for _, b := range buttons {
pressed := buttonMask&b.rfbBit != 0
wasPressed := x.lastButtons&b.rfbBit != 0
if b.x11Btn >= 4 {
// Scroll: send press+release on each scroll event.
if pressed {
xtest.FakeInput(x.conn, xproto.ButtonPress, b.x11Btn, 0, x.root, 0, 0, 0)
xtest.FakeInput(x.conn, xproto.ButtonRelease, b.x11Btn, 0, x.root, 0, 0, 0)
}
} else {
if pressed && !wasPressed {
xtest.FakeInput(x.conn, xproto.ButtonPress, b.x11Btn, 0, x.root, 0, 0, 0)
} else if !pressed && wasPressed {
xtest.FakeInput(x.conn, xproto.ButtonRelease, b.x11Btn, 0, x.root, 0, 0, 0)
}
}
}
x.lastButtons = buttonMask
}
// cacheKeyboardMapping fetches the X11 keyboard mapping once and stores it
// as a keysym-to-keycode map, avoiding a round-trip per keystroke.
func (x *X11InputInjector) cacheKeyboardMapping() {
setup := xproto.Setup(x.conn)
minKeycode := setup.MinKeycode
maxKeycode := setup.MaxKeycode
reply, err := xproto.GetKeyboardMapping(x.conn, minKeycode,
byte(maxKeycode-minKeycode+1)).Reply()
if err != nil {
log.Debugf("cache keyboard mapping: %v", err)
x.keysymMap = make(map[uint32]byte)
return
}
m := make(map[uint32]byte, int(maxKeycode-minKeycode+1)*int(reply.KeysymsPerKeycode))
keysymsPerKeycode := int(reply.KeysymsPerKeycode)
for i := int(minKeycode); i <= int(maxKeycode); i++ {
offset := (i - int(minKeycode)) * keysymsPerKeycode
for j := 0; j < keysymsPerKeycode; j++ {
ks := uint32(reply.Keysyms[offset+j])
if ks != 0 {
if _, exists := m[ks]; !exists {
m[ks] = byte(i)
}
}
}
}
x.keysymMap = m
}
// keysymToKeycode looks up a cached keysym-to-keycode mapping.
// Returns 0 if the keysym is not mapped.
func (x *X11InputInjector) keysymToKeycode(keysym uint32) byte {
return x.keysymMap[keysym]
}
// SetClipboard sets the X11 clipboard using xclip or xsel.
func (x *X11InputInjector) SetClipboard(text string) {
if x.clipboardTool == "" {
return
}
var cmd *exec.Cmd
if x.clipboardToolName == "xclip" {
cmd = exec.Command(x.clipboardTool, "-selection", "clipboard")
} else {
cmd = exec.Command(x.clipboardTool, "--clipboard", "--input")
}
cmd.Env = x.clipboardEnv()
cmd.Stdin = strings.NewReader(text)
if err := cmd.Run(); err != nil {
log.Debugf("set clipboard via %s: %v", x.clipboardToolName, err)
}
}
func (x *X11InputInjector) resolveClipboardTool() {
for _, name := range []string{"xclip", "xsel"} {
path, err := exec.LookPath(name)
if err == nil {
x.clipboardTool = path
x.clipboardToolName = name
log.Debugf("clipboard tool resolved to %s", path)
return
}
}
log.Debugf("no clipboard tool (xclip/xsel) found, clipboard sync disabled")
}
// GetClipboard reads the X11 clipboard using xclip or xsel.
func (x *X11InputInjector) GetClipboard() string {
if x.clipboardTool == "" {
return ""
}
var cmd *exec.Cmd
if x.clipboardToolName == "xclip" {
cmd = exec.Command(x.clipboardTool, "-selection", "clipboard", "-o")
} else {
cmd = exec.Command(x.clipboardTool, "--clipboard", "--output")
}
cmd.Env = x.clipboardEnv()
out, err := cmd.Output()
if err != nil {
log.Tracef("get clipboard via %s: %v", x.clipboardToolName, err)
return ""
}
return string(out)
}
func (x *X11InputInjector) clipboardEnv() []string {
env := []string{"DISPLAY=" + x.display}
if auth := os.Getenv("XAUTHORITY"); auth != "" {
env = append(env, "XAUTHORITY="+auth)
}
return env
}
// Close releases X11 resources.
func (x *X11InputInjector) Close() {
x.conn.Close()
}
var _ InputInjector = (*X11InputInjector)(nil)
var _ ScreenCapturer = (*X11Poller)(nil)

Some files were not shown because too many files have changed in this diff Show More