mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-01 22:56:41 +00:00
Compare commits
25 Commits
claude/rdp
...
vnc-server
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b754df1171 | ||
|
|
3098f48b25 | ||
|
|
7f023ce801 | ||
|
|
e361126515 | ||
|
|
95213f7157 | ||
|
|
2e0e3a3601 | ||
|
|
8ae8f2098f | ||
|
|
a39787d679 | ||
|
|
53b04e512a | ||
|
|
633dde8d1f | ||
|
|
7e4542adde | ||
|
|
d4c61ed38b | ||
|
|
6b540d145c | ||
|
|
08f624507d | ||
|
|
95bc01e48f | ||
|
|
0d86de47df | ||
|
|
e804a705b7 | ||
|
|
46fc8c9f65 | ||
|
|
d7ad908962 | ||
|
|
c5623307cc | ||
|
|
7f666b8022 | ||
|
|
0a30b9b275 | ||
|
|
4eed459f27 | ||
|
|
13539543af | ||
|
|
7483fec048 |
62
.github/workflows/proto-version-check.yml
vendored
Normal file
62
.github/workflows/proto-version-check.yml
vendored
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
name: Proto Version Check
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- "**/*.pb.go"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check-proto-versions:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Check for proto tool version changes
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const files = await github.paginate(github.rest.pulls.listFiles, {
|
||||||
|
owner: context.repo.owner,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
pull_number: context.issue.number,
|
||||||
|
per_page: 100,
|
||||||
|
});
|
||||||
|
|
||||||
|
const pbFiles = files.filter(f => f.filename.endsWith('.pb.go'));
|
||||||
|
const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename);
|
||||||
|
if (missingPatch.length > 0) {
|
||||||
|
core.setFailed(
|
||||||
|
`Cannot inspect patch data for:\n` +
|
||||||
|
missingPatch.map(f => `- ${f}`).join('\n') +
|
||||||
|
`\nThis can happen with very large PRs. Verify proto versions manually.`
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
|
||||||
|
const violations = [];
|
||||||
|
|
||||||
|
for (const file of pbFiles) {
|
||||||
|
const changed = file.patch
|
||||||
|
.split('\n')
|
||||||
|
.filter(line => versionPattern.test(line));
|
||||||
|
if (changed.length > 0) {
|
||||||
|
violations.push({
|
||||||
|
file: file.filename,
|
||||||
|
lines: changed,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (violations.length > 0) {
|
||||||
|
const details = violations.map(v =>
|
||||||
|
`${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}`
|
||||||
|
).join('\n\n');
|
||||||
|
|
||||||
|
core.setFailed(
|
||||||
|
`Proto version strings changed in generated files.\n` +
|
||||||
|
`This usually means the wrong protoc or protoc-gen-go version was used.\n` +
|
||||||
|
`Regenerate with the matching tool versions.\n\n` +
|
||||||
|
details
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log('No proto version string changes detected');
|
||||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.1.1"
|
SIGN_PIPE_VER: "v0.1.2"
|
||||||
GORELEASER_VER: "v2.14.3"
|
GORELEASER_VER: "v2.14.3"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
|
|||||||
2
Makefile
2
Makefile
@@ -5,7 +5,7 @@ GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
|||||||
$(GOLANGCI_LINT):
|
$(GOLANGCI_LINT):
|
||||||
@echo "Installing golangci-lint..."
|
@echo "Installing golangci-lint..."
|
||||||
@mkdir -p ./bin
|
@mkdir -p ./bin
|
||||||
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
|
||||||
|
|
||||||
# Lint only changed files (fast, for pre-push)
|
# Lint only changed files (fast, for pre-push)
|
||||||
lint: $(GOLANGCI_LINT)
|
lint: $(GOLANGCI_LINT)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
@@ -15,6 +16,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/debug"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
@@ -26,6 +28,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
types "github.com/netbirdio/netbird/upload-server/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionListener export internal Listener for mobile
|
// ConnectionListener export internal Listener for mobile
|
||||||
@@ -68,7 +71,30 @@ type Client struct {
|
|||||||
uiVersion string
|
uiVersion string
|
||||||
networkChangeListener listener.NetworkChangeListener
|
networkChangeListener listener.NetworkChangeListener
|
||||||
|
|
||||||
|
stateMu sync.RWMutex
|
||||||
connectClient *internal.ConnectClient
|
connectClient *internal.ConnectClient
|
||||||
|
config *profilemanager.Config
|
||||||
|
cacheDir string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) {
|
||||||
|
c.stateMu.Lock()
|
||||||
|
defer c.stateMu.Unlock()
|
||||||
|
c.config = cfg
|
||||||
|
c.cacheDir = cacheDir
|
||||||
|
c.connectClient = cc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) stateSnapshot() (*profilemanager.Config, string, *internal.ConnectClient) {
|
||||||
|
c.stateMu.RLock()
|
||||||
|
defer c.stateMu.RUnlock()
|
||||||
|
return c.config, c.cacheDir, c.connectClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) getConnectClient() *internal.ConnectClient {
|
||||||
|
c.stateMu.RLock()
|
||||||
|
defer c.stateMu.RUnlock()
|
||||||
|
return c.connectClient
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
@@ -93,6 +119,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
|||||||
|
|
||||||
cfgFile := platformFiles.ConfigurationFilePath()
|
cfgFile := platformFiles.ConfigurationFilePath()
|
||||||
stateFile := platformFiles.StateFilePath()
|
stateFile := platformFiles.StateFilePath()
|
||||||
|
cacheDir := platformFiles.CacheDir()
|
||||||
|
|
||||||
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
|
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
|
||||||
|
|
||||||
@@ -124,8 +151,9 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
c.setState(cfg, cacheDir, connectClient)
|
||||||
|
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||||
@@ -135,6 +163,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
|||||||
|
|
||||||
cfgFile := platformFiles.ConfigurationFilePath()
|
cfgFile := platformFiles.ConfigurationFilePath()
|
||||||
stateFile := platformFiles.StateFilePath()
|
stateFile := platformFiles.StateFilePath()
|
||||||
|
cacheDir := platformFiles.CacheDir()
|
||||||
|
|
||||||
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
|
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
|
||||||
|
|
||||||
@@ -157,8 +186,9 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
c.setState(cfg, cacheDir, connectClient)
|
||||||
|
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
@@ -173,11 +203,12 @@ func (c *Client) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) RenewTun(fd int) error {
|
func (c *Client) RenewTun(fd int) error {
|
||||||
if c.connectClient == nil {
|
cc := c.getConnectClient()
|
||||||
|
if cc == nil {
|
||||||
return fmt.Errorf("engine not running")
|
return fmt.Errorf("engine not running")
|
||||||
}
|
}
|
||||||
|
|
||||||
e := c.connectClient.Engine()
|
e := cc.Engine()
|
||||||
if e == nil {
|
if e == nil {
|
||||||
return fmt.Errorf("engine not initialized")
|
return fmt.Errorf("engine not initialized")
|
||||||
}
|
}
|
||||||
@@ -185,6 +216,73 @@ func (c *Client) RenewTun(fd int) error {
|
|||||||
return e.RenewTun(fd)
|
return e.RenewTun(fd)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DebugBundle generates a debug bundle, uploads it, and returns the upload key.
|
||||||
|
// It works both with and without a running engine.
|
||||||
|
func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (string, error) {
|
||||||
|
cfg, cacheDir, cc := c.stateSnapshot()
|
||||||
|
|
||||||
|
// If the engine hasn't been started, load config from disk
|
||||||
|
if cfg == nil {
|
||||||
|
var err error
|
||||||
|
cfg, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||||
|
ConfigPath: platformFiles.ConfigurationFilePath(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("load config: %w", err)
|
||||||
|
}
|
||||||
|
cacheDir = platformFiles.CacheDir()
|
||||||
|
}
|
||||||
|
|
||||||
|
deps := debug.GeneratorDependencies{
|
||||||
|
InternalConfig: cfg,
|
||||||
|
StatusRecorder: c.recorder,
|
||||||
|
TempDir: cacheDir,
|
||||||
|
}
|
||||||
|
|
||||||
|
if cc != nil {
|
||||||
|
resp, err := cc.GetLatestSyncResponse()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("get latest sync response: %v", err)
|
||||||
|
}
|
||||||
|
deps.SyncResponse = resp
|
||||||
|
|
||||||
|
if e := cc.Engine(); e != nil {
|
||||||
|
if cm := e.GetClientMetrics(); cm != nil {
|
||||||
|
deps.ClientMetrics = cm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bundleGenerator := debug.NewBundleGenerator(
|
||||||
|
deps,
|
||||||
|
debug.BundleConfig{
|
||||||
|
Anonymize: anonymize,
|
||||||
|
IncludeSystemInfo: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
path, err := bundleGenerator.Generate()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("generate debug bundle: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := os.Remove(path); err != nil {
|
||||||
|
log.Errorf("failed to remove debug bundle file: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("upload debug bundle: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("debug bundle uploaded with key %s", key)
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetTraceLogLevel configure the logger to trace level
|
// SetTraceLogLevel configure the logger to trace level
|
||||||
func (c *Client) SetTraceLogLevel() {
|
func (c *Client) SetTraceLogLevel() {
|
||||||
log.SetLevel(log.TraceLevel)
|
log.SetLevel(log.TraceLevel)
|
||||||
@@ -214,12 +312,13 @@ func (c *Client) PeersList() *PeerInfoArray {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Networks() *NetworkArray {
|
func (c *Client) Networks() *NetworkArray {
|
||||||
if c.connectClient == nil {
|
cc := c.getConnectClient()
|
||||||
|
if cc == nil {
|
||||||
log.Error("not connected")
|
log.Error("not connected")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
engine := c.connectClient.Engine()
|
engine := cc.Engine()
|
||||||
if engine == nil {
|
if engine == nil {
|
||||||
log.Error("could not get engine")
|
log.Error("could not get engine")
|
||||||
return nil
|
return nil
|
||||||
@@ -300,7 +399,7 @@ func (c *Client) toggleRoute(command routeCommand) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) getRouteManager() (routemanager.Manager, error) {
|
func (c *Client) getRouteManager() (routemanager.Manager, error) {
|
||||||
client := c.connectClient
|
client := c.getConnectClient()
|
||||||
if client == nil {
|
if client == nil {
|
||||||
return nil, fmt.Errorf("not connected")
|
return nil, fmt.Errorf("not connected")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,4 +7,5 @@ package android
|
|||||||
type PlatformFiles interface {
|
type PlatformFiles interface {
|
||||||
ConfigurationFilePath() string
|
ConfigurationFilePath() string
|
||||||
StateFilePath() string
|
StateFilePath() string
|
||||||
|
CacheDir() string
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ var (
|
|||||||
mtu uint16
|
mtu uint16
|
||||||
profilesDisabled bool
|
profilesDisabled bool
|
||||||
updateSettingsDisabled bool
|
updateSettingsDisabled bool
|
||||||
|
networksDisabled bool
|
||||||
|
|
||||||
rootCmd = &cobra.Command{
|
rootCmd = &cobra.Command{
|
||||||
Use: "netbird",
|
Use: "netbird",
|
||||||
@@ -150,6 +151,7 @@ func init() {
|
|||||||
rootCmd.AddCommand(logoutCmd)
|
rootCmd.AddCommand(logoutCmd)
|
||||||
rootCmd.AddCommand(versionCmd)
|
rootCmd.AddCommand(versionCmd)
|
||||||
rootCmd.AddCommand(sshCmd)
|
rootCmd.AddCommand(sshCmd)
|
||||||
|
rootCmd.AddCommand(vncCmd)
|
||||||
rootCmd.AddCommand(networksCMD)
|
rootCmd.AddCommand(networksCMD)
|
||||||
rootCmd.AddCommand(forwardingRulesCmd)
|
rootCmd.AddCommand(forwardingRulesCmd)
|
||||||
rootCmd.AddCommand(debugCmd)
|
rootCmd.AddCommand(debugCmd)
|
||||||
|
|||||||
@@ -44,10 +44,13 @@ func init() {
|
|||||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
|
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
|
||||||
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
|
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
|
||||||
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
||||||
|
serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks")
|
||||||
|
|
||||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||||
serviceEnvDesc := `Sets extra environment variables for the service. ` +
|
serviceEnvDesc := `Sets extra environment variables for the service. ` +
|
||||||
`You can specify a comma-separated list of KEY=VALUE pairs. ` +
|
`You can specify a comma-separated list of KEY=VALUE pairs. ` +
|
||||||
|
`New keys are merged with previously saved env vars; existing keys are overwritten. ` +
|
||||||
|
`Use --service-env "" to clear all saved env vars. ` +
|
||||||
`E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value`
|
`E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value`
|
||||||
|
|
||||||
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled)
|
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, networksDisabled)
|
||||||
if err := serverInstance.Start(); err != nil {
|
if err := serverInstance.Start(); err != nil {
|
||||||
log.Fatalf("failed to start daemon: %v", err)
|
log.Fatalf("failed to start daemon: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,6 +59,10 @@ func buildServiceArguments() []string {
|
|||||||
args = append(args, "--disable-update-settings")
|
args = append(args, "--disable-update-settings")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if networksDisabled {
|
||||||
|
args = append(args, "--disable-networks")
|
||||||
|
}
|
||||||
|
|
||||||
return args
|
return args
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ type serviceParams struct {
|
|||||||
LogFiles []string `json:"log_files,omitempty"`
|
LogFiles []string `json:"log_files,omitempty"`
|
||||||
DisableProfiles bool `json:"disable_profiles,omitempty"`
|
DisableProfiles bool `json:"disable_profiles,omitempty"`
|
||||||
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
||||||
|
DisableNetworks bool `json:"disable_networks,omitempty"`
|
||||||
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -78,11 +79,12 @@ func currentServiceParams() *serviceParams {
|
|||||||
LogFiles: logFiles,
|
LogFiles: logFiles,
|
||||||
DisableProfiles: profilesDisabled,
|
DisableProfiles: profilesDisabled,
|
||||||
DisableUpdateSettings: updateSettingsDisabled,
|
DisableUpdateSettings: updateSettingsDisabled,
|
||||||
|
DisableNetworks: networksDisabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(serviceEnvVars) > 0 {
|
if len(serviceEnvVars) > 0 {
|
||||||
parsed, err := parseServiceEnvVars(serviceEnvVars)
|
parsed, err := parseServiceEnvVars(serviceEnvVars)
|
||||||
if err == nil && len(parsed) > 0 {
|
if err == nil {
|
||||||
params.ServiceEnvVars = parsed
|
params.ServiceEnvVars = parsed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -142,31 +144,46 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
|
|||||||
updateSettingsDisabled = params.DisableUpdateSettings
|
updateSettingsDisabled = params.DisableUpdateSettings
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !serviceCmd.PersistentFlags().Changed("disable-networks") {
|
||||||
|
networksDisabled = params.DisableNetworks
|
||||||
|
}
|
||||||
|
|
||||||
applyServiceEnvParams(cmd, params)
|
applyServiceEnvParams(cmd, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyServiceEnvParams merges saved service environment variables.
|
// applyServiceEnvParams merges saved service environment variables.
|
||||||
// If --service-env was explicitly set, explicit values win on key conflict
|
// If --service-env was explicitly set with values, explicit values win on key
|
||||||
// but saved keys not in the explicit set are carried over.
|
// conflict but saved keys not in the explicit set are carried over.
|
||||||
|
// If --service-env was explicitly set to empty, all saved env vars are cleared.
|
||||||
// If --service-env was not set, saved env vars are used entirely.
|
// If --service-env was not set, saved env vars are used entirely.
|
||||||
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
|
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
|
||||||
if len(params.ServiceEnvVars) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !cmd.Flags().Changed("service-env") {
|
if !cmd.Flags().Changed("service-env") {
|
||||||
|
if len(params.ServiceEnvVars) > 0 {
|
||||||
// No explicit env vars: rebuild serviceEnvVars from saved params.
|
// No explicit env vars: rebuild serviceEnvVars from saved params.
|
||||||
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
|
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Explicit env vars were provided: merge saved values underneath.
|
// Flag was explicitly set: parse what the user provided.
|
||||||
explicit, err := parseServiceEnvVars(serviceEnvVars)
|
explicit, err := parseServiceEnvVars(serviceEnvVars)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
|
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the user passed an empty value (e.g. --service-env ""), clear all
|
||||||
|
// saved env vars rather than merging.
|
||||||
|
if len(explicit) == 0 {
|
||||||
|
serviceEnvVars = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(params.ServiceEnvVars) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge saved values underneath explicit ones.
|
||||||
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
|
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
|
||||||
maps.Copy(merged, params.ServiceEnvVars)
|
maps.Copy(merged, params.ServiceEnvVars)
|
||||||
maps.Copy(merged, explicit) // explicit wins on conflict
|
maps.Copy(merged, explicit) // explicit wins on conflict
|
||||||
|
|||||||
@@ -327,6 +327,41 @@ func TestApplyServiceEnvParams_NotChanged(t *testing.T) {
|
|||||||
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
|
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyServiceEnvParams_ExplicitEmptyClears(t *testing.T) {
|
||||||
|
origServiceEnvVars := serviceEnvVars
|
||||||
|
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
||||||
|
|
||||||
|
// Simulate --service-env "" which produces [""] in the slice.
|
||||||
|
serviceEnvVars = []string{""}
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().StringSlice("service-env", nil, "")
|
||||||
|
require.NoError(t, cmd.Flags().Set("service-env", ""))
|
||||||
|
|
||||||
|
saved := &serviceParams{
|
||||||
|
ServiceEnvVars: map[string]string{"OLD_VAR": "should_be_cleared"},
|
||||||
|
}
|
||||||
|
|
||||||
|
applyServiceEnvParams(cmd, saved)
|
||||||
|
|
||||||
|
assert.Nil(t, serviceEnvVars, "explicit empty --service-env should clear all saved env vars")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCurrentServiceParams_EmptyEnvVarsAfterParse(t *testing.T) {
|
||||||
|
origServiceEnvVars := serviceEnvVars
|
||||||
|
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
||||||
|
|
||||||
|
// Simulate --service-env "" which produces [""] in the slice.
|
||||||
|
serviceEnvVars = []string{""}
|
||||||
|
|
||||||
|
params := currentServiceParams()
|
||||||
|
|
||||||
|
// After parsing, the empty string is skipped, resulting in an empty map.
|
||||||
|
// The map should still be set (not nil) so it overwrites saved values.
|
||||||
|
assert.NotNil(t, params.ServiceEnvVars, "empty env vars should produce empty map, not nil")
|
||||||
|
assert.Empty(t, params.ServiceEnvVars, "no valid env vars should be parsed from empty string")
|
||||||
|
}
|
||||||
|
|
||||||
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
|
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
|
||||||
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
|
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
|
||||||
// added to serviceParams but not wired into these functions, this test fails.
|
// added to serviceParams but not wired into these functions, this test fails.
|
||||||
@@ -500,6 +535,7 @@ func fieldToGlobalVar(field string) string {
|
|||||||
"LogFiles": "logFiles",
|
"LogFiles": "logFiles",
|
||||||
"DisableProfiles": "profilesDisabled",
|
"DisableProfiles": "profilesDisabled",
|
||||||
"DisableUpdateSettings": "updateSettingsDisabled",
|
"DisableUpdateSettings": "updateSettingsDisabled",
|
||||||
|
"DisableNetworks": "networksDisabled",
|
||||||
"ServiceEnvVars": "serviceEnvVars",
|
"ServiceEnvVars": "serviceEnvVars",
|
||||||
}
|
}
|
||||||
if v, ok := m[field]; ok {
|
if v, ok := m[field]; ok {
|
||||||
|
|||||||
@@ -36,6 +36,9 @@ const (
|
|||||||
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
||||||
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
||||||
disableSSHAuthFlag = "disable-ssh-auth"
|
disableSSHAuthFlag = "disable-ssh-auth"
|
||||||
|
jwtCacheTTLFlag = "jwt-cache-ttl"
|
||||||
|
|
||||||
|
// Alias for backward compatibility.
|
||||||
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
|
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -61,7 +64,7 @@ var (
|
|||||||
enableSSHLocalPortForward bool
|
enableSSHLocalPortForward bool
|
||||||
enableSSHRemotePortForward bool
|
enableSSHRemotePortForward bool
|
||||||
disableSSHAuth bool
|
disableSSHAuth bool
|
||||||
sshJWTCacheTTL int
|
jwtCacheTTL int
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -71,7 +74,9 @@ func init() {
|
|||||||
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
||||||
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
||||||
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
|
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
|
||||||
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
|
upCmd.PersistentFlags().IntVar(&jwtCacheTTL, jwtCacheTTLFlag, 0, "JWT token cache TTL in seconds (0=disabled)")
|
||||||
|
upCmd.PersistentFlags().IntVar(&jwtCacheTTL, sshJWTCacheTTLFlag, 0, "JWT token cache TTL in seconds (alias for --jwt-cache-ttl)")
|
||||||
|
_ = upCmd.PersistentFlags().MarkDeprecated(sshJWTCacheTTLFlag, "use --jwt-cache-ttl instead")
|
||||||
|
|
||||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
||||||
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||||
@@ -100,9 +102,16 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
|
|
||||||
jobManager := job.NewJobManager(nil, store, peersmanager)
|
jobManager := job.NewJobManager(nil, store, peersmanager)
|
||||||
|
|
||||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
|
ctx := context.Background()
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
||||||
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
@@ -113,12 +122,11 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
Return(&types.Settings{}, nil).
|
Return(&types.Settings{}, nil).
|
||||||
AnyTimes()
|
AnyTimes()
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
|
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
|
||||||
|
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
accountManager, err := mgmt.BuildManager(ctx, config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -152,7 +160,7 @@ func startClientDaemon(
|
|||||||
s := grpc.NewServer()
|
s := grpc.NewServer()
|
||||||
|
|
||||||
server := client.New(ctx,
|
server := client.New(ctx,
|
||||||
"", "", false, false)
|
"", "", false, false, false)
|
||||||
if err := server.Start(); err != nil {
|
if err := server.Start(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -356,6 +356,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
|||||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||||
req.ServerSSHAllowed = &serverSSHAllowed
|
req.ServerSSHAllowed = &serverSSHAllowed
|
||||||
}
|
}
|
||||||
|
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||||
|
req.ServerVNCAllowed = &serverVNCAllowed
|
||||||
|
}
|
||||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||||
req.EnableSSHRoot = &enableSSHRoot
|
req.EnableSSHRoot = &enableSSHRoot
|
||||||
}
|
}
|
||||||
@@ -371,9 +374,12 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
|||||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||||
req.DisableSSHAuth = &disableSSHAuth
|
req.DisableSSHAuth = &disableSSHAuth
|
||||||
}
|
}
|
||||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
if cmd.Flag(disableVNCAuthFlag).Changed {
|
||||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
req.DisableVNCAuth = &disableVNCAuth
|
||||||
req.SshJWTCacheTTL = &sshJWTCacheTTL32
|
}
|
||||||
|
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||||
|
jwtCacheTTL32 := int32(jwtCacheTTL)
|
||||||
|
req.SshJWTCacheTTL = &jwtCacheTTL32
|
||||||
}
|
}
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
@@ -458,6 +464,9 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
|||||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||||
}
|
}
|
||||||
|
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||||
|
ic.ServerVNCAllowed = &serverVNCAllowed
|
||||||
|
}
|
||||||
|
|
||||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||||
ic.EnableSSHRoot = &enableSSHRoot
|
ic.EnableSSHRoot = &enableSSHRoot
|
||||||
@@ -479,8 +488,12 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
|||||||
ic.DisableSSHAuth = &disableSSHAuth
|
ic.DisableSSHAuth = &disableSSHAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
if cmd.Flag(disableVNCAuthFlag).Changed {
|
||||||
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
|
ic.DisableVNCAuth = &disableVNCAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||||
|
ic.SSHJWTCacheTTL = &jwtCacheTTL
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
@@ -582,6 +595,9 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
|||||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||||
}
|
}
|
||||||
|
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||||
|
loginRequest.ServerVNCAllowed = &serverVNCAllowed
|
||||||
|
}
|
||||||
|
|
||||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||||
loginRequest.EnableSSHRoot = &enableSSHRoot
|
loginRequest.EnableSSHRoot = &enableSSHRoot
|
||||||
@@ -603,9 +619,13 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
|||||||
loginRequest.DisableSSHAuth = &disableSSHAuth
|
loginRequest.DisableSSHAuth = &disableSSHAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
if cmd.Flag(disableVNCAuthFlag).Changed {
|
||||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
loginRequest.DisableVNCAuth = &disableVNCAuth
|
||||||
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(jwtCacheTTLFlag).Changed || cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||||
|
jwtCacheTTL32 := int32(jwtCacheTTL)
|
||||||
|
loginRequest.SshJWTCacheTTL = &jwtCacheTTL32
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||||
|
|||||||
271
client/cmd/vnc.go
Normal file
271
client/cmd/vnc.go
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"os/user"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
vncUsername string
|
||||||
|
vncHost string
|
||||||
|
vncMode string
|
||||||
|
vncListen string
|
||||||
|
vncNoBrowser bool
|
||||||
|
vncNoCache bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
vncCmd.PersistentFlags().StringVar(&vncUsername, "user", "", "OS username for session mode")
|
||||||
|
vncCmd.PersistentFlags().StringVar(&vncMode, "mode", "attach", "Connection mode: attach (view current display) or session (virtual desktop)")
|
||||||
|
vncCmd.PersistentFlags().StringVar(&vncListen, "listen", "", "Start local VNC proxy on this address (e.g., :5900) for external VNC viewers")
|
||||||
|
vncCmd.PersistentFlags().BoolVar(&vncNoBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||||
|
vncCmd.PersistentFlags().BoolVar(&vncNoCache, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||||
|
}
|
||||||
|
|
||||||
|
var vncCmd = &cobra.Command{
|
||||||
|
Use: "vnc [flags] [user@]host",
|
||||||
|
Short: "Connect to a NetBird peer via VNC",
|
||||||
|
Long: `Connect to a NetBird peer using VNC with JWT-based authentication.
|
||||||
|
The target peer must have the VNC server enabled.
|
||||||
|
|
||||||
|
Two modes are available:
|
||||||
|
- attach: view the current physical display (remote support)
|
||||||
|
- session: start a virtual desktop as the specified user (passwordless login)
|
||||||
|
|
||||||
|
Use --listen to start a local proxy for external VNC viewers:
|
||||||
|
netbird vnc --listen :5900 peer-hostname
|
||||||
|
vncviewer localhost:5900
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
netbird vnc peer-hostname
|
||||||
|
netbird vnc --mode session --user alice peer-hostname
|
||||||
|
netbird vnc --listen :5900 peer-hostname`,
|
||||||
|
Args: cobra.MinimumNArgs(1),
|
||||||
|
RunE: vncFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
func vncFn(cmd *cobra.Command, args []string) error {
|
||||||
|
SetFlagsFromEnvVars(rootCmd)
|
||||||
|
SetFlagsFromEnvVars(cmd)
|
||||||
|
cmd.SetOut(cmd.OutOrStdout())
|
||||||
|
|
||||||
|
logOutput := "console"
|
||||||
|
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||||
|
logOutput = firstLogFile
|
||||||
|
}
|
||||||
|
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||||
|
return fmt.Errorf("init log: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := parseVNCHostArg(args[0]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := internal.CtxInitState(cmd.Context())
|
||||||
|
sig := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||||
|
vncCtx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
if err := runVNC(vncCtx, cmd); err != nil {
|
||||||
|
errCh <- err
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-sig:
|
||||||
|
cancel()
|
||||||
|
<-vncCtx.Done()
|
||||||
|
return nil
|
||||||
|
case err := <-errCh:
|
||||||
|
return err
|
||||||
|
case <-vncCtx.Done():
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseVNCHostArg(arg string) error {
|
||||||
|
if strings.Contains(arg, "@") {
|
||||||
|
parts := strings.SplitN(arg, "@", 2)
|
||||||
|
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||||
|
return fmt.Errorf("invalid user@host format")
|
||||||
|
}
|
||||||
|
if vncUsername == "" {
|
||||||
|
vncUsername = parts[0]
|
||||||
|
}
|
||||||
|
vncHost = parts[1]
|
||||||
|
if vncMode == "attach" {
|
||||||
|
vncMode = "session"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
vncHost = arg
|
||||||
|
}
|
||||||
|
|
||||||
|
if vncMode == "session" && vncUsername == "" {
|
||||||
|
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||||
|
vncUsername = sudoUser
|
||||||
|
} else if currentUser, err := user.Current(); err == nil {
|
||||||
|
vncUsername = currentUser.Username
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runVNC(ctx context.Context, cmd *cobra.Command) error {
|
||||||
|
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||||
|
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("connect to daemon: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = grpcConn.Close() }()
|
||||||
|
|
||||||
|
daemonClient := proto.NewDaemonServiceClient(grpcConn)
|
||||||
|
|
||||||
|
if vncMode == "session" {
|
||||||
|
cmd.Printf("Connecting to %s@%s [session mode]...\n", vncUsername, vncHost)
|
||||||
|
} else {
|
||||||
|
cmd.Printf("Connecting to %s [attach mode]...\n", vncHost)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Obtain JWT token. If the daemon has no SSO configured, proceed without one
|
||||||
|
// (the server will accept unauthenticated connections if --disable-vnc-auth is set).
|
||||||
|
var jwtToken string
|
||||||
|
hint := profilemanager.GetLoginHint()
|
||||||
|
var browserOpener func(string) error
|
||||||
|
if !vncNoBrowser {
|
||||||
|
browserOpener = util.OpenBrowser
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := nbssh.RequestJWTToken(ctx, daemonClient, nil, cmd.ErrOrStderr(), !vncNoCache, hint, browserOpener)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("JWT authentication unavailable, connecting without token: %v", err)
|
||||||
|
} else {
|
||||||
|
jwtToken = token
|
||||||
|
log.Debug("JWT authentication successful")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to the VNC server on the standard port (5900). The peer's firewall
|
||||||
|
// DNATs 5900 -> 25900 (internal), so both ports work on the overlay network.
|
||||||
|
vncAddr := net.JoinHostPort(vncHost, "5900")
|
||||||
|
vncConn, err := net.DialTimeout("tcp", vncAddr, vncDialTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("connect to VNC at %s: %w", vncAddr, err)
|
||||||
|
}
|
||||||
|
defer vncConn.Close()
|
||||||
|
|
||||||
|
// Send session header with mode, username, and JWT.
|
||||||
|
if err := sendVNCHeader(vncConn, vncMode, vncUsername, jwtToken); err != nil {
|
||||||
|
return fmt.Errorf("send VNC header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("VNC connected to %s\n", vncHost)
|
||||||
|
|
||||||
|
if vncListen != "" {
|
||||||
|
return runVNCLocalProxy(ctx, cmd, vncConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// No --listen flag: inform the user they need to use --listen for external viewers.
|
||||||
|
cmd.Printf("VNC tunnel established. Use --listen :5900 to proxy for local VNC viewers.\n")
|
||||||
|
cmd.Printf("Press Ctrl+C to disconnect.\n")
|
||||||
|
<-ctx.Done()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const vncDialTimeout = 15 * time.Second
|
||||||
|
|
||||||
|
// sendVNCHeader writes the NetBird VNC session header.
|
||||||
|
func sendVNCHeader(conn net.Conn, mode, username, jwt string) error {
|
||||||
|
var modeByte byte
|
||||||
|
if mode == "session" {
|
||||||
|
modeByte = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
usernameBytes := []byte(username)
|
||||||
|
jwtBytes := []byte(jwt)
|
||||||
|
hdr := make([]byte, 3+len(usernameBytes)+2+len(jwtBytes))
|
||||||
|
hdr[0] = modeByte
|
||||||
|
binary.BigEndian.PutUint16(hdr[1:3], uint16(len(usernameBytes)))
|
||||||
|
off := 3
|
||||||
|
copy(hdr[off:], usernameBytes)
|
||||||
|
off += len(usernameBytes)
|
||||||
|
binary.BigEndian.PutUint16(hdr[off:off+2], uint16(len(jwtBytes)))
|
||||||
|
off += 2
|
||||||
|
copy(hdr[off:], jwtBytes)
|
||||||
|
|
||||||
|
_, err := conn.Write(hdr)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// runVNCLocalProxy listens on the given address and proxies incoming
|
||||||
|
// connections to the already-established VNC tunnel.
|
||||||
|
func runVNCLocalProxy(ctx context.Context, cmd *cobra.Command, vncConn net.Conn) error {
|
||||||
|
listener, err := net.Listen("tcp", vncListen)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("listen on %s: %w", vncListen, err)
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
cmd.Printf("VNC proxy listening on %s - connect with your VNC viewer\n", listener.Addr())
|
||||||
|
cmd.Printf("Press Ctrl+C to stop.\n")
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
listener.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Accept a single viewer connection. VNC is single-session: the RFB
|
||||||
|
// handshake completes on vncConn for the first viewer, so subsequent
|
||||||
|
// viewers would get a mid-stream connection. The loop handles transient
|
||||||
|
// accept errors until a valid connection arrives.
|
||||||
|
for {
|
||||||
|
clientConn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
log.Debugf("accept VNC proxy client: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("VNC viewer connected from %s\n", clientConn.RemoteAddr())
|
||||||
|
|
||||||
|
// Bidirectional copy.
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
io.Copy(vncConn, clientConn)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
io.Copy(clientConn, vncConn)
|
||||||
|
<-done
|
||||||
|
clientConn.Close()
|
||||||
|
|
||||||
|
cmd.Printf("VNC viewer disconnected\n")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
62
client/cmd/vnc_agent.go
Normal file
62
client/cmd/vnc_agent.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
var vncAgentPort string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
vncAgentCmd.Flags().StringVar(&vncAgentPort, "port", "15900", "Port for the VNC agent to listen on")
|
||||||
|
rootCmd.AddCommand(vncAgentCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// vncAgentCmd runs a VNC server in the current user session, listening on
|
||||||
|
// localhost. It is spawned by the NetBird service (Session 0) via
|
||||||
|
// CreateProcessAsUser into the interactive console session.
|
||||||
|
var vncAgentCmd = &cobra.Command{
|
||||||
|
Use: "vnc-agent",
|
||||||
|
Short: "Run VNC capture agent (internal, spawned by service)",
|
||||||
|
Hidden: true,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
// Agent's stderr is piped to the service which relogs it.
|
||||||
|
// Use JSON format with caller info for structured parsing.
|
||||||
|
log.SetReportCaller(true)
|
||||||
|
log.SetFormatter(&log.JSONFormatter{})
|
||||||
|
log.SetOutput(os.Stderr)
|
||||||
|
|
||||||
|
sessionID := vncserver.GetCurrentSessionID()
|
||||||
|
log.Infof("VNC agent starting on 127.0.0.1:%s (session %d)", vncAgentPort, sessionID)
|
||||||
|
|
||||||
|
capturer := vncserver.NewDesktopCapturer()
|
||||||
|
injector := vncserver.NewWindowsInputInjector()
|
||||||
|
srv := vncserver.New(capturer, injector, "")
|
||||||
|
// Auth is handled by the service. The agent verifies a token on each
|
||||||
|
// connection to ensure only the service process can connect.
|
||||||
|
// The token is passed via environment variable to avoid exposing it
|
||||||
|
// in the process command line (visible via tasklist/wmic).
|
||||||
|
srv.SetDisableAuth(true)
|
||||||
|
srv.SetAgentToken(os.Getenv("NB_VNC_AGENT_TOKEN"))
|
||||||
|
|
||||||
|
port, err := netip.ParseAddrPort("127.0.0.1:" + vncAgentPort)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
loopback := netip.PrefixFrom(netip.AddrFrom4([4]byte{127, 0, 0, 0}), 8)
|
||||||
|
if err := srv.Start(cmd.Context(), port, loopback); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
<-cmd.Context().Done()
|
||||||
|
return srv.Stop()
|
||||||
|
},
|
||||||
|
}
|
||||||
16
client/cmd/vnc_flags.go
Normal file
16
client/cmd/vnc_flags.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
const (
|
||||||
|
serverVNCAllowedFlag = "allow-server-vnc"
|
||||||
|
disableVNCAuthFlag = "disable-vnc-auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
serverVNCAllowed bool
|
||||||
|
disableVNCAuth bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
upCmd.PersistentFlags().BoolVar(&serverVNCAllowed, serverVNCAllowedFlag, false, "Allow embedded VNC server on peer")
|
||||||
|
upCmd.PersistentFlags().BoolVar(&disableVNCAuth, disableVNCAuthFlag, false, "Disable JWT authentication for VNC")
|
||||||
|
}
|
||||||
229
client/cmd/vnc_recordings.go
Normal file
229
client/cmd/vnc_recordings.go
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdh"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"text/tabwriter"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
var vncRecDir string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
vncRecPlayCmd.Flags().StringVar(&vncRecDir, "dir", "", "Recording directory (default: auto-detect)")
|
||||||
|
vncRecListCmd.Flags().StringVar(&vncRecDir, "dir", "", "Recording directory (default: auto-detect)")
|
||||||
|
vncRecCmd.AddCommand(vncRecListCmd)
|
||||||
|
vncRecCmd.AddCommand(vncRecPlayCmd)
|
||||||
|
vncRecCmd.AddCommand(vncRecKeygenCmd)
|
||||||
|
vncCmd.AddCommand(vncRecCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
var vncRecCmd = &cobra.Command{
|
||||||
|
Use: "rec",
|
||||||
|
Short: "Manage VNC session recordings",
|
||||||
|
}
|
||||||
|
|
||||||
|
var vncRecKeygenCmd = &cobra.Command{
|
||||||
|
Use: "keygen",
|
||||||
|
Short: "Generate an X25519 keypair for recording encryption",
|
||||||
|
Long: `Generates an X25519 keypair. Put the public key in management settings
|
||||||
|
(Session Recording > Encryption Key). Keep the private key safe for decrypting recordings.`,
|
||||||
|
RunE: vncRecKeygenFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
var vncRecListCmd = &cobra.Command{
|
||||||
|
Use: "list",
|
||||||
|
Short: "List VNC session recordings",
|
||||||
|
RunE: vncRecListFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
var vncRecPlayCmd = &cobra.Command{
|
||||||
|
Use: "play <file-or-name>",
|
||||||
|
Short: "Open a VNC recording in the browser",
|
||||||
|
Long: `Opens a browser-based player with playback controls:
|
||||||
|
play/pause, seek, speed (0.25x to 8x), keyboard shortcuts.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
netbird vnc rec play last
|
||||||
|
netbird vnc rec play 20260416-104433_vnc.rec`,
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
RunE: vncRecPlayFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func vncRecListFn(cmd *cobra.Command, _ []string) error {
|
||||||
|
dir, err := resolveVNCRecDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read recording dir %s: %w", dir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 0, 2, ' ', 0)
|
||||||
|
fmt.Fprintln(w, "FILE\tSIZE\tDIMENSIONS\tUSER\tREMOTE\tMODE\tDATE")
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".rec") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filePath := filepath.Join(dir, entry.Name())
|
||||||
|
info, err := entry.Info()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
header, err := vncserver.ReadRecordingHeader(filePath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(w, "%s\t%s\t?\t?\t?\t?\t?\n", entry.Name(), vncFormatSize(info.Size()))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "%s\t%s\t%dx%d\t%s\t%s\t%s\t%s\n",
|
||||||
|
entry.Name(),
|
||||||
|
vncFormatSize(info.Size()),
|
||||||
|
header.Width, header.Height,
|
||||||
|
header.Meta.User,
|
||||||
|
header.Meta.RemoteAddr,
|
||||||
|
header.Meta.Mode,
|
||||||
|
header.StartTime.Format("2006-01-02 15:04:05"),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func vncRecPlayFn(cmd *cobra.Command, args []string) error {
|
||||||
|
filePath, err := resolveVNCRecFile(args[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
header, err := vncserver.ReadRecordingHeader(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read recording: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("Recording: %s (%dx%d)\n", filepath.Base(filePath), header.Width, header.Height)
|
||||||
|
|
||||||
|
url, err := vncserver.ServeWebPlayer(filePath, "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("start web player: %w", err)
|
||||||
|
}
|
||||||
|
cmd.Printf("Player: %s\n", url)
|
||||||
|
if err := util.OpenBrowser(url); err != nil {
|
||||||
|
cmd.Printf("Open %s in your browser\n", url)
|
||||||
|
}
|
||||||
|
cmd.Printf("Press Ctrl+C to stop.\n")
|
||||||
|
sig := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
<-sig
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func vncRecKeygenFn(cmd *cobra.Command, _ []string) error {
|
||||||
|
priv, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
privB64 := base64.StdEncoding.EncodeToString(priv.Bytes())
|
||||||
|
pubB64 := base64.StdEncoding.EncodeToString(priv.PublicKey().Bytes())
|
||||||
|
|
||||||
|
cmd.Printf("Private key (keep secret, for decrypting recordings):\n %s\n\n", privB64)
|
||||||
|
cmd.Printf("Public key (paste into management Settings > Session Recording > Encryption Key):\n %s\n", pubB64)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func vncFormatSize(size int64) string {
|
||||||
|
switch {
|
||||||
|
case size >= 1<<20:
|
||||||
|
return fmt.Sprintf("%.1fM", float64(size)/float64(1<<20))
|
||||||
|
case size >= 1<<10:
|
||||||
|
return fmt.Sprintf("%.1fK", float64(size)/float64(1<<10))
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("%dB", size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveVNCRecDir() (string, error) {
|
||||||
|
if vncRecDir != "" {
|
||||||
|
return vncRecDir, nil
|
||||||
|
}
|
||||||
|
candidates := []string{
|
||||||
|
"/var/lib/netbird/recordings/vnc",
|
||||||
|
filepath.Join(os.Getenv("HOME"), ".netbird/recordings/vnc"),
|
||||||
|
}
|
||||||
|
for _, dir := range candidates {
|
||||||
|
if fi, err := os.Stat(dir); err == nil && fi.IsDir() {
|
||||||
|
return dir, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("no VNC recording directory found; use --dir to specify")
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveVNCRecFile(arg string) (string, error) {
|
||||||
|
if strings.Contains(arg, "/") || strings.Contains(arg, string(os.PathSeparator)) {
|
||||||
|
return arg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dir, err := resolveVNCRecDir()
|
||||||
|
if err != nil && arg != "last" {
|
||||||
|
return arg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if arg == "last" {
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return findLatestRec(dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
full := filepath.Join(dir, arg)
|
||||||
|
if _, err := os.Stat(full); err == nil {
|
||||||
|
return full, nil
|
||||||
|
}
|
||||||
|
return arg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func findLatestRec(dir string) (string, error) {
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("read dir: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var latest string
|
||||||
|
var latestTime time.Time
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".rec") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
info, err := entry.Info()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if info.ModTime().After(latestTime) {
|
||||||
|
latestTime = info.ModTime()
|
||||||
|
latest = filepath.Join(dir, entry.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if latest == "" {
|
||||||
|
return "", fmt.Errorf("no recordings found in %s", dir)
|
||||||
|
}
|
||||||
|
return latest, nil
|
||||||
|
}
|
||||||
@@ -56,6 +56,13 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg
|
|||||||
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Native firewall handles packet filtering, but the userspace WireGuard bind
|
||||||
|
// needs a device filter for DNS interception hooks. Install a minimal
|
||||||
|
// hooks-only filter that passes all traffic through to the kernel firewall.
|
||||||
|
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
|
||||||
|
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return fm, nil
|
return fm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,10 @@ const (
|
|||||||
|
|
||||||
// rules chains contains the effective ACL rules
|
// rules chains contains the effective ACL rules
|
||||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||||
|
|
||||||
|
// mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent
|
||||||
|
// external DNAT from bypassing ACL rules.
|
||||||
|
mangleFwdKey = "MANGLE-FORWARD"
|
||||||
)
|
)
|
||||||
|
|
||||||
type aclEntries map[string][][]string
|
type aclEntries map[string][][]string
|
||||||
@@ -274,6 +278,12 @@ func (m *aclManager) cleanChains() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, rule := range m.entries[mangleFwdKey] {
|
||||||
|
if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil {
|
||||||
|
log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||||
if err := m.flushIPSet(ipsetName); err != nil {
|
if err := m.flushIPSet(ipsetName); err != nil {
|
||||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||||
@@ -303,6 +313,10 @@ func (m *aclManager) createDefaultChains() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for chainName, rules := range m.entries {
|
for chainName, rules := range m.entries {
|
||||||
|
// mangle FORWARD guard rules are handled separately below
|
||||||
|
if chainName == mangleFwdKey {
|
||||||
|
continue
|
||||||
|
}
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
||||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||||
@@ -322,6 +336,13 @@ func (m *aclManager) createDefaultChains() error {
|
|||||||
}
|
}
|
||||||
clear(m.optionalEntries)
|
clear(m.optionalEntries)
|
||||||
|
|
||||||
|
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
|
||||||
|
for _, rule := range m.entries[mangleFwdKey] {
|
||||||
|
if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil {
|
||||||
|
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -343,6 +364,22 @@ func (m *aclManager) seedInitialEntries() {
|
|||||||
|
|
||||||
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
|
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
|
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
|
||||||
|
|
||||||
|
// Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it
|
||||||
|
// traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD
|
||||||
|
// can be inserted above ours. Mangle runs before filter, so these guard rules enforce the
|
||||||
|
// ACL mark check where it cannot be overridden.
|
||||||
|
m.appendToEntries(mangleFwdKey, []string{
|
||||||
|
"-i", m.wgIface.Name(),
|
||||||
|
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
|
||||||
|
"-j", "ACCEPT",
|
||||||
|
})
|
||||||
|
m.appendToEntries(mangleFwdKey, []string{
|
||||||
|
"-i", m.wgIface.Name(),
|
||||||
|
"-m", "conntrack", "--ctstate", "DNAT",
|
||||||
|
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||||
|
"-j", "DROP",
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) seedInitialOptionalEntries() {
|
func (m *aclManager) seedInitialOptionalEntries() {
|
||||||
|
|||||||
37
client/firewall/uspfilter/common/hooks.go
Normal file
37
client/firewall/uspfilter/common/hooks.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PacketHook stores a registered hook for a specific IP:port.
|
||||||
|
type PacketHook struct {
|
||||||
|
IP netip.Addr
|
||||||
|
Port uint16
|
||||||
|
Fn func([]byte) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// HookMatches checks if a packet's destination matches the hook and invokes it.
|
||||||
|
func HookMatches(h *PacketHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
|
||||||
|
if h == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if h.IP == dstIP && h.Port == dport {
|
||||||
|
return h.Fn(packetData)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHook atomically stores a hook, handling nil removal.
|
||||||
|
func SetHook(ptr *atomic.Pointer[PacketHook], ip netip.Addr, dPort uint16, hook func([]byte) bool) {
|
||||||
|
if hook == nil {
|
||||||
|
ptr.Store(nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ptr.Store(&PacketHook{
|
||||||
|
IP: ip,
|
||||||
|
Port: dPort,
|
||||||
|
Fn: hook,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -142,15 +142,8 @@ type Manager struct {
|
|||||||
mssClampEnabled bool
|
mssClampEnabled bool
|
||||||
|
|
||||||
// Only one hook per protocol is supported. Outbound direction only.
|
// Only one hook per protocol is supported. Outbound direction only.
|
||||||
udpHookOut atomic.Pointer[packetHook]
|
udpHookOut atomic.Pointer[common.PacketHook]
|
||||||
tcpHookOut atomic.Pointer[packetHook]
|
tcpHookOut atomic.Pointer[common.PacketHook]
|
||||||
}
|
|
||||||
|
|
||||||
// packetHook stores a registered hook for a specific IP:port.
|
|
||||||
type packetHook struct {
|
|
||||||
ip netip.Addr
|
|
||||||
port uint16
|
|
||||||
fn func([]byte) bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@@ -912,21 +905,11 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||||
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
|
return common.HookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||||
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
|
return common.HookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
|
||||||
}
|
|
||||||
|
|
||||||
func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
|
|
||||||
if h == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if h.ip == dstIP && h.port == dport {
|
|
||||||
return h.fn(packetData)
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterInbound implements filtering logic for incoming packets.
|
// filterInbound implements filtering logic for incoming packets.
|
||||||
@@ -1337,28 +1320,12 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
|||||||
|
|
||||||
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
|
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
|
||||||
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
||||||
if hook == nil {
|
common.SetHook(&m.udpHookOut, ip, dPort, hook)
|
||||||
m.udpHookOut.Store(nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
m.udpHookOut.Store(&packetHook{
|
|
||||||
ip: ip,
|
|
||||||
port: dPort,
|
|
||||||
fn: hook,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
|
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
|
||||||
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
||||||
if hook == nil {
|
common.SetHook(&m.tcpHookOut, ip, dPort, hook)
|
||||||
m.tcpHookOut.Store(nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
m.tcpHookOut.Store(&packetHook{
|
|
||||||
ip: ip,
|
|
||||||
port: dPort,
|
|
||||||
fn: hook,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLogLevel sets the log level for the firewall manager
|
// SetLogLevel sets the log level for the firewall manager
|
||||||
|
|||||||
@@ -202,9 +202,9 @@ func TestSetUDPPacketHook(t *testing.T) {
|
|||||||
|
|
||||||
h := manager.udpHookOut.Load()
|
h := manager.udpHookOut.Load()
|
||||||
require.NotNil(t, h)
|
require.NotNil(t, h)
|
||||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
|
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
|
||||||
assert.Equal(t, uint16(8000), h.port)
|
assert.Equal(t, uint16(8000), h.Port)
|
||||||
assert.True(t, h.fn(nil))
|
assert.True(t, h.Fn(nil))
|
||||||
assert.True(t, called)
|
assert.True(t, called)
|
||||||
|
|
||||||
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
|
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
|
||||||
@@ -226,9 +226,9 @@ func TestSetTCPPacketHook(t *testing.T) {
|
|||||||
|
|
||||||
h := manager.tcpHookOut.Load()
|
h := manager.tcpHookOut.Load()
|
||||||
require.NotNil(t, h)
|
require.NotNil(t, h)
|
||||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
|
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
|
||||||
assert.Equal(t, uint16(53), h.port)
|
assert.Equal(t, uint16(53), h.Port)
|
||||||
assert.True(t, h.fn(nil))
|
assert.True(t, h.Fn(nil))
|
||||||
assert.True(t, called)
|
assert.True(t, called)
|
||||||
|
|
||||||
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
|
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
|
||||||
|
|||||||
90
client/firewall/uspfilter/hooks_filter.go
Normal file
90
client/firewall/uspfilter/hooks_filter.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"net/netip"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ipv4HeaderMinLen = 20
|
||||||
|
ipv4ProtoOffset = 9
|
||||||
|
ipv4FlagsOffset = 6
|
||||||
|
ipv4DstOffset = 16
|
||||||
|
ipProtoUDP = 17
|
||||||
|
ipProtoTCP = 6
|
||||||
|
ipv4FragOffMask = 0x1fff
|
||||||
|
// dstPortOffset is the offset of the destination port within a UDP or TCP header.
|
||||||
|
dstPortOffset = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// HooksFilter is a minimal packet filter that only handles outbound DNS hooks.
|
||||||
|
// It is installed on the WireGuard interface when the userspace bind is active
|
||||||
|
// but a full firewall filter (Manager) is not needed because a native kernel
|
||||||
|
// firewall (nftables/iptables) handles packet filtering.
|
||||||
|
type HooksFilter struct {
|
||||||
|
udpHook atomic.Pointer[common.PacketHook]
|
||||||
|
tcpHook atomic.Pointer[common.PacketHook]
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ device.PacketFilter = (*HooksFilter)(nil)
|
||||||
|
|
||||||
|
// FilterOutbound checks outbound packets for DNS hook matches.
|
||||||
|
// Only IPv4 packets matching the registered hook IP:port are intercepted.
|
||||||
|
// IPv6 and non-IP packets pass through unconditionally.
|
||||||
|
func (f *HooksFilter) FilterOutbound(packetData []byte, _ int) bool {
|
||||||
|
if len(packetData) < ipv4HeaderMinLen {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only process IPv4 packets, let everything else pass through.
|
||||||
|
if packetData[0]>>4 != 4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
ihl := int(packetData[0]&0x0f) * 4
|
||||||
|
if ihl < ipv4HeaderMinLen || len(packetData) < ihl+4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip non-first fragments: they don't carry L4 headers.
|
||||||
|
flagsAndOffset := binary.BigEndian.Uint16(packetData[ipv4FlagsOffset : ipv4FlagsOffset+2])
|
||||||
|
if flagsAndOffset&ipv4FragOffMask != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
dstIP, ok := netip.AddrFromSlice(packetData[ipv4DstOffset : ipv4DstOffset+4])
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
proto := packetData[ipv4ProtoOffset]
|
||||||
|
dstPort := binary.BigEndian.Uint16(packetData[ihl+dstPortOffset : ihl+dstPortOffset+2])
|
||||||
|
|
||||||
|
switch proto {
|
||||||
|
case ipProtoUDP:
|
||||||
|
return common.HookMatches(f.udpHook.Load(), dstIP, dstPort, packetData)
|
||||||
|
case ipProtoTCP:
|
||||||
|
return common.HookMatches(f.tcpHook.Load(), dstIP, dstPort, packetData)
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterInbound allows all inbound packets (native firewall handles filtering).
|
||||||
|
func (f *HooksFilter) FilterInbound([]byte, int) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUDPPacketHook registers the UDP packet hook.
|
||||||
|
func (f *HooksFilter) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
|
||||||
|
common.SetHook(&f.udpHook, ip, dPort, hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTCPPacketHook registers the TCP packet hook.
|
||||||
|
func (f *HooksFilter) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
|
||||||
|
common.SetHook(&f.tcpHook, ip, dPort, hook)
|
||||||
|
}
|
||||||
@@ -217,7 +217,6 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
|||||||
// Close closes the tunnel interface
|
// Close closes the tunnel interface
|
||||||
func (w *WGIface) Close() error {
|
func (w *WGIface) Close() error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
|
||||||
|
|
||||||
var result *multierror.Error
|
var result *multierror.Error
|
||||||
|
|
||||||
@@ -225,7 +224,15 @@ func (w *WGIface) Close() error {
|
|||||||
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
|
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := w.tun.Close(); err != nil {
|
// Release w.mu before calling w.tun.Close(): the underlying
|
||||||
|
// wireguard-go device.Close() waits for its send/receive goroutines
|
||||||
|
// to drain. Some of those goroutines re-enter WGIface methods that
|
||||||
|
// take w.mu (e.g. the packet filter DNS hook calls GetDevice()), so
|
||||||
|
// holding the mutex here would deadlock the shutdown path.
|
||||||
|
tun := w.tun
|
||||||
|
w.mu.Unlock()
|
||||||
|
|
||||||
|
if err := tun.Close(); err != nil {
|
||||||
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
113
client/iface/iface_close_test.go
Normal file
113
client/iface/iface_close_test.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package iface
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fakeTunDevice implements WGTunDevice and lets the test control when
|
||||||
|
// Close() returns. It mimics the wireguard-go shutdown path, which blocks
|
||||||
|
// until its goroutines drain. Some of those goroutines (e.g. the packet
|
||||||
|
// filter DNS hook in client/internal/dns) call back into WGIface, so if
|
||||||
|
// WGIface.Close() held w.mu across tun.Close() the shutdown would
|
||||||
|
// deadlock.
|
||||||
|
type fakeTunDevice struct {
|
||||||
|
closeStarted chan struct{}
|
||||||
|
unblockClose chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeTunDevice) Create() (device.WGConfigurer, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
func (f *fakeTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
func (f *fakeTunDevice) UpdateAddr(wgaddr.Address) error { return nil }
|
||||||
|
func (f *fakeTunDevice) WgAddress() wgaddr.Address { return wgaddr.Address{} }
|
||||||
|
func (f *fakeTunDevice) MTU() uint16 { return DefaultMTU }
|
||||||
|
func (f *fakeTunDevice) DeviceName() string { return "nb-close-test" }
|
||||||
|
func (f *fakeTunDevice) FilteredDevice() *device.FilteredDevice { return nil }
|
||||||
|
func (f *fakeTunDevice) Device() *wgdevice.Device { return nil }
|
||||||
|
func (f *fakeTunDevice) GetNet() *netstack.Net { return nil }
|
||||||
|
func (f *fakeTunDevice) GetICEBind() device.EndpointManager { return nil }
|
||||||
|
|
||||||
|
func (f *fakeTunDevice) Close() error {
|
||||||
|
close(f.closeStarted)
|
||||||
|
<-f.unblockClose
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeProxyFactory struct{}
|
||||||
|
|
||||||
|
func (fakeProxyFactory) GetProxy() wgproxy.Proxy { return nil }
|
||||||
|
func (fakeProxyFactory) GetProxyPort() uint16 { return 0 }
|
||||||
|
func (fakeProxyFactory) Free() error { return nil }
|
||||||
|
|
||||||
|
// TestWGIface_CloseReleasesMutexBeforeTunClose guards against a deadlock
|
||||||
|
// that surfaces as a macOS test-timeout in
|
||||||
|
// TestDNSPermanent_updateUpstream: WGIface.Close() used to hold w.mu
|
||||||
|
// while waiting for the wireguard-go device goroutines to finish, and
|
||||||
|
// one of those goroutines (the DNS filter hook) calls back into
|
||||||
|
// WGIface.GetDevice() which needs the same mutex. The fix is to drop
|
||||||
|
// the lock before tun.Close() returns control.
|
||||||
|
func TestWGIface_CloseReleasesMutexBeforeTunClose(t *testing.T) {
|
||||||
|
tun := &fakeTunDevice{
|
||||||
|
closeStarted: make(chan struct{}),
|
||||||
|
unblockClose: make(chan struct{}),
|
||||||
|
}
|
||||||
|
w := &WGIface{
|
||||||
|
tun: tun,
|
||||||
|
wgProxyFactory: fakeProxyFactory{},
|
||||||
|
}
|
||||||
|
|
||||||
|
closeDone := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
closeDone <- w.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-tun.closeStarted:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
close(tun.unblockClose)
|
||||||
|
t.Fatal("tun.Close() was never invoked")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate the WireGuard read goroutine calling back into WGIface
|
||||||
|
// via the packet filter's DNS hook. If Close() still held w.mu
|
||||||
|
// during tun.Close(), this would block until the test timeout.
|
||||||
|
getDeviceDone := make(chan struct{})
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_ = w.GetDevice()
|
||||||
|
close(getDeviceDone)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-getDeviceDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
close(tun.unblockClose)
|
||||||
|
wg.Wait()
|
||||||
|
t.Fatal("GetDevice() deadlocked while WGIface.Close was closing the tun")
|
||||||
|
}
|
||||||
|
|
||||||
|
close(tun.unblockClose)
|
||||||
|
select {
|
||||||
|
case <-closeDone:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("WGIface.Close() never returned after the tun was unblocked")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -171,7 +171,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if u.address.Network.Contains(a) {
|
if u.address.Network.Contains(a) {
|
||||||
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,7 +181,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
|||||||
u.addrCache.Store(addr.String(), isRouted)
|
u.addrCache.Store(addr.String(), isRouted)
|
||||||
if isRouted {
|
if isRouted {
|
||||||
// Extra log, as the error only shows up with ICE logging enabled
|
// Extra log, as the error only shows up with ICE logging enabled
|
||||||
log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix)
|
log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||||
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
|
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -315,6 +315,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
|||||||
a.config.RosenpassEnabled,
|
a.config.RosenpassEnabled,
|
||||||
a.config.RosenpassPermissive,
|
a.config.RosenpassPermissive,
|
||||||
a.config.ServerSSHAllowed,
|
a.config.ServerSSHAllowed,
|
||||||
|
a.config.ServerVNCAllowed,
|
||||||
a.config.DisableClientRoutes,
|
a.config.DisableClientRoutes,
|
||||||
a.config.DisableServerRoutes,
|
a.config.DisableServerRoutes,
|
||||||
a.config.DisableDNS,
|
a.config.DisableDNS,
|
||||||
@@ -327,6 +328,7 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
|||||||
a.config.EnableSSHLocalPortForwarding,
|
a.config.EnableSSHLocalPortForwarding,
|
||||||
a.config.EnableSSHRemotePortForwarding,
|
a.config.EnableSSHRemotePortForwarding,
|
||||||
a.config.DisableSSHAuth,
|
a.config.DisableSSHAuth,
|
||||||
|
a.config.DisableVNCAuth,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -94,6 +94,7 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
dnsAddresses []netip.AddrPort,
|
dnsAddresses []netip.AddrPort,
|
||||||
dnsReadyListener dns.ReadyListener,
|
dnsReadyListener dns.ReadyListener,
|
||||||
stateFilePath string,
|
stateFilePath string,
|
||||||
|
cacheDir string,
|
||||||
) error {
|
) error {
|
||||||
// in case of non Android os these variables will be nil
|
// in case of non Android os these variables will be nil
|
||||||
mobileDependency := MobileDependency{
|
mobileDependency := MobileDependency{
|
||||||
@@ -103,6 +104,7 @@ func (c *ConnectClient) RunOnAndroid(
|
|||||||
HostDNSAddresses: dnsAddresses,
|
HostDNSAddresses: dnsAddresses,
|
||||||
DnsReadyListener: dnsReadyListener,
|
DnsReadyListener: dnsReadyListener,
|
||||||
StateFilePath: stateFilePath,
|
StateFilePath: stateFilePath,
|
||||||
|
TempDir: cacheDir,
|
||||||
}
|
}
|
||||||
return c.run(mobileDependency, nil, "")
|
return c.run(mobileDependency, nil, "")
|
||||||
}
|
}
|
||||||
@@ -338,6 +340,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
log.Error(err)
|
log.Error(err)
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
|
engineConfig.TempDir = mobileDependency.TempDir
|
||||||
|
|
||||||
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
|
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
|
||||||
c.statusRecorder.SetRelayMgr(relayManager)
|
c.statusRecorder.SetRelayMgr(relayManager)
|
||||||
@@ -543,11 +546,13 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
|||||||
RosenpassEnabled: config.RosenpassEnabled,
|
RosenpassEnabled: config.RosenpassEnabled,
|
||||||
RosenpassPermissive: config.RosenpassPermissive,
|
RosenpassPermissive: config.RosenpassPermissive,
|
||||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||||
|
ServerVNCAllowed: config.ServerVNCAllowed != nil && *config.ServerVNCAllowed,
|
||||||
EnableSSHRoot: config.EnableSSHRoot,
|
EnableSSHRoot: config.EnableSSHRoot,
|
||||||
EnableSSHSFTP: config.EnableSSHSFTP,
|
EnableSSHSFTP: config.EnableSSHSFTP,
|
||||||
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
||||||
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
|
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
|
||||||
DisableSSHAuth: config.DisableSSHAuth,
|
DisableSSHAuth: config.DisableSSHAuth,
|
||||||
|
DisableVNCAuth: config.DisableVNCAuth,
|
||||||
DNSRouteInterval: config.DNSRouteInterval,
|
DNSRouteInterval: config.DNSRouteInterval,
|
||||||
|
|
||||||
DisableClientRoutes: config.DisableClientRoutes,
|
DisableClientRoutes: config.DisableClientRoutes,
|
||||||
@@ -624,6 +629,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
|||||||
config.RosenpassEnabled,
|
config.RosenpassEnabled,
|
||||||
config.RosenpassPermissive,
|
config.RosenpassPermissive,
|
||||||
config.ServerSSHAllowed,
|
config.ServerSSHAllowed,
|
||||||
|
config.ServerVNCAllowed,
|
||||||
config.DisableClientRoutes,
|
config.DisableClientRoutes,
|
||||||
config.DisableServerRoutes,
|
config.DisableServerRoutes,
|
||||||
config.DisableDNS,
|
config.DisableDNS,
|
||||||
@@ -636,6 +642,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
|||||||
config.EnableSSHLocalPortForwarding,
|
config.EnableSSHLocalPortForwarding,
|
||||||
config.EnableSSHRemotePortForwarding,
|
config.EnableSSHRemotePortForwarding,
|
||||||
config.DisableSSHAuth,
|
config.DisableSSHAuth,
|
||||||
|
config.DisableVNCAuth,
|
||||||
)
|
)
|
||||||
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
|
return client.Login(sysInfo, pubSSHKey, config.DNSLabels)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/pprof"
|
"runtime/pprof"
|
||||||
"slices"
|
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -31,7 +30,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const readmeContent = `Netbird debug bundle
|
const readmeContent = `Netbird debug bundle
|
||||||
@@ -234,6 +232,7 @@ type BundleGenerator struct {
|
|||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
syncResponse *mgmProto.SyncResponse
|
syncResponse *mgmProto.SyncResponse
|
||||||
logPath string
|
logPath string
|
||||||
|
tempDir string
|
||||||
cpuProfile []byte
|
cpuProfile []byte
|
||||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||||
clientMetrics MetricsExporter
|
clientMetrics MetricsExporter
|
||||||
@@ -256,6 +255,7 @@ type GeneratorDependencies struct {
|
|||||||
StatusRecorder *peer.Status
|
StatusRecorder *peer.Status
|
||||||
SyncResponse *mgmProto.SyncResponse
|
SyncResponse *mgmProto.SyncResponse
|
||||||
LogPath string
|
LogPath string
|
||||||
|
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
|
||||||
CPUProfile []byte
|
CPUProfile []byte
|
||||||
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
||||||
ClientMetrics MetricsExporter
|
ClientMetrics MetricsExporter
|
||||||
@@ -275,6 +275,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
|||||||
statusRecorder: deps.StatusRecorder,
|
statusRecorder: deps.StatusRecorder,
|
||||||
syncResponse: deps.SyncResponse,
|
syncResponse: deps.SyncResponse,
|
||||||
logPath: deps.LogPath,
|
logPath: deps.LogPath,
|
||||||
|
tempDir: deps.TempDir,
|
||||||
cpuProfile: deps.CPUProfile,
|
cpuProfile: deps.CPUProfile,
|
||||||
refreshStatus: deps.RefreshStatus,
|
refreshStatus: deps.RefreshStatus,
|
||||||
clientMetrics: deps.ClientMetrics,
|
clientMetrics: deps.ClientMetrics,
|
||||||
@@ -287,7 +288,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
|||||||
|
|
||||||
// Generate creates a debug bundle and returns the location.
|
// Generate creates a debug bundle and returns the location.
|
||||||
func (g *BundleGenerator) Generate() (resp string, err error) {
|
func (g *BundleGenerator) Generate() (resp string, err error) {
|
||||||
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
|
bundlePath, err := os.CreateTemp(g.tempDir, "netbird.debug.*.zip")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("create zip file: %w", err)
|
return "", fmt.Errorf("create zip file: %w", err)
|
||||||
}
|
}
|
||||||
@@ -373,15 +374,8 @@ func (g *BundleGenerator) createArchive() error {
|
|||||||
log.Errorf("failed to add wg show output: %v", err)
|
log.Errorf("failed to add wg show output: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
|
if err := g.addPlatformLog(); err != nil {
|
||||||
if err := g.addLogfile(); err != nil {
|
log.Errorf("failed to add logs to debug bundle: %v", err)
|
||||||
log.Errorf("failed to add log file to debug bundle: %v", err)
|
|
||||||
if err := g.trySystemdLogFallback(); err != nil {
|
|
||||||
log.Errorf("failed to add systemd logs as fallback: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if err := g.trySystemdLogFallback(); err != nil {
|
|
||||||
log.Errorf("failed to add systemd logs: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := g.addUpdateLogs(); err != nil {
|
if err := g.addUpdateLogs(); err != nil {
|
||||||
|
|||||||
41
client/internal/debug/debug_android.go
Normal file
41
client/internal/debug/debug_android.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os/exec"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addPlatformLog() error {
|
||||||
|
cmd := exec.Command("/system/bin/logcat", "-d")
|
||||||
|
stdout, err := cmd.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("logcat stdout pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return fmt.Errorf("start logcat: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var logReader io.Reader = stdout
|
||||||
|
if g.anonymize {
|
||||||
|
var pw *io.PipeWriter
|
||||||
|
logReader, pw = io.Pipe()
|
||||||
|
go anonymizeLog(stdout, pw, g.anonymizer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.addFileToZip(logReader, "logcat.txt"); err != nil {
|
||||||
|
return fmt.Errorf("add logcat to zip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cmd.Wait(); err != nil {
|
||||||
|
return fmt.Errorf("wait logcat: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("added logcat output to debug bundle")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
25
client/internal/debug/debug_nonandroid.go
Normal file
25
client/internal/debug/debug_nonandroid.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
//go:build !android
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addPlatformLog() error {
|
||||||
|
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
|
||||||
|
if err := g.addLogfile(); err != nil {
|
||||||
|
log.Errorf("failed to add log file to debug bundle: %v", err)
|
||||||
|
if err := g.trySystemdLogFallback(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if err := g.trySystemdLogFallback(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -117,11 +117,13 @@ type EngineConfig struct {
|
|||||||
RosenpassPermissive bool
|
RosenpassPermissive bool
|
||||||
|
|
||||||
ServerSSHAllowed bool
|
ServerSSHAllowed bool
|
||||||
|
ServerVNCAllowed bool
|
||||||
EnableSSHRoot *bool
|
EnableSSHRoot *bool
|
||||||
EnableSSHSFTP *bool
|
EnableSSHSFTP *bool
|
||||||
EnableSSHLocalPortForwarding *bool
|
EnableSSHLocalPortForwarding *bool
|
||||||
EnableSSHRemotePortForwarding *bool
|
EnableSSHRemotePortForwarding *bool
|
||||||
DisableSSHAuth *bool
|
DisableSSHAuth *bool
|
||||||
|
DisableVNCAuth *bool
|
||||||
|
|
||||||
DNSRouteInterval time.Duration
|
DNSRouteInterval time.Duration
|
||||||
|
|
||||||
@@ -140,6 +142,7 @@ type EngineConfig struct {
|
|||||||
ProfileConfig *profilemanager.Config
|
ProfileConfig *profilemanager.Config
|
||||||
|
|
||||||
LogPath string
|
LogPath string
|
||||||
|
TempDir string
|
||||||
}
|
}
|
||||||
|
|
||||||
// EngineServices holds the external service dependencies required by the Engine.
|
// EngineServices holds the external service dependencies required by the Engine.
|
||||||
@@ -197,6 +200,7 @@ type Engine struct {
|
|||||||
networkMonitor *networkmonitor.NetworkMonitor
|
networkMonitor *networkmonitor.NetworkMonitor
|
||||||
|
|
||||||
sshServer sshServer
|
sshServer sshServer
|
||||||
|
vncSrv vncServer
|
||||||
|
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
@@ -310,6 +314,10 @@ func (e *Engine) Stop() error {
|
|||||||
log.Warnf("failed to stop SSH server: %v", err)
|
log.Warnf("failed to stop SSH server: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := e.stopVNCServer(); err != nil {
|
||||||
|
log.Warnf("failed to stop VNC server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
e.cleanupSSHConfig()
|
e.cleanupSSHConfig()
|
||||||
|
|
||||||
if e.ingressGatewayMgr != nil {
|
if e.ingressGatewayMgr != nil {
|
||||||
@@ -997,6 +1005,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
|||||||
e.config.RosenpassEnabled,
|
e.config.RosenpassEnabled,
|
||||||
e.config.RosenpassPermissive,
|
e.config.RosenpassPermissive,
|
||||||
&e.config.ServerSSHAllowed,
|
&e.config.ServerSSHAllowed,
|
||||||
|
&e.config.ServerVNCAllowed,
|
||||||
e.config.DisableClientRoutes,
|
e.config.DisableClientRoutes,
|
||||||
e.config.DisableServerRoutes,
|
e.config.DisableServerRoutes,
|
||||||
e.config.DisableDNS,
|
e.config.DisableDNS,
|
||||||
@@ -1009,6 +1018,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
|||||||
e.config.EnableSSHLocalPortForwarding,
|
e.config.EnableSSHLocalPortForwarding,
|
||||||
e.config.EnableSSHRemotePortForwarding,
|
e.config.EnableSSHRemotePortForwarding,
|
||||||
e.config.DisableSSHAuth,
|
e.config.DisableSSHAuth,
|
||||||
|
e.config.DisableVNCAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||||
@@ -1036,6 +1046,10 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := e.updateVNC(conf.GetSshConfig()); err != nil {
|
||||||
|
log.Warnf("failed handling VNC server setup: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
state := e.statusRecorder.GetLocalPeerState()
|
state := e.statusRecorder.GetLocalPeerState()
|
||||||
state.IP = e.wgInterface.Address().String()
|
state.IP = e.wgInterface.Address().String()
|
||||||
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||||
@@ -1095,6 +1109,7 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
|
|||||||
StatusRecorder: e.statusRecorder,
|
StatusRecorder: e.statusRecorder,
|
||||||
SyncResponse: syncResponse,
|
SyncResponse: syncResponse,
|
||||||
LogPath: e.config.LogPath,
|
LogPath: e.config.LogPath,
|
||||||
|
TempDir: e.config.TempDir,
|
||||||
ClientMetrics: e.clientMetrics,
|
ClientMetrics: e.clientMetrics,
|
||||||
RefreshStatus: func() {
|
RefreshStatus: func() {
|
||||||
e.RunHealthProbes(true)
|
e.RunHealthProbes(true)
|
||||||
@@ -1137,6 +1152,7 @@ func (e *Engine) receiveManagementEvents() {
|
|||||||
e.config.RosenpassEnabled,
|
e.config.RosenpassEnabled,
|
||||||
e.config.RosenpassPermissive,
|
e.config.RosenpassPermissive,
|
||||||
&e.config.ServerSSHAllowed,
|
&e.config.ServerSSHAllowed,
|
||||||
|
&e.config.ServerVNCAllowed,
|
||||||
e.config.DisableClientRoutes,
|
e.config.DisableClientRoutes,
|
||||||
e.config.DisableServerRoutes,
|
e.config.DisableServerRoutes,
|
||||||
e.config.DisableDNS,
|
e.config.DisableDNS,
|
||||||
@@ -1149,6 +1165,7 @@ func (e *Engine) receiveManagementEvents() {
|
|||||||
e.config.EnableSSHLocalPortForwarding,
|
e.config.EnableSSHLocalPortForwarding,
|
||||||
e.config.EnableSSHRemotePortForwarding,
|
e.config.EnableSSHRemotePortForwarding,
|
||||||
e.config.DisableSSHAuth,
|
e.config.DisableSSHAuth,
|
||||||
|
e.config.DisableVNCAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||||
@@ -1323,6 +1340,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||||
|
|
||||||
|
// VNC auth: use dedicated VNCAuth if present.
|
||||||
|
if vncAuth := networkMap.GetVncAuth(); vncAuth != nil {
|
||||||
|
e.updateVNCServerAuth(vncAuth)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||||
@@ -1732,6 +1754,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
|||||||
e.config.RosenpassEnabled,
|
e.config.RosenpassEnabled,
|
||||||
e.config.RosenpassPermissive,
|
e.config.RosenpassPermissive,
|
||||||
&e.config.ServerSSHAllowed,
|
&e.config.ServerSSHAllowed,
|
||||||
|
&e.config.ServerVNCAllowed,
|
||||||
e.config.DisableClientRoutes,
|
e.config.DisableClientRoutes,
|
||||||
e.config.DisableServerRoutes,
|
e.config.DisableServerRoutes,
|
||||||
e.config.DisableDNS,
|
e.config.DisableDNS,
|
||||||
@@ -1744,6 +1767,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
|||||||
e.config.EnableSSHLocalPortForwarding,
|
e.config.EnableSSHLocalPortForwarding,
|
||||||
e.config.EnableSSHRemotePortForwarding,
|
e.config.EnableSSHRemotePortForwarding,
|
||||||
e.config.DisableSSHAuth,
|
e.config.DisableSSHAuth,
|
||||||
|
e.config.DisableVNCAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ import (
|
|||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
@@ -1634,7 +1635,12 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
peersManager := peers.NewManager(store, permissionsManager)
|
peersManager := peers.NewManager(store, permissionsManager)
|
||||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||||
|
|
||||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
|
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1656,7 +1662,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
309
client/internal/engine_vnc.go
Normal file
309
client/internal/engine_vnc.go
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
|
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||||
|
)
|
||||||
|
|
||||||
|
const envVNCForceRecording = "NB_VNC_FORCE_RECORDING"
|
||||||
|
|
||||||
|
const (
|
||||||
|
vncExternalPort uint16 = 5900
|
||||||
|
vncInternalPort uint16 = 25900
|
||||||
|
)
|
||||||
|
|
||||||
|
type vncServer interface {
|
||||||
|
Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error
|
||||||
|
Stop() error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) setupVNCPortRedirection() error {
|
||||||
|
if e.firewall == nil || e.wgInterface == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
localAddr := e.wgInterface.Address().IP
|
||||||
|
if !localAddr.IsValid() {
|
||||||
|
return errors.New("invalid local NetBird address")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
|
||||||
|
return fmt.Errorf("add VNC port redirection: %w", err)
|
||||||
|
}
|
||||||
|
log.Infof("VNC port redirection: %s:%d -> %s:%d", localAddr, vncExternalPort, localAddr, vncInternalPort)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) cleanupVNCPortRedirection() error {
|
||||||
|
if e.firewall == nil || e.wgInterface == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
localAddr := e.wgInterface.Address().IP
|
||||||
|
if !localAddr.IsValid() {
|
||||||
|
return errors.New("invalid local NetBird address")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, vncExternalPort, vncInternalPort); err != nil {
|
||||||
|
return fmt.Errorf("remove VNC port redirection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateVNC handles starting/stopping the VNC server based on the config flag.
|
||||||
|
// sshConf provides the JWT identity provider config (shared with SSH).
|
||||||
|
func (e *Engine) updateVNC(sshConf *mgmProto.SSHConfig) error {
|
||||||
|
if !e.config.ServerVNCAllowed {
|
||||||
|
if e.vncSrv != nil {
|
||||||
|
log.Info("VNC server disabled, stopping")
|
||||||
|
}
|
||||||
|
return e.stopVNCServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.config.BlockInbound {
|
||||||
|
log.Info("VNC server disabled because inbound connections are blocked")
|
||||||
|
return e.stopVNCServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.vncSrv != nil {
|
||||||
|
// Update JWT config on existing server in case management sent new config.
|
||||||
|
e.updateVNCServerJWT(sshConf)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.startVNCServer(sshConf)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) startVNCServer(sshConf *mgmProto.SSHConfig) error {
|
||||||
|
if e.wgInterface == nil {
|
||||||
|
return errors.New("wg interface not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
capturer, injector := newPlatformVNC()
|
||||||
|
if capturer == nil || injector == nil {
|
||||||
|
log.Debug("VNC server not supported on this platform")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
netbirdIP := e.wgInterface.Address().IP
|
||||||
|
|
||||||
|
srv := vncserver.New(capturer, injector, "")
|
||||||
|
if vncNeedsServiceMode() {
|
||||||
|
log.Info("VNC: running in Session 0, enabling service mode (agent proxy)")
|
||||||
|
srv.SetServiceMode(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure VNC authentication.
|
||||||
|
if e.config.DisableVNCAuth != nil && *e.config.DisableVNCAuth {
|
||||||
|
log.Info("VNC: authentication disabled by config")
|
||||||
|
srv.SetDisableAuth(true)
|
||||||
|
} else if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
|
||||||
|
audiences := protoJWT.GetAudiences()
|
||||||
|
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
|
||||||
|
audiences = []string{protoJWT.GetAudience()}
|
||||||
|
}
|
||||||
|
srv.SetJWTConfig(&vncserver.JWTConfig{
|
||||||
|
Issuer: protoJWT.GetIssuer(),
|
||||||
|
Audiences: audiences,
|
||||||
|
KeysLocation: protoJWT.GetKeysLocation(),
|
||||||
|
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||||
|
})
|
||||||
|
log.Debugf("VNC: JWT authentication configured (issuer=%s)", protoJWT.GetIssuer())
|
||||||
|
}
|
||||||
|
|
||||||
|
e.configureVNCRecording(srv, sshConf)
|
||||||
|
|
||||||
|
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||||
|
srv.SetNetstackNet(netstackNet)
|
||||||
|
}
|
||||||
|
|
||||||
|
listenAddr := netip.AddrPortFrom(netbirdIP, vncInternalPort)
|
||||||
|
network := e.wgInterface.Address().Network
|
||||||
|
if err := srv.Start(e.ctx, listenAddr, network); err != nil {
|
||||||
|
return fmt.Errorf("start VNC server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.vncSrv = srv
|
||||||
|
|
||||||
|
if registrar, ok := e.firewall.(interface {
|
||||||
|
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||||
|
}); ok {
|
||||||
|
registrar.RegisterNetstackService(nftypes.TCP, vncInternalPort)
|
||||||
|
log.Debugf("registered VNC service for TCP:%d", vncInternalPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.setupVNCPortRedirection(); err != nil {
|
||||||
|
log.Warnf("setup VNC port redirection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("VNC server enabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// configureVNCRecording enables session recording on the VNC server from the
|
||||||
|
// management-supplied settings. The env var NB_VNC_FORCE_RECORDING overrides
|
||||||
|
// the API for local development: when set, recording is always enabled and
|
||||||
|
// writes into that directory. Otherwise recordings go next to the state file
|
||||||
|
// under vnc-recordings/.
|
||||||
|
func (e *Engine) configureVNCRecording(srv *vncserver.Server, sshConf *mgmProto.SSHConfig) {
|
||||||
|
recDir := os.Getenv(envVNCForceRecording)
|
||||||
|
apiEnabled := sshConf.GetEnableRecording()
|
||||||
|
|
||||||
|
if recDir == "" && !apiEnabled {
|
||||||
|
log.Debugf("VNC recording disabled (env=%q, api=%v)", recDir, apiEnabled)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if recDir == "" {
|
||||||
|
base := e.defaultRecordingBase()
|
||||||
|
if base == "" {
|
||||||
|
log.Warn("VNC recording requested by management but no state directory is available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
recDir = filepath.Join(base, "vnc-recordings")
|
||||||
|
} else {
|
||||||
|
recDir = filepath.Join(recDir, "vnc")
|
||||||
|
}
|
||||||
|
|
||||||
|
srv.SetRecordingDir(recDir)
|
||||||
|
log.Infof("VNC recording enabled (dir=%s, source=%s)", recDir, recordingSource(apiEnabled))
|
||||||
|
|
||||||
|
encKey := string(sshConf.GetRecordingEncryptionKey())
|
||||||
|
if encKey == "" {
|
||||||
|
encKey = os.Getenv("NB_VNC_RECORDING_ENCRYPTION_KEY")
|
||||||
|
}
|
||||||
|
if encKey != "" {
|
||||||
|
srv.SetRecordingEncryptionKey(encKey)
|
||||||
|
log.Info("VNC recording encryption enabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) defaultRecordingBase() string {
|
||||||
|
if e.stateManager == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
p := e.stateManager.FilePath()
|
||||||
|
if p == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return filepath.Dir(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func recordingSource(api bool) string {
|
||||||
|
if api {
|
||||||
|
return "management"
|
||||||
|
}
|
||||||
|
return "env"
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateVNCServerJWT configures the JWT validation for the VNC server using
|
||||||
|
// the same JWT config as SSH (same identity provider).
|
||||||
|
func (e *Engine) updateVNCServerJWT(sshConf *mgmProto.SSHConfig) {
|
||||||
|
if e.vncSrv == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vncSrv, ok := e.vncSrv.(*vncserver.Server)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.config.DisableVNCAuth != nil && *e.config.DisableVNCAuth {
|
||||||
|
vncSrv.SetDisableAuth(true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
protoJWT := sshConf.GetJwtConfig()
|
||||||
|
if protoJWT == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
audiences := protoJWT.GetAudiences()
|
||||||
|
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
|
||||||
|
audiences = []string{protoJWT.GetAudience()}
|
||||||
|
}
|
||||||
|
|
||||||
|
vncSrv.SetJWTConfig(&vncserver.JWTConfig{
|
||||||
|
Issuer: protoJWT.GetIssuer(),
|
||||||
|
Audiences: audiences,
|
||||||
|
KeysLocation: protoJWT.GetKeysLocation(),
|
||||||
|
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateVNCServerAuth updates VNC fine-grained access control from management.
|
||||||
|
func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
|
||||||
|
if vncAuth == nil || e.vncSrv == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vncSrv, ok := e.vncSrv.(*vncserver.Server)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
protoUsers := vncAuth.GetAuthorizedUsers()
|
||||||
|
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
|
||||||
|
for i, hash := range protoUsers {
|
||||||
|
if len(hash) != 16 {
|
||||||
|
log.Warnf("invalid VNC auth hash length %d, expected 16", len(hash))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
machineUsers := make(map[string][]uint32)
|
||||||
|
for osUser, indexes := range vncAuth.GetMachineUsers() {
|
||||||
|
machineUsers[osUser] = indexes.GetIndexes()
|
||||||
|
}
|
||||||
|
|
||||||
|
vncSrv.UpdateVNCAuth(&sshauth.Config{
|
||||||
|
UserIDClaim: vncAuth.GetUserIDClaim(),
|
||||||
|
AuthorizedUsers: authorizedUsers,
|
||||||
|
MachineUsers: machineUsers,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetVNCServerStatus returns whether the VNC server is running.
|
||||||
|
func (e *Engine) GetVNCServerStatus() bool {
|
||||||
|
return e.vncSrv != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) stopVNCServer() error {
|
||||||
|
if e.vncSrv == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := e.cleanupVNCPortRedirection(); err != nil {
|
||||||
|
log.Warnf("cleanup VNC port redirection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if registrar, ok := e.firewall.(interface {
|
||||||
|
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||||
|
}); ok {
|
||||||
|
registrar.UnregisterNetstackService(nftypes.TCP, vncInternalPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("stopping VNC server")
|
||||||
|
err := e.vncSrv.Stop()
|
||||||
|
e.vncSrv = nil
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("stop VNC server: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
23
client/internal/engine_vnc_darwin.go
Normal file
23
client/internal/engine_vnc_darwin.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
|
||||||
|
capturer := vncserver.NewMacPoller()
|
||||||
|
injector, err := vncserver.NewMacInputInjector()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("VNC: macOS input injector: %v", err)
|
||||||
|
return capturer, &vncserver.StubInputInjector{}
|
||||||
|
}
|
||||||
|
return capturer, injector
|
||||||
|
}
|
||||||
|
|
||||||
|
func vncNeedsServiceMode() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
13
client/internal/engine_vnc_stub.go
Normal file
13
client/internal/engine_vnc_stub.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
//go:build !windows && !darwin && !freebsd && !(linux && !android)
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||||
|
|
||||||
|
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func vncNeedsServiceMode() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
13
client/internal/engine_vnc_windows.go
Normal file
13
client/internal/engine_vnc_windows.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||||
|
|
||||||
|
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
|
||||||
|
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector()
|
||||||
|
}
|
||||||
|
|
||||||
|
func vncNeedsServiceMode() bool {
|
||||||
|
return vncserver.GetCurrentSessionID() == 0
|
||||||
|
}
|
||||||
23
client/internal/engine_vnc_x11.go
Normal file
23
client/internal/engine_vnc_x11.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
//go:build (linux && !android) || freebsd
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector) {
|
||||||
|
capturer := vncserver.NewX11Poller("")
|
||||||
|
injector, err := vncserver.NewX11InputInjector("")
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("VNC: X11 input injector: %v", err)
|
||||||
|
return capturer, &vncserver.StubInputInjector{}
|
||||||
|
}
|
||||||
|
return capturer, injector
|
||||||
|
}
|
||||||
|
|
||||||
|
func vncNeedsServiceMode() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -22,4 +22,8 @@ type MobileDependency struct {
|
|||||||
DnsManager dns.IosDnsManager
|
DnsManager dns.IosDnsManager
|
||||||
FileDescriptor int32
|
FileDescriptor int32
|
||||||
StateFilePath string
|
StateFilePath string
|
||||||
|
|
||||||
|
// TempDir is a writable directory for temporary files (e.g., debug bundle zip).
|
||||||
|
// On Android, this should be set to the app's cache directory.
|
||||||
|
TempDir string
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
nfct "github.com/ti-mo/conntrack"
|
nfct "github.com/ti-mo/conntrack"
|
||||||
@@ -17,31 +19,64 @@ import (
|
|||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultChannelSize = 100
|
const (
|
||||||
|
defaultChannelSize = 100
|
||||||
|
reconnectInitInterval = 5 * time.Second
|
||||||
|
reconnectMaxInterval = 5 * time.Minute
|
||||||
|
reconnectRandomization = 0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
// listener abstracts a netlink conntrack connection for testability.
|
||||||
|
type listener interface {
|
||||||
|
Listen(evChan chan<- nfct.Event, numWorkers uint8, groups []netfilter.NetlinkGroup) (chan error, error)
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
// ConnTrack manages kernel-based conntrack events
|
// ConnTrack manages kernel-based conntrack events
|
||||||
type ConnTrack struct {
|
type ConnTrack struct {
|
||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
iface nftypes.IFaceMapper
|
iface nftypes.IFaceMapper
|
||||||
|
|
||||||
conn *nfct.Conn
|
conn listener
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
|
|
||||||
|
dial func() (listener, error)
|
||||||
instanceID uuid.UUID
|
instanceID uuid.UUID
|
||||||
started bool
|
started bool
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
sysctlModified bool
|
sysctlModified bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DialFunc is a constructor for netlink conntrack connections.
|
||||||
|
type DialFunc func() (listener, error)
|
||||||
|
|
||||||
|
// Option configures a ConnTrack instance.
|
||||||
|
type Option func(*ConnTrack)
|
||||||
|
|
||||||
|
// WithDialer overrides the default netlink dialer, primarily for testing.
|
||||||
|
func WithDialer(dial DialFunc) Option {
|
||||||
|
return func(c *ConnTrack) {
|
||||||
|
c.dial = dial
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultDial() (listener, error) {
|
||||||
|
return nfct.Dial(nil)
|
||||||
|
}
|
||||||
|
|
||||||
// New creates a new connection tracker that interfaces with the kernel's conntrack system
|
// New creates a new connection tracker that interfaces with the kernel's conntrack system
|
||||||
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack {
|
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper, opts ...Option) *ConnTrack {
|
||||||
return &ConnTrack{
|
ct := &ConnTrack{
|
||||||
flowLogger: flowLogger,
|
flowLogger: flowLogger,
|
||||||
iface: iface,
|
iface: iface,
|
||||||
instanceID: uuid.New(),
|
instanceID: uuid.New(),
|
||||||
started: false,
|
dial: defaultDial,
|
||||||
done: make(chan struct{}, 1),
|
done: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(ct)
|
||||||
|
}
|
||||||
|
return ct
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start begins tracking connections by listening for conntrack events. This method is idempotent.
|
// Start begins tracking connections by listening for conntrack events. This method is idempotent.
|
||||||
@@ -59,8 +94,9 @@ func (c *ConnTrack) Start(enableCounters bool) error {
|
|||||||
c.EnableAccounting()
|
c.EnableAccounting()
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := nfct.Dial(nil)
|
conn, err := c.dial()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
c.RestoreAccounting()
|
||||||
return fmt.Errorf("dial conntrack: %w", err)
|
return fmt.Errorf("dial conntrack: %w", err)
|
||||||
}
|
}
|
||||||
c.conn = conn
|
c.conn = conn
|
||||||
@@ -76,9 +112,16 @@ func (c *ConnTrack) Start(enableCounters bool) error {
|
|||||||
log.Errorf("Error closing conntrack connection: %v", err)
|
log.Errorf("Error closing conntrack connection: %v", err)
|
||||||
}
|
}
|
||||||
c.conn = nil
|
c.conn = nil
|
||||||
|
c.RestoreAccounting()
|
||||||
return fmt.Errorf("start conntrack listener: %w", err)
|
return fmt.Errorf("start conntrack listener: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Drain any stale stop signal from a previous cycle.
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
c.started = true
|
c.started = true
|
||||||
|
|
||||||
go c.receiverRoutine(events, errChan)
|
go c.receiverRoutine(events, errChan)
|
||||||
@@ -92,17 +135,98 @@ func (c *ConnTrack) receiverRoutine(events chan nfct.Event, errChan chan error)
|
|||||||
case event := <-events:
|
case event := <-events:
|
||||||
c.handleEvent(event)
|
c.handleEvent(event)
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
log.Errorf("Error from conntrack event listener: %v", err)
|
if events, errChan = c.handleListenerError(err); events == nil {
|
||||||
if err := c.conn.Close(); err != nil {
|
|
||||||
log.Errorf("Error closing conntrack connection: %v", err)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
|
}
|
||||||
case <-c.done:
|
case <-c.done:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleListenerError closes the failed connection and attempts to reconnect.
|
||||||
|
// Returns new channels on success, or nil if shutdown was requested.
|
||||||
|
func (c *ConnTrack) handleListenerError(err error) (chan nfct.Event, chan error) {
|
||||||
|
log.Warnf("conntrack event listener failed: %v", err)
|
||||||
|
c.closeConn()
|
||||||
|
return c.reconnect()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConnTrack) closeConn() {
|
||||||
|
c.mux.Lock()
|
||||||
|
defer c.mux.Unlock()
|
||||||
|
|
||||||
|
if c.conn != nil {
|
||||||
|
if err := c.conn.Close(); err != nil {
|
||||||
|
log.Debugf("close conntrack connection: %v", err)
|
||||||
|
}
|
||||||
|
c.conn = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reconnect attempts to re-establish the conntrack netlink listener with exponential backoff.
|
||||||
|
// Returns new channels on success, or nil if shutdown was requested.
|
||||||
|
func (c *ConnTrack) reconnect() (chan nfct.Event, chan error) {
|
||||||
|
bo := &backoff.ExponentialBackOff{
|
||||||
|
InitialInterval: reconnectInitInterval,
|
||||||
|
RandomizationFactor: reconnectRandomization,
|
||||||
|
Multiplier: backoff.DefaultMultiplier,
|
||||||
|
MaxInterval: reconnectMaxInterval,
|
||||||
|
MaxElapsedTime: 0, // retry indefinitely
|
||||||
|
Clock: backoff.SystemClock,
|
||||||
|
}
|
||||||
|
bo.Reset()
|
||||||
|
|
||||||
|
for {
|
||||||
|
delay := bo.NextBackOff()
|
||||||
|
log.Infof("reconnecting conntrack listener in %s", delay)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
c.mux.Lock()
|
||||||
|
c.started = false
|
||||||
|
c.mux.Unlock()
|
||||||
|
return nil, nil
|
||||||
|
case <-time.After(delay):
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := c.dial()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("reconnect conntrack dial: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
events := make(chan nfct.Event, defaultChannelSize)
|
||||||
|
errChan, err := conn.Listen(events, 1, []netfilter.NetlinkGroup{
|
||||||
|
netfilter.GroupCTNew,
|
||||||
|
netfilter.GroupCTDestroy,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("reconnect conntrack listen: %v", err)
|
||||||
|
if closeErr := conn.Close(); closeErr != nil {
|
||||||
|
log.Debugf("close conntrack connection: %v", closeErr)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mux.Lock()
|
||||||
|
if !c.started {
|
||||||
|
// Stop() ran while we were reconnecting.
|
||||||
|
c.mux.Unlock()
|
||||||
|
if closeErr := conn.Close(); closeErr != nil {
|
||||||
|
log.Debugf("close conntrack connection: %v", closeErr)
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
c.conn = conn
|
||||||
|
c.mux.Unlock()
|
||||||
|
|
||||||
|
log.Infof("conntrack listener reconnected successfully")
|
||||||
|
|
||||||
|
return events, errChan
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Stop stops the connection tracking. This method is idempotent.
|
// Stop stops the connection tracking. This method is idempotent.
|
||||||
func (c *ConnTrack) Stop() {
|
func (c *ConnTrack) Stop() {
|
||||||
c.mux.Lock()
|
c.mux.Lock()
|
||||||
@@ -136,23 +260,27 @@ func (c *ConnTrack) Close() error {
|
|||||||
c.mux.Lock()
|
c.mux.Lock()
|
||||||
defer c.mux.Unlock()
|
defer c.mux.Unlock()
|
||||||
|
|
||||||
if c.started {
|
if !c.started {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case c.done <- struct{}{}:
|
case c.done <- struct{}{}:
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if c.conn != nil {
|
|
||||||
err := c.conn.Close()
|
|
||||||
c.conn = nil
|
|
||||||
c.started = false
|
c.started = false
|
||||||
|
|
||||||
|
var closeErr error
|
||||||
|
if c.conn != nil {
|
||||||
|
closeErr = c.conn.Close()
|
||||||
|
c.conn = nil
|
||||||
|
}
|
||||||
|
|
||||||
c.RestoreAccounting()
|
c.RestoreAccounting()
|
||||||
|
|
||||||
if err != nil {
|
if closeErr != nil {
|
||||||
return fmt.Errorf("close conntrack: %w", err)
|
return fmt.Errorf("close conntrack: %w", closeErr)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
224
client/internal/netflow/conntrack/conntrack_test.go
Normal file
224
client/internal/netflow/conntrack/conntrack_test.go
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
nfct "github.com/ti-mo/conntrack"
|
||||||
|
"github.com/ti-mo/netfilter"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockListener struct {
|
||||||
|
errChan chan error
|
||||||
|
closed atomic.Bool
|
||||||
|
closedCh chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockListener() *mockListener {
|
||||||
|
return &mockListener{
|
||||||
|
errChan: make(chan error, 1),
|
||||||
|
closedCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockListener) Listen(evChan chan<- nfct.Event, _ uint8, _ []netfilter.NetlinkGroup) (chan error, error) {
|
||||||
|
return m.errChan, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockListener) Close() error {
|
||||||
|
if m.closed.CompareAndSwap(false, true) {
|
||||||
|
close(m.closedCh)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReconnectAfterError(t *testing.T) {
|
||||||
|
first := newMockListener()
|
||||||
|
second := newMockListener()
|
||||||
|
third := newMockListener()
|
||||||
|
listeners := []*mockListener{first, second, third}
|
||||||
|
callCount := atomic.Int32{}
|
||||||
|
|
||||||
|
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
||||||
|
n := int(callCount.Add(1)) - 1
|
||||||
|
return listeners[n], nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
err := ct.Start(false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Inject an error on the first listener.
|
||||||
|
first.errChan <- assert.AnError
|
||||||
|
|
||||||
|
// Wait for reconnect to complete.
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
return callCount.Load() >= 2
|
||||||
|
}, 15*time.Second, 100*time.Millisecond, "reconnect should dial a new connection")
|
||||||
|
|
||||||
|
// The first connection must have been closed.
|
||||||
|
select {
|
||||||
|
case <-first.closedCh:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("first connection was not closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the receiver is still running by injecting and handling a second error.
|
||||||
|
second.errChan <- assert.AnError
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
return callCount.Load() >= 3
|
||||||
|
}, 15*time.Second, 100*time.Millisecond, "second reconnect should succeed")
|
||||||
|
|
||||||
|
ct.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStopDuringReconnectBackoff(t *testing.T) {
|
||||||
|
mock := newMockListener()
|
||||||
|
|
||||||
|
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
||||||
|
return mock, nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
err := ct.Start(false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Trigger an error so the receiver enters reconnect.
|
||||||
|
mock.errChan <- assert.AnError
|
||||||
|
|
||||||
|
// Wait for the error handler to close the old listener before calling Stop.
|
||||||
|
select {
|
||||||
|
case <-mock.closedCh:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for reconnect to start")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop while reconnecting.
|
||||||
|
ct.Stop()
|
||||||
|
|
||||||
|
ct.mux.Lock()
|
||||||
|
assert.False(t, ct.started, "started should be false after Stop")
|
||||||
|
assert.Nil(t, ct.conn, "conn should be nil after Stop")
|
||||||
|
ct.mux.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStopRaceWithReconnectDial(t *testing.T) {
|
||||||
|
first := newMockListener()
|
||||||
|
dialStarted := make(chan struct{})
|
||||||
|
dialProceed := make(chan struct{})
|
||||||
|
second := newMockListener()
|
||||||
|
callCount := atomic.Int32{}
|
||||||
|
|
||||||
|
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
||||||
|
n := callCount.Add(1)
|
||||||
|
if n == 1 {
|
||||||
|
return first, nil
|
||||||
|
}
|
||||||
|
// Second dial: signal that we're in progress, wait for test to call Stop.
|
||||||
|
close(dialStarted)
|
||||||
|
<-dialProceed
|
||||||
|
return second, nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
err := ct.Start(false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Trigger error to enter reconnect.
|
||||||
|
first.errChan <- assert.AnError
|
||||||
|
|
||||||
|
// Wait for reconnect's second dial to begin.
|
||||||
|
select {
|
||||||
|
case <-dialStarted:
|
||||||
|
case <-time.After(15 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for reconnect dial")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop while dial is in progress (conn is nil at this point).
|
||||||
|
ct.Stop()
|
||||||
|
|
||||||
|
// Let the dial complete. reconnect should detect started==false and close the new conn.
|
||||||
|
close(dialProceed)
|
||||||
|
|
||||||
|
// The second connection should be closed (not leaked).
|
||||||
|
select {
|
||||||
|
case <-second.closedCh:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("second connection was leaked after Stop")
|
||||||
|
}
|
||||||
|
|
||||||
|
ct.mux.Lock()
|
||||||
|
assert.False(t, ct.started)
|
||||||
|
assert.Nil(t, ct.conn)
|
||||||
|
ct.mux.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloseRaceWithReconnectDial(t *testing.T) {
|
||||||
|
first := newMockListener()
|
||||||
|
dialStarted := make(chan struct{})
|
||||||
|
dialProceed := make(chan struct{})
|
||||||
|
second := newMockListener()
|
||||||
|
callCount := atomic.Int32{}
|
||||||
|
|
||||||
|
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
||||||
|
n := callCount.Add(1)
|
||||||
|
if n == 1 {
|
||||||
|
return first, nil
|
||||||
|
}
|
||||||
|
close(dialStarted)
|
||||||
|
<-dialProceed
|
||||||
|
return second, nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
err := ct.Start(false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
first.errChan <- assert.AnError
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-dialStarted:
|
||||||
|
case <-time.After(15 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for reconnect dial")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close while dial is in progress (conn is nil).
|
||||||
|
require.NoError(t, ct.Close())
|
||||||
|
|
||||||
|
close(dialProceed)
|
||||||
|
|
||||||
|
// The second connection should be closed (not leaked).
|
||||||
|
select {
|
||||||
|
case <-second.closedCh:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("second connection was leaked after Close")
|
||||||
|
}
|
||||||
|
|
||||||
|
ct.mux.Lock()
|
||||||
|
assert.False(t, ct.started)
|
||||||
|
assert.Nil(t, ct.conn)
|
||||||
|
ct.mux.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartIsIdempotent(t *testing.T) {
|
||||||
|
mock := newMockListener()
|
||||||
|
callCount := atomic.Int32{}
|
||||||
|
|
||||||
|
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
||||||
|
callCount.Add(1)
|
||||||
|
return mock, nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
err := ct.Start(false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Second Start should be a no-op.
|
||||||
|
err = ct.Start(false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, int32(1), callCount.Load(), "dial should only be called once")
|
||||||
|
|
||||||
|
ct.Stop()
|
||||||
|
}
|
||||||
@@ -9,17 +9,26 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
|
envDisableNATMapper = "NB_DISABLE_NAT_MAPPER"
|
||||||
|
envDisablePCPHealthCheck = "NB_DISABLE_PCP_HEALTH_CHECK"
|
||||||
)
|
)
|
||||||
|
|
||||||
func isDisabledByEnv() bool {
|
func isDisabledByEnv() bool {
|
||||||
val := os.Getenv(envDisableNATMapper)
|
return parseBoolEnv(envDisableNATMapper)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isHealthCheckDisabled() bool {
|
||||||
|
return parseBoolEnv(envDisablePCPHealthCheck)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseBoolEnv(key string) bool {
|
||||||
|
val := os.Getenv(key)
|
||||||
if val == "" {
|
if val == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
disabled, err := strconv.ParseBool(val)
|
disabled, err := strconv.ParseBool(val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to parse %s: %v", envDisableNATMapper, err)
|
log.Warnf("failed to parse %s: %v", key, err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return disabled
|
return disabled
|
||||||
|
|||||||
@@ -12,10 +12,13 @@ import (
|
|||||||
|
|
||||||
"github.com/libp2p/go-nat"
|
"github.com/libp2p/go-nat"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/portforward/pcp"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
defaultMappingTTL = 2 * time.Hour
|
defaultMappingTTL = 2 * time.Hour
|
||||||
|
healthCheckInterval = 1 * time.Minute
|
||||||
discoveryTimeout = 10 * time.Second
|
discoveryTimeout = 10 * time.Second
|
||||||
mappingDescription = "NetBird"
|
mappingDescription = "NetBird"
|
||||||
)
|
)
|
||||||
@@ -154,7 +157,7 @@ func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) {
|
|||||||
discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout)
|
discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout)
|
||||||
defer discoverCancel()
|
defer discoverCancel()
|
||||||
|
|
||||||
gateway, err := nat.DiscoverGateway(discoverCtx)
|
gateway, err := discoverGateway(discoverCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("discover gateway: %w", err)
|
return nil, nil, fmt.Errorf("discover gateway: %w", err)
|
||||||
}
|
}
|
||||||
@@ -189,7 +192,6 @@ func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping,
|
|||||||
externalIP, err := gateway.GetExternalAddress()
|
externalIP, err := gateway.GetExternalAddress()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to get external address: %v", err)
|
log.Debugf("failed to get external address: %v", err)
|
||||||
// todo return with err?
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mapping := &Mapping{
|
mapping := &Mapping{
|
||||||
@@ -208,26 +210,86 @@ func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping,
|
|||||||
|
|
||||||
func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT, ttl time.Duration) {
|
func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT, ttl time.Duration) {
|
||||||
if ttl == 0 {
|
if ttl == 0 {
|
||||||
// Permanent mappings don't expire, just wait for cancellation.
|
// Permanent mappings don't expire, just wait for cancellation
|
||||||
<-ctx.Done()
|
// but still run health checks for PCP gateways.
|
||||||
|
m.permanentLeaseLoop(ctx, gateway)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ticker := time.NewTicker(ttl / 2)
|
renewTicker := time.NewTicker(ttl / 2)
|
||||||
defer ticker.Stop()
|
healthTicker := time.NewTicker(healthCheckInterval)
|
||||||
|
defer renewTicker.Stop()
|
||||||
|
defer healthTicker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-renewTicker.C:
|
||||||
if err := m.renewMapping(ctx, gateway); err != nil {
|
if err := m.renewMapping(ctx, gateway); err != nil {
|
||||||
log.Warnf("failed to renew port mapping: %v", err)
|
log.Warnf("failed to renew port mapping: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
case <-healthTicker.C:
|
||||||
|
if m.checkHealthAndRecreate(ctx, gateway) {
|
||||||
|
renewTicker.Reset(ttl / 2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) permanentLeaseLoop(ctx context.Context, gateway nat.NAT) {
|
||||||
|
healthTicker := time.NewTicker(healthCheckInterval)
|
||||||
|
defer healthTicker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-healthTicker.C:
|
||||||
|
m.checkHealthAndRecreate(ctx, gateway)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) checkHealthAndRecreate(ctx context.Context, gateway nat.NAT) bool {
|
||||||
|
if isHealthCheckDisabled() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mappingLock.Lock()
|
||||||
|
hasMapping := m.mapping != nil
|
||||||
|
m.mappingLock.Unlock()
|
||||||
|
|
||||||
|
if !hasMapping {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
pcpNAT, ok := gateway.(*pcp.NAT)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
epoch, serverRestarted, err := pcpNAT.CheckServerHealth(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("PCP health check failed: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if serverRestarted {
|
||||||
|
log.Warnf("PCP server restart detected (epoch=%d), recreating port mapping", epoch)
|
||||||
|
if err := m.renewMapping(ctx, gateway); err != nil {
|
||||||
|
log.Errorf("failed to recreate port mapping after server restart: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error {
|
func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error {
|
||||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
|||||||
408
client/internal/portforward/pcp/client.go
Normal file
408
client/internal/portforward/pcp/client.go
Normal file
@@ -0,0 +1,408 @@
|
|||||||
|
package pcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultTimeout = 3 * time.Second
|
||||||
|
responseBufferSize = 128
|
||||||
|
|
||||||
|
// RFC 6887 Section 8.1.1 retry timing
|
||||||
|
initialRetryDelay = 3 * time.Second
|
||||||
|
maxRetryDelay = 1024 * time.Second
|
||||||
|
maxRetries = 4 // 3s + 6s + 12s + 24s = 45s total worst case
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client is a PCP protocol client.
|
||||||
|
// All methods are safe for concurrent use.
|
||||||
|
type Client struct {
|
||||||
|
gateway netip.Addr
|
||||||
|
timeout time.Duration
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
// localIP caches the resolved local IP address.
|
||||||
|
localIP netip.Addr
|
||||||
|
// lastEpoch is the last observed server epoch value.
|
||||||
|
lastEpoch uint32
|
||||||
|
// epochTime tracks when lastEpoch was received for state loss detection.
|
||||||
|
epochTime time.Time
|
||||||
|
// externalIP caches the external IP from the last successful MAP response.
|
||||||
|
externalIP netip.Addr
|
||||||
|
// epochStateLost is set when epoch indicates server restart.
|
||||||
|
epochStateLost bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient creates a new PCP client for the gateway at the given IP.
|
||||||
|
func NewClient(gateway net.IP) *Client {
|
||||||
|
addr, ok := netip.AddrFromSlice(gateway)
|
||||||
|
if !ok {
|
||||||
|
log.Debugf("invalid gateway IP: %v", gateway)
|
||||||
|
}
|
||||||
|
return &Client{
|
||||||
|
gateway: addr.Unmap(),
|
||||||
|
timeout: defaultTimeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClientWithTimeout creates a new PCP client with a custom timeout.
|
||||||
|
func NewClientWithTimeout(gateway net.IP, timeout time.Duration) *Client {
|
||||||
|
addr, ok := netip.AddrFromSlice(gateway)
|
||||||
|
if !ok {
|
||||||
|
log.Debugf("invalid gateway IP: %v", gateway)
|
||||||
|
}
|
||||||
|
return &Client{
|
||||||
|
gateway: addr.Unmap(),
|
||||||
|
timeout: timeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLocalIP sets the local IP address to use in PCP requests.
|
||||||
|
func (c *Client) SetLocalIP(ip net.IP) {
|
||||||
|
addr, ok := netip.AddrFromSlice(ip)
|
||||||
|
if !ok {
|
||||||
|
log.Debugf("invalid local IP: %v", ip)
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
c.localIP = addr.Unmap()
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gateway returns the gateway IP address.
|
||||||
|
func (c *Client) Gateway() net.IP {
|
||||||
|
return c.gateway.AsSlice()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Announce sends a PCP ANNOUNCE request to discover PCP support.
|
||||||
|
// Returns the server's epoch time on success.
|
||||||
|
func (c *Client) Announce(ctx context.Context) (epoch uint32, err error) {
|
||||||
|
localIP, err := c.getLocalIP()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("get local IP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := buildAnnounceRequest(localIP)
|
||||||
|
resp, err := c.sendRequest(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("send announce: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, err := parseResponse(resp)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("parse announce response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if parsed.ResultCode != ResultSuccess {
|
||||||
|
return 0, fmt.Errorf("PCP ANNOUNCE failed: %s", ResultCodeString(parsed.ResultCode))
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.updateEpochLocked(parsed.Epoch) {
|
||||||
|
log.Warnf("PCP server epoch indicates state loss - mappings may need refresh")
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
return parsed.Epoch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddPortMapping requests a port mapping from the PCP server.
|
||||||
|
func (c *Client) AddPortMapping(ctx context.Context, protocol string, internalPort int, lifetime time.Duration) (*MapResponse, error) {
|
||||||
|
return c.addPortMappingWithHint(ctx, protocol, internalPort, internalPort, netip.Addr{}, lifetime)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddPortMappingWithHint requests a port mapping with suggested external port and IP.
|
||||||
|
// Use lifetime <= 0 to delete a mapping.
|
||||||
|
func (c *Client) AddPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP net.IP, lifetime time.Duration) (*MapResponse, error) {
|
||||||
|
var extIP netip.Addr
|
||||||
|
if suggestedExtIP != nil {
|
||||||
|
var ok bool
|
||||||
|
extIP, ok = netip.AddrFromSlice(suggestedExtIP)
|
||||||
|
if !ok {
|
||||||
|
log.Debugf("invalid suggested external IP: %v", suggestedExtIP)
|
||||||
|
}
|
||||||
|
extIP = extIP.Unmap()
|
||||||
|
}
|
||||||
|
return c.addPortMappingWithHint(ctx, protocol, internalPort, suggestedExtPort, extIP, lifetime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) addPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP netip.Addr, lifetime time.Duration) (*MapResponse, error) {
|
||||||
|
localIP, err := c.getLocalIP()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get local IP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
proto, err := protocolNumber(protocol)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse protocol: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var nonce [12]byte
|
||||||
|
if _, err := rand.Read(nonce[:]); err != nil {
|
||||||
|
return nil, fmt.Errorf("generate nonce: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert lifetime to seconds. Lifetime 0 means delete, so only apply
|
||||||
|
// default for positive durations that round to 0 seconds.
|
||||||
|
var lifetimeSec uint32
|
||||||
|
if lifetime > 0 {
|
||||||
|
lifetimeSec = uint32(lifetime.Seconds())
|
||||||
|
if lifetimeSec == 0 {
|
||||||
|
lifetimeSec = DefaultLifetime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req := buildMapRequest(localIP, nonce, proto, uint16(internalPort), uint16(suggestedExtPort), suggestedExtIP, lifetimeSec)
|
||||||
|
|
||||||
|
resp, err := c.sendRequest(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("send map request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mapResp, err := parseMapResponse(resp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse map response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if mapResp.Nonce != nonce {
|
||||||
|
return nil, fmt.Errorf("nonce mismatch in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
if mapResp.Protocol != proto {
|
||||||
|
return nil, fmt.Errorf("protocol mismatch: requested %d, got %d", proto, mapResp.Protocol)
|
||||||
|
}
|
||||||
|
if mapResp.InternalPort != uint16(internalPort) {
|
||||||
|
return nil, fmt.Errorf("internal port mismatch: requested %d, got %d", internalPort, mapResp.InternalPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
if mapResp.ResultCode != ResultSuccess {
|
||||||
|
return nil, &Error{
|
||||||
|
Code: mapResp.ResultCode,
|
||||||
|
Message: ResultCodeString(mapResp.ResultCode),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.updateEpochLocked(mapResp.Epoch) {
|
||||||
|
log.Warnf("PCP server epoch indicates state loss - mappings may need refresh")
|
||||||
|
}
|
||||||
|
c.cacheExternalIPLocked(mapResp.ExternalIP)
|
||||||
|
c.mu.Unlock()
|
||||||
|
return mapResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePortMapping removes a port mapping by requesting zero lifetime.
|
||||||
|
func (c *Client) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
|
||||||
|
if _, err := c.addPortMappingWithHint(ctx, protocol, internalPort, 0, netip.Addr{}, 0); err != nil {
|
||||||
|
var pcpErr *Error
|
||||||
|
if errors.As(err, &pcpErr) && pcpErr.Code == ResultNotAuthorized {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("delete mapping: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetExternalAddress returns the external IP address.
|
||||||
|
// First checks for a cached value from previous MAP responses.
|
||||||
|
// If not cached, creates a short-lived mapping to discover the external IP.
|
||||||
|
func (c *Client) GetExternalAddress(ctx context.Context) (net.IP, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.externalIP.IsValid() {
|
||||||
|
ip := c.externalIP.AsSlice()
|
||||||
|
c.mu.Unlock()
|
||||||
|
return ip, nil
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
// Use an ephemeral port in the dynamic range (49152-65535).
|
||||||
|
// Port 0 is not valid with UDP/TCP protocols per RFC 6887.
|
||||||
|
ephemeralPort := 49152 + int(uint16(time.Now().UnixNano()))%(65535-49152)
|
||||||
|
|
||||||
|
// Use minimal lifetime (1 second) for discovery.
|
||||||
|
resp, err := c.AddPortMapping(ctx, "udp", ephemeralPort, time.Second)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create temporary mapping: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.DeletePortMapping(ctx, "udp", ephemeralPort); err != nil {
|
||||||
|
log.Debugf("cleanup temporary PCP mapping: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp.ExternalIP.AsSlice(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LastEpoch returns the last observed server epoch value.
|
||||||
|
// A decrease in epoch indicates the server may have restarted and mappings may be lost.
|
||||||
|
func (c *Client) LastEpoch() uint32 {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
return c.lastEpoch
|
||||||
|
}
|
||||||
|
|
||||||
|
// EpochStateLost returns true if epoch state loss was detected and clears the flag.
|
||||||
|
func (c *Client) EpochStateLost() bool {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
lost := c.epochStateLost
|
||||||
|
c.epochStateLost = false
|
||||||
|
return lost
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateEpoch updates the epoch tracking and detects potential state loss.
|
||||||
|
// Returns true if state loss was detected (server likely restarted).
|
||||||
|
// Caller must hold c.mu.
|
||||||
|
func (c *Client) updateEpochLocked(newEpoch uint32) bool {
|
||||||
|
now := time.Now()
|
||||||
|
stateLost := false
|
||||||
|
|
||||||
|
// RFC 6887 Section 8.5: Detect invalid epoch indicating server state loss.
|
||||||
|
// client_delta = time since last response
|
||||||
|
// server_delta = epoch change since last response
|
||||||
|
// Invalid if: client_delta+2 < server_delta - server_delta/16
|
||||||
|
// OR: server_delta+2 < client_delta - client_delta/16
|
||||||
|
// The +2 handles quantization, /16 (6.25%) handles clock drift.
|
||||||
|
if !c.epochTime.IsZero() && c.lastEpoch > 0 {
|
||||||
|
clientDelta := uint32(now.Sub(c.epochTime).Seconds())
|
||||||
|
serverDelta := newEpoch - c.lastEpoch
|
||||||
|
|
||||||
|
// Check for epoch going backwards or jumping unexpectedly.
|
||||||
|
// Subtraction is safe: serverDelta/16 is always <= serverDelta.
|
||||||
|
if clientDelta+2 < serverDelta-(serverDelta/16) ||
|
||||||
|
serverDelta+2 < clientDelta-(clientDelta/16) {
|
||||||
|
stateLost = true
|
||||||
|
c.epochStateLost = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.lastEpoch = newEpoch
|
||||||
|
c.epochTime = now
|
||||||
|
return stateLost
|
||||||
|
}
|
||||||
|
|
||||||
|
// cacheExternalIP stores the external IP from a successful MAP response.
|
||||||
|
// Caller must hold c.mu.
|
||||||
|
func (c *Client) cacheExternalIPLocked(ip netip.Addr) {
|
||||||
|
if ip.IsValid() && !ip.IsUnspecified() {
|
||||||
|
c.externalIP = ip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendRequest sends a PCP request with retries per RFC 6887 Section 8.1.1.
|
||||||
|
func (c *Client) sendRequest(ctx context.Context, req []byte) ([]byte, error) {
|
||||||
|
addr := &net.UDPAddr{IP: c.gateway.AsSlice(), Port: Port}
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
delay := initialRetryDelay
|
||||||
|
|
||||||
|
for range maxRetries {
|
||||||
|
resp, err := c.sendOnce(ctx, addr, req)
|
||||||
|
if err == nil {
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
lastErr = err
|
||||||
|
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RFC 6887 Section 8.1.1: RT = (1 + RAND) * MIN(2 * RTprev, MRT)
|
||||||
|
// RAND is random between -0.1 and +0.1
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(retryDelayWithJitter(delay)):
|
||||||
|
}
|
||||||
|
delay = min(delay*2, maxRetryDelay)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("PCP request failed after %d retries: %w", maxRetries, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// retryDelayWithJitter applies RFC 6887 jitter: multiply by (1 + RAND) where RAND is [-0.1, +0.1].
|
||||||
|
func retryDelayWithJitter(d time.Duration) time.Duration {
|
||||||
|
var b [1]byte
|
||||||
|
_, _ = rand.Read(b[:])
|
||||||
|
// Convert byte to range [-0.1, +0.1]: (b/255 * 0.2) - 0.1
|
||||||
|
jitter := (float64(b[0])/255.0)*0.2 - 0.1
|
||||||
|
return time.Duration(float64(d) * (1 + jitter))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) sendOnce(ctx context.Context, addr *net.UDPAddr, req []byte) ([]byte, error) {
|
||||||
|
// Use ListenUDP instead of DialUDP to validate response source address per RFC 6887 §8.3.
|
||||||
|
conn, err := net.ListenUDP("udp", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("listen: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Debugf("close UDP connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
timeout := c.timeout
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
if remaining := time.Until(deadline); remaining < timeout {
|
||||||
|
timeout = remaining
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
|
||||||
|
return nil, fmt.Errorf("set deadline: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := conn.WriteToUDP(req, addr); err != nil {
|
||||||
|
return nil, fmt.Errorf("write: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := make([]byte, responseBufferSize)
|
||||||
|
n, from, err := conn.ReadFromUDP(resp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RFC 6887 §8.3: Validate response came from expected PCP server.
|
||||||
|
if !from.IP.Equal(addr.IP) {
|
||||||
|
return nil, fmt.Errorf("response from unexpected source %s (expected %s)", from.IP, addr.IP)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp[:n], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) getLocalIP() (netip.Addr, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if !c.localIP.IsValid() {
|
||||||
|
return netip.Addr{}, fmt.Errorf("local IP not set for gateway %s", c.gateway)
|
||||||
|
}
|
||||||
|
return c.localIP, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func protocolNumber(protocol string) (uint8, error) {
|
||||||
|
switch protocol {
|
||||||
|
case "udp", "UDP":
|
||||||
|
return ProtoUDP, nil
|
||||||
|
case "tcp", "TCP":
|
||||||
|
return ProtoTCP, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error represents a PCP error response.
|
||||||
|
type Error struct {
|
||||||
|
Code uint8
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Error) Error() string {
|
||||||
|
return fmt.Sprintf("PCP error: %s (%d)", e.Message, e.Code)
|
||||||
|
}
|
||||||
187
client/internal/portforward/pcp/client_test.go
Normal file
187
client/internal/portforward/pcp/client_test.go
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
package pcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAddrConversion(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addr netip.Addr
|
||||||
|
}{
|
||||||
|
{"IPv4", netip.MustParseAddr("192.168.1.100")},
|
||||||
|
{"IPv4 loopback", netip.MustParseAddr("127.0.0.1")},
|
||||||
|
{"IPv6", netip.MustParseAddr("2001:db8::1")},
|
||||||
|
{"IPv6 loopback", netip.MustParseAddr("::1")},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
b16 := addrTo16(tt.addr)
|
||||||
|
|
||||||
|
recovered := addrFrom16(b16)
|
||||||
|
assert.Equal(t, tt.addr, recovered, "address should round-trip")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAnnounceRequest(t *testing.T) {
|
||||||
|
clientIP := netip.MustParseAddr("192.168.1.100")
|
||||||
|
req := buildAnnounceRequest(clientIP)
|
||||||
|
|
||||||
|
require.Len(t, req, headerSize)
|
||||||
|
assert.Equal(t, byte(Version), req[0], "version")
|
||||||
|
assert.Equal(t, byte(OpAnnounce), req[1], "opcode")
|
||||||
|
|
||||||
|
// Check client IP is properly encoded as IPv4-mapped IPv6
|
||||||
|
assert.Equal(t, byte(0xff), req[18], "IPv4-mapped prefix byte 10")
|
||||||
|
assert.Equal(t, byte(0xff), req[19], "IPv4-mapped prefix byte 11")
|
||||||
|
assert.Equal(t, byte(192), req[20], "IP octet 1")
|
||||||
|
assert.Equal(t, byte(168), req[21], "IP octet 2")
|
||||||
|
assert.Equal(t, byte(1), req[22], "IP octet 3")
|
||||||
|
assert.Equal(t, byte(100), req[23], "IP octet 4")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildMapRequest(t *testing.T) {
|
||||||
|
clientIP := netip.MustParseAddr("192.168.1.100")
|
||||||
|
nonce := [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}
|
||||||
|
req := buildMapRequest(clientIP, nonce, ProtoUDP, 51820, 51820, netip.Addr{}, 3600)
|
||||||
|
|
||||||
|
require.Len(t, req, mapRequestSize)
|
||||||
|
assert.Equal(t, byte(Version), req[0], "version")
|
||||||
|
assert.Equal(t, byte(OpMap), req[1], "opcode")
|
||||||
|
|
||||||
|
// Lifetime at bytes 4-7
|
||||||
|
assert.Equal(t, uint32(3600), (uint32(req[4])<<24)|(uint32(req[5])<<16)|(uint32(req[6])<<8)|uint32(req[7]), "lifetime")
|
||||||
|
|
||||||
|
// Nonce at bytes 24-35
|
||||||
|
assert.Equal(t, nonce[:], req[24:36], "nonce")
|
||||||
|
|
||||||
|
// Protocol at byte 36
|
||||||
|
assert.Equal(t, byte(ProtoUDP), req[36], "protocol")
|
||||||
|
|
||||||
|
// Internal port at bytes 40-41
|
||||||
|
assert.Equal(t, uint16(51820), (uint16(req[40])<<8)|uint16(req[41]), "internal port")
|
||||||
|
|
||||||
|
// External port at bytes 42-43
|
||||||
|
assert.Equal(t, uint16(51820), (uint16(req[42])<<8)|uint16(req[43]), "external port")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseResponse(t *testing.T) {
|
||||||
|
// Construct a valid ANNOUNCE response
|
||||||
|
resp := make([]byte, headerSize)
|
||||||
|
resp[0] = Version
|
||||||
|
resp[1] = OpAnnounce | OpReply
|
||||||
|
// Result code = 0 (success)
|
||||||
|
// Lifetime = 0
|
||||||
|
// Epoch = 12345
|
||||||
|
resp[8] = 0
|
||||||
|
resp[9] = 0
|
||||||
|
resp[10] = 0x30
|
||||||
|
resp[11] = 0x39
|
||||||
|
|
||||||
|
parsed, err := parseResponse(resp)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, uint8(Version), parsed.Version)
|
||||||
|
assert.Equal(t, uint8(OpAnnounce|OpReply), parsed.Opcode)
|
||||||
|
assert.Equal(t, uint8(ResultSuccess), parsed.ResultCode)
|
||||||
|
assert.Equal(t, uint32(12345), parsed.Epoch)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseResponseErrors(t *testing.T) {
|
||||||
|
t.Run("too short", func(t *testing.T) {
|
||||||
|
_, err := parseResponse([]byte{1, 2, 3})
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("wrong version", func(t *testing.T) {
|
||||||
|
resp := make([]byte, headerSize)
|
||||||
|
resp[0] = 1 // Wrong version
|
||||||
|
resp[1] = OpReply
|
||||||
|
_, err := parseResponse(resp)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing reply bit", func(t *testing.T) {
|
||||||
|
resp := make([]byte, headerSize)
|
||||||
|
resp[0] = Version
|
||||||
|
resp[1] = OpAnnounce // Missing OpReply bit
|
||||||
|
_, err := parseResponse(resp)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResultCodeString(t *testing.T) {
|
||||||
|
assert.Equal(t, "SUCCESS", ResultCodeString(ResultSuccess))
|
||||||
|
assert.Equal(t, "NOT_AUTHORIZED", ResultCodeString(ResultNotAuthorized))
|
||||||
|
assert.Equal(t, "ADDRESS_MISMATCH", ResultCodeString(ResultAddressMismatch))
|
||||||
|
assert.Contains(t, ResultCodeString(255), "UNKNOWN")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProtocolNumber(t *testing.T) {
|
||||||
|
proto, err := protocolNumber("udp")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, uint8(ProtoUDP), proto)
|
||||||
|
|
||||||
|
proto, err = protocolNumber("tcp")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, uint8(ProtoTCP), proto)
|
||||||
|
|
||||||
|
proto, err = protocolNumber("UDP")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, uint8(ProtoUDP), proto)
|
||||||
|
|
||||||
|
_, err = protocolNumber("icmp")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientCreation(t *testing.T) {
|
||||||
|
gateway := netip.MustParseAddr("192.168.1.1").AsSlice()
|
||||||
|
|
||||||
|
client := NewClient(gateway)
|
||||||
|
assert.Equal(t, net.IP(gateway), client.Gateway())
|
||||||
|
assert.Equal(t, defaultTimeout, client.timeout)
|
||||||
|
|
||||||
|
clientWithTimeout := NewClientWithTimeout(gateway, 5*time.Second)
|
||||||
|
assert.Equal(t, 5*time.Second, clientWithTimeout.timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNATType(t *testing.T) {
|
||||||
|
n := NewNAT(netip.MustParseAddr("192.168.1.1").AsSlice(), netip.MustParseAddr("192.168.1.100").AsSlice())
|
||||||
|
assert.Equal(t, "PCP", n.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Integration test - skipped unless PCP_TEST_GATEWAY env is set
|
||||||
|
func TestClientIntegration(t *testing.T) {
|
||||||
|
t.Skip("Integration test - run manually with PCP_TEST_GATEWAY=<gateway-ip>")
|
||||||
|
|
||||||
|
gateway := netip.MustParseAddr("10.0.1.1").AsSlice() // Change to your test gateway
|
||||||
|
localIP := netip.MustParseAddr("10.0.1.100").AsSlice() // Change to your local IP
|
||||||
|
|
||||||
|
client := NewClient(gateway)
|
||||||
|
client.SetLocalIP(localIP)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Test ANNOUNCE
|
||||||
|
epoch, err := client.Announce(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Logf("Server epoch: %d", epoch)
|
||||||
|
|
||||||
|
// Test MAP
|
||||||
|
resp, err := client.AddPortMapping(ctx, "udp", 51820, 1*time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Logf("Mapping: internal=%d external=%d externalIP=%s",
|
||||||
|
resp.InternalPort, resp.ExternalPort, resp.ExternalIP)
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
err = client.DeletePortMapping(ctx, "udp", 51820)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
209
client/internal/portforward/pcp/nat.go
Normal file
209
client/internal/portforward/pcp/nat.go
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
package pcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/libp2p/go-nat"
|
||||||
|
"github.com/libp2p/go-netroute"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ nat.NAT = (*NAT)(nil)
|
||||||
|
|
||||||
|
// NAT implements the go-nat NAT interface using PCP.
|
||||||
|
// Supports dual-stack (IPv4 and IPv6) when available.
|
||||||
|
// All methods are safe for concurrent use.
|
||||||
|
//
|
||||||
|
// TODO: IPv6 pinholes use the local IPv6 address. If the address changes
|
||||||
|
// (e.g., due to SLAAC rotation or network change), the pinhole becomes stale
|
||||||
|
// and needs to be recreated with the new address.
|
||||||
|
type NAT struct {
|
||||||
|
client *Client
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
// client6 is the IPv6 PCP client, nil if IPv6 is unavailable.
|
||||||
|
client6 *Client
|
||||||
|
// localIP6 caches the local IPv6 address used for PCP requests.
|
||||||
|
localIP6 netip.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewNAT creates a new NAT instance backed by PCP.
|
||||||
|
func NewNAT(gateway, localIP net.IP) *NAT {
|
||||||
|
client := NewClient(gateway)
|
||||||
|
client.SetLocalIP(localIP)
|
||||||
|
return &NAT{
|
||||||
|
client: client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type returns "PCP" as the NAT type.
|
||||||
|
func (n *NAT) Type() string {
|
||||||
|
return "PCP"
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceAddress returns the gateway IP address.
|
||||||
|
func (n *NAT) GetDeviceAddress() (net.IP, error) {
|
||||||
|
return n.client.Gateway(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetExternalAddress returns the external IP address.
|
||||||
|
func (n *NAT) GetExternalAddress() (net.IP, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
return n.client.GetExternalAddress(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetInternalAddress returns the local IP address used to communicate with the gateway.
|
||||||
|
func (n *NAT) GetInternalAddress() (net.IP, error) {
|
||||||
|
addr, err := n.client.getLocalIP()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return addr.AsSlice(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddPortMapping creates a port mapping on both IPv4 and IPv6 (if available).
|
||||||
|
func (n *NAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, _ string, timeout time.Duration) (int, error) {
|
||||||
|
resp, err := n.client.AddPortMapping(ctx, protocol, internalPort, timeout)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("add mapping: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n.mu.RLock()
|
||||||
|
client6 := n.client6
|
||||||
|
localIP6 := n.localIP6
|
||||||
|
n.mu.RUnlock()
|
||||||
|
|
||||||
|
if client6 == nil {
|
||||||
|
return int(resp.ExternalPort), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := client6.AddPortMapping(ctx, protocol, internalPort, timeout); err != nil {
|
||||||
|
log.Warnf("IPv6 PCP mapping failed (continuing with IPv4): %v", err)
|
||||||
|
return int(resp.ExternalPort), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("created IPv6 PCP pinhole: %s:%d", localIP6, internalPort)
|
||||||
|
return int(resp.ExternalPort), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePortMapping removes a port mapping from both IPv4 and IPv6.
|
||||||
|
func (n *NAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error {
|
||||||
|
err := n.client.DeletePortMapping(ctx, protocol, internalPort)
|
||||||
|
|
||||||
|
n.mu.RLock()
|
||||||
|
client6 := n.client6
|
||||||
|
n.mu.RUnlock()
|
||||||
|
|
||||||
|
if client6 != nil {
|
||||||
|
if err6 := client6.DeletePortMapping(ctx, protocol, internalPort); err6 != nil {
|
||||||
|
log.Warnf("IPv6 PCP delete mapping failed: %v", err6)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("delete mapping: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckServerHealth sends an ANNOUNCE to verify the server is still responsive.
|
||||||
|
// Returns the current epoch and whether the server may have restarted (epoch state loss detected).
|
||||||
|
func (n *NAT) CheckServerHealth(ctx context.Context) (epoch uint32, serverRestarted bool, err error) {
|
||||||
|
epoch, err = n.client.Announce(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, false, fmt.Errorf("announce: %w", err)
|
||||||
|
}
|
||||||
|
return epoch, n.client.EpochStateLost(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DiscoverPCP attempts to discover a PCP-capable gateway.
|
||||||
|
// Returns a NAT interface if PCP is supported, or an error otherwise.
|
||||||
|
// Discovers both IPv4 and IPv6 gateways when available.
|
||||||
|
func DiscoverPCP(ctx context.Context) (nat.NAT, error) {
|
||||||
|
gateway, localIP, err := getDefaultGateway()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get default gateway: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := NewClient(gateway)
|
||||||
|
client.SetLocalIP(localIP)
|
||||||
|
if _, err := client.Announce(ctx); err != nil {
|
||||||
|
return nil, fmt.Errorf("PCP announce: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := &NAT{client: client}
|
||||||
|
discoverIPv6(ctx, result)
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func discoverIPv6(ctx context.Context, result *NAT) {
|
||||||
|
gateway6, localIP6, err := getDefaultGateway6()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("IPv6 gateway discovery failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client6 := NewClient(gateway6)
|
||||||
|
client6.SetLocalIP(localIP6)
|
||||||
|
if _, err := client6.Announce(ctx); err != nil {
|
||||||
|
log.Debugf("PCP IPv6 announce failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, ok := netip.AddrFromSlice(localIP6)
|
||||||
|
if !ok {
|
||||||
|
log.Debugf("invalid IPv6 local IP: %v", localIP6)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result.mu.Lock()
|
||||||
|
result.client6 = client6
|
||||||
|
result.localIP6 = addr
|
||||||
|
result.mu.Unlock()
|
||||||
|
log.Debugf("PCP IPv6 gateway discovered: %s (local: %s)", gateway6, localIP6)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDefaultGateway returns the default IPv4 gateway and local IP using the system routing table.
|
||||||
|
func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) {
|
||||||
|
router, err := netroute.New()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, gateway, localIP, err = router.Route(net.IPv4zero)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if gateway == nil {
|
||||||
|
return nil, nil, nat.ErrNoNATFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return gateway, localIP, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDefaultGateway6 returns the default IPv6 gateway IP address using the system routing table.
|
||||||
|
func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) {
|
||||||
|
router, err := netroute.New()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, gateway, localIP, err = router.Route(net.IPv6zero)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if gateway == nil {
|
||||||
|
return nil, nil, nat.ErrNoNATFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return gateway, localIP, nil
|
||||||
|
}
|
||||||
225
client/internal/portforward/pcp/protocol.go
Normal file
225
client/internal/portforward/pcp/protocol.go
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
// Package pcp implements the Port Control Protocol (RFC 6887).
|
||||||
|
//
|
||||||
|
// # Implemented Features
|
||||||
|
//
|
||||||
|
// - ANNOUNCE opcode: Discovers PCP server support
|
||||||
|
// - MAP opcode: Creates/deletes port mappings (IPv4 NAT) and firewall pinholes (IPv6)
|
||||||
|
// - Dual-stack: Simultaneous IPv4 and IPv6 support via separate clients
|
||||||
|
// - Nonce validation: Prevents response spoofing
|
||||||
|
// - Epoch tracking: Detects server restarts per Section 8.5
|
||||||
|
// - RFC-compliant retry timing: 3s initial, exponential backoff to 1024s max (Section 8.1.1)
|
||||||
|
//
|
||||||
|
// # Not Implemented
|
||||||
|
//
|
||||||
|
// - PEER opcode: For outbound peer connections (not needed for inbound NAT traversal)
|
||||||
|
// - THIRD_PARTY option: For managing mappings on behalf of other devices
|
||||||
|
// - PREFER_FAILURE option: Requires exact external port or fail (IPv4 NAT only, not needed for IPv6 pinholing)
|
||||||
|
// - FILTER option: To restrict remote peer addresses
|
||||||
|
//
|
||||||
|
// These optional features are omitted because the primary use case is simple
|
||||||
|
// port forwarding for WireGuard, which only requires MAP with default behavior.
|
||||||
|
package pcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Version is the PCP protocol version (RFC 6887).
|
||||||
|
Version = 2
|
||||||
|
|
||||||
|
// Port is the standard PCP server port.
|
||||||
|
Port = 5351
|
||||||
|
|
||||||
|
// DefaultLifetime is the default requested mapping lifetime in seconds.
|
||||||
|
DefaultLifetime = 7200 // 2 hours
|
||||||
|
|
||||||
|
// Header sizes
|
||||||
|
headerSize = 24
|
||||||
|
mapPayloadSize = 36
|
||||||
|
mapRequestSize = headerSize + mapPayloadSize // 60 bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
// Opcodes
|
||||||
|
const (
|
||||||
|
OpAnnounce = 0
|
||||||
|
OpMap = 1
|
||||||
|
OpPeer = 2
|
||||||
|
OpReply = 0x80 // OR'd with opcode in responses
|
||||||
|
)
|
||||||
|
|
||||||
|
// Protocol numbers for MAP requests
|
||||||
|
const (
|
||||||
|
ProtoUDP = 17
|
||||||
|
ProtoTCP = 6
|
||||||
|
)
|
||||||
|
|
||||||
|
// Result codes (RFC 6887 Section 7.4)
|
||||||
|
const (
|
||||||
|
ResultSuccess = 0
|
||||||
|
ResultUnsuppVersion = 1
|
||||||
|
ResultNotAuthorized = 2
|
||||||
|
ResultMalformedRequest = 3
|
||||||
|
ResultUnsuppOpcode = 4
|
||||||
|
ResultUnsuppOption = 5
|
||||||
|
ResultMalformedOption = 6
|
||||||
|
ResultNetworkFailure = 7
|
||||||
|
ResultNoResources = 8
|
||||||
|
ResultUnsuppProtocol = 9
|
||||||
|
ResultUserExQuota = 10
|
||||||
|
ResultCannotProvideExt = 11
|
||||||
|
ResultAddressMismatch = 12
|
||||||
|
ResultExcessiveRemotePeers = 13
|
||||||
|
)
|
||||||
|
|
||||||
|
// ResultCodeString returns a human-readable string for a result code.
|
||||||
|
func ResultCodeString(code uint8) string {
|
||||||
|
switch code {
|
||||||
|
case ResultSuccess:
|
||||||
|
return "SUCCESS"
|
||||||
|
case ResultUnsuppVersion:
|
||||||
|
return "UNSUPP_VERSION"
|
||||||
|
case ResultNotAuthorized:
|
||||||
|
return "NOT_AUTHORIZED"
|
||||||
|
case ResultMalformedRequest:
|
||||||
|
return "MALFORMED_REQUEST"
|
||||||
|
case ResultUnsuppOpcode:
|
||||||
|
return "UNSUPP_OPCODE"
|
||||||
|
case ResultUnsuppOption:
|
||||||
|
return "UNSUPP_OPTION"
|
||||||
|
case ResultMalformedOption:
|
||||||
|
return "MALFORMED_OPTION"
|
||||||
|
case ResultNetworkFailure:
|
||||||
|
return "NETWORK_FAILURE"
|
||||||
|
case ResultNoResources:
|
||||||
|
return "NO_RESOURCES"
|
||||||
|
case ResultUnsuppProtocol:
|
||||||
|
return "UNSUPP_PROTOCOL"
|
||||||
|
case ResultUserExQuota:
|
||||||
|
return "USER_EX_QUOTA"
|
||||||
|
case ResultCannotProvideExt:
|
||||||
|
return "CANNOT_PROVIDE_EXTERNAL"
|
||||||
|
case ResultAddressMismatch:
|
||||||
|
return "ADDRESS_MISMATCH"
|
||||||
|
case ResultExcessiveRemotePeers:
|
||||||
|
return "EXCESSIVE_REMOTE_PEERS"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("UNKNOWN(%d)", code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Response represents a parsed PCP response header.
|
||||||
|
type Response struct {
|
||||||
|
Version uint8
|
||||||
|
Opcode uint8
|
||||||
|
ResultCode uint8
|
||||||
|
Lifetime uint32
|
||||||
|
Epoch uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapResponse contains the full response to a MAP request.
|
||||||
|
type MapResponse struct {
|
||||||
|
Response
|
||||||
|
Nonce [12]byte
|
||||||
|
Protocol uint8
|
||||||
|
InternalPort uint16
|
||||||
|
ExternalPort uint16
|
||||||
|
ExternalIP netip.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// addrTo16 converts an address to its 16-byte IPv4-mapped IPv6 representation.
|
||||||
|
func addrTo16(addr netip.Addr) [16]byte {
|
||||||
|
if addr.Is4() {
|
||||||
|
return netip.AddrFrom4(addr.As4()).As16()
|
||||||
|
}
|
||||||
|
return addr.As16()
|
||||||
|
}
|
||||||
|
|
||||||
|
// addrFrom16 extracts an address from a 16-byte representation, unmapping IPv4.
|
||||||
|
func addrFrom16(b [16]byte) netip.Addr {
|
||||||
|
return netip.AddrFrom16(b).Unmap()
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildAnnounceRequest creates a PCP ANNOUNCE request packet.
|
||||||
|
func buildAnnounceRequest(clientIP netip.Addr) []byte {
|
||||||
|
req := make([]byte, headerSize)
|
||||||
|
req[0] = Version
|
||||||
|
req[1] = OpAnnounce
|
||||||
|
mapped := addrTo16(clientIP)
|
||||||
|
copy(req[8:24], mapped[:])
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildMapRequest creates a PCP MAP request packet.
|
||||||
|
func buildMapRequest(clientIP netip.Addr, nonce [12]byte, protocol uint8, internalPort, suggestedExtPort uint16, suggestedExtIP netip.Addr, lifetime uint32) []byte {
|
||||||
|
req := make([]byte, mapRequestSize)
|
||||||
|
|
||||||
|
// Header
|
||||||
|
req[0] = Version
|
||||||
|
req[1] = OpMap
|
||||||
|
binary.BigEndian.PutUint32(req[4:8], lifetime)
|
||||||
|
mapped := addrTo16(clientIP)
|
||||||
|
copy(req[8:24], mapped[:])
|
||||||
|
|
||||||
|
// MAP payload
|
||||||
|
copy(req[24:36], nonce[:])
|
||||||
|
req[36] = protocol
|
||||||
|
binary.BigEndian.PutUint16(req[40:42], internalPort)
|
||||||
|
binary.BigEndian.PutUint16(req[42:44], suggestedExtPort)
|
||||||
|
if suggestedExtIP.IsValid() {
|
||||||
|
extMapped := addrTo16(suggestedExtIP)
|
||||||
|
copy(req[44:60], extMapped[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseResponse parses the common PCP response header.
|
||||||
|
func parseResponse(data []byte) (*Response, error) {
|
||||||
|
if len(data) < headerSize {
|
||||||
|
return nil, fmt.Errorf("response too short: %d bytes", len(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &Response{
|
||||||
|
Version: data[0],
|
||||||
|
Opcode: data[1],
|
||||||
|
ResultCode: data[3], // Byte 2 is reserved, byte 3 is result code (RFC 6887 §7.2)
|
||||||
|
Lifetime: binary.BigEndian.Uint32(data[4:8]),
|
||||||
|
Epoch: binary.BigEndian.Uint32(data[8:12]),
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Version != Version {
|
||||||
|
return nil, fmt.Errorf("unsupported PCP version: %d", resp.Version)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Opcode&OpReply == 0 {
|
||||||
|
return nil, fmt.Errorf("response missing reply bit: opcode=0x%02x", resp.Opcode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseMapResponse parses a complete MAP response.
|
||||||
|
func parseMapResponse(data []byte) (*MapResponse, error) {
|
||||||
|
if len(data) < mapRequestSize {
|
||||||
|
return nil, fmt.Errorf("MAP response too short: %d bytes", len(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := parseResponse(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mapResp := &MapResponse{
|
||||||
|
Response: *resp,
|
||||||
|
Protocol: data[36],
|
||||||
|
InternalPort: binary.BigEndian.Uint16(data[40:42]),
|
||||||
|
ExternalPort: binary.BigEndian.Uint16(data[42:44]),
|
||||||
|
ExternalIP: addrFrom16([16]byte(data[44:60])),
|
||||||
|
}
|
||||||
|
copy(mapResp.Nonce[:], data[24:36])
|
||||||
|
|
||||||
|
return mapResp, nil
|
||||||
|
}
|
||||||
63
client/internal/portforward/state.go
Normal file
63
client/internal/portforward/state.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
//go:build !js
|
||||||
|
|
||||||
|
package portforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/libp2p/go-nat"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/portforward/pcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// discoverGateway is the function used for NAT gateway discovery.
|
||||||
|
// It can be replaced in tests to avoid real network operations.
|
||||||
|
// Tries PCP first, then falls back to NAT-PMP/UPnP.
|
||||||
|
var discoverGateway = defaultDiscoverGateway
|
||||||
|
|
||||||
|
func defaultDiscoverGateway(ctx context.Context) (nat.NAT, error) {
|
||||||
|
pcpGateway, err := pcp.DiscoverPCP(ctx)
|
||||||
|
if err == nil {
|
||||||
|
return pcpGateway, nil
|
||||||
|
}
|
||||||
|
log.Debugf("PCP discovery failed: %v, trying NAT-PMP/UPnP", err)
|
||||||
|
|
||||||
|
return nat.DiscoverGateway(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// State is persisted only for crash recovery cleanup
|
||||||
|
type State struct {
|
||||||
|
InternalPort uint16 `json:"internal_port,omitempty"`
|
||||||
|
Protocol string `json:"protocol,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) Name() string {
|
||||||
|
return "port_forward_state"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup implements statemanager.CleanableState for crash recovery
|
||||||
|
func (s *State) Cleanup() error {
|
||||||
|
if s.InternalPort == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("cleaning up stale port mapping for port %d", s.InternalPort)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), discoveryTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
gateway, err := discoverGateway(ctx)
|
||||||
|
if err != nil {
|
||||||
|
// Discovery failure is not an error - gateway may not exist
|
||||||
|
log.Debugf("cleanup: no gateway found: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := gateway.DeletePortMapping(ctx, s.Protocol, int(s.InternalPort)); err != nil {
|
||||||
|
return fmt.Errorf("delete port mapping: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -64,11 +64,13 @@ type ConfigInput struct {
|
|||||||
StateFilePath string
|
StateFilePath string
|
||||||
PreSharedKey *string
|
PreSharedKey *string
|
||||||
ServerSSHAllowed *bool
|
ServerSSHAllowed *bool
|
||||||
|
ServerVNCAllowed *bool
|
||||||
EnableSSHRoot *bool
|
EnableSSHRoot *bool
|
||||||
EnableSSHSFTP *bool
|
EnableSSHSFTP *bool
|
||||||
EnableSSHLocalPortForwarding *bool
|
EnableSSHLocalPortForwarding *bool
|
||||||
EnableSSHRemotePortForwarding *bool
|
EnableSSHRemotePortForwarding *bool
|
||||||
DisableSSHAuth *bool
|
DisableSSHAuth *bool
|
||||||
|
DisableVNCAuth *bool
|
||||||
SSHJWTCacheTTL *int
|
SSHJWTCacheTTL *int
|
||||||
NATExternalIPs []string
|
NATExternalIPs []string
|
||||||
CustomDNSAddress []byte
|
CustomDNSAddress []byte
|
||||||
@@ -114,11 +116,13 @@ type Config struct {
|
|||||||
RosenpassEnabled bool
|
RosenpassEnabled bool
|
||||||
RosenpassPermissive bool
|
RosenpassPermissive bool
|
||||||
ServerSSHAllowed *bool
|
ServerSSHAllowed *bool
|
||||||
|
ServerVNCAllowed *bool
|
||||||
EnableSSHRoot *bool
|
EnableSSHRoot *bool
|
||||||
EnableSSHSFTP *bool
|
EnableSSHSFTP *bool
|
||||||
EnableSSHLocalPortForwarding *bool
|
EnableSSHLocalPortForwarding *bool
|
||||||
EnableSSHRemotePortForwarding *bool
|
EnableSSHRemotePortForwarding *bool
|
||||||
DisableSSHAuth *bool
|
DisableSSHAuth *bool
|
||||||
|
DisableVNCAuth *bool
|
||||||
SSHJWTCacheTTL *int
|
SSHJWTCacheTTL *int
|
||||||
|
|
||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
@@ -415,6 +419,21 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.ServerVNCAllowed != nil {
|
||||||
|
if config.ServerVNCAllowed == nil || *input.ServerVNCAllowed != *config.ServerVNCAllowed {
|
||||||
|
if *input.ServerVNCAllowed {
|
||||||
|
log.Infof("enabling VNC server")
|
||||||
|
} else {
|
||||||
|
log.Infof("disabling VNC server")
|
||||||
|
}
|
||||||
|
config.ServerVNCAllowed = input.ServerVNCAllowed
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
} else if config.ServerVNCAllowed == nil {
|
||||||
|
config.ServerVNCAllowed = util.True()
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
||||||
if *input.EnableSSHRoot {
|
if *input.EnableSSHRoot {
|
||||||
log.Infof("enabling SSH root login")
|
log.Infof("enabling SSH root login")
|
||||||
@@ -465,6 +484,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.DisableVNCAuth != nil && input.DisableVNCAuth != config.DisableVNCAuth {
|
||||||
|
if *input.DisableVNCAuth {
|
||||||
|
log.Infof("disabling VNC authentication")
|
||||||
|
} else {
|
||||||
|
log.Infof("enabling VNC authentication")
|
||||||
|
}
|
||||||
|
config.DisableVNCAuth = input.DisableVNCAuth
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
|
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
|
||||||
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
|
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
|
||||||
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
|
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
|
||||||
|
|||||||
@@ -168,6 +168,7 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
|||||||
NetworkType: route.IPv4Network,
|
NetworkType: route.IPv4Network,
|
||||||
}
|
}
|
||||||
cr = append(cr, fakeIPRoute)
|
cr = append(cr, fakeIPRoute)
|
||||||
|
m.notifier.SetFakeIPRoute(fakeIPRoute)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.notifier.SetInitialClientRoutes(cr, routesForComparison)
|
m.notifier.SetInitialClientRoutes(cr, routesForComparison)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
type Notifier struct {
|
type Notifier struct {
|
||||||
initialRoutes []*route.Route
|
initialRoutes []*route.Route
|
||||||
currentRoutes []*route.Route
|
currentRoutes []*route.Route
|
||||||
|
fakeIPRoute *route.Route
|
||||||
|
|
||||||
listener listener.NetworkChangeListener
|
listener listener.NetworkChangeListener
|
||||||
listenerMux sync.Mutex
|
listenerMux sync.Mutex
|
||||||
@@ -31,13 +32,17 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
|||||||
n.listener = listener
|
n.listener = listener
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetInitialClientRoutes stores the full initial route set (including fake IP blocks)
|
// SetInitialClientRoutes stores the initial route sets for TUN configuration.
|
||||||
// and a separate comparison set (without fake IP blocks) for diff detection.
|
|
||||||
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
|
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
|
||||||
n.initialRoutes = filterStatic(initialRoutes)
|
n.initialRoutes = filterStatic(initialRoutes)
|
||||||
n.currentRoutes = filterStatic(routesForComparison)
|
n.currentRoutes = filterStatic(routesForComparison)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFakeIPRoute stores the fake IP route to be included in every TUN rebuild.
|
||||||
|
func (n *Notifier) SetFakeIPRoute(r *route.Route) {
|
||||||
|
n.fakeIPRoute = r
|
||||||
|
}
|
||||||
|
|
||||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||||
var newRoutes []*route.Route
|
var newRoutes []*route.Route
|
||||||
for _, routes := range idMap {
|
for _, routes := range idMap {
|
||||||
@@ -69,7 +74,9 @@ func (n *Notifier) notify() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
allRoutes := slices.Clone(n.currentRoutes)
|
allRoutes := slices.Clone(n.currentRoutes)
|
||||||
allRoutes = append(allRoutes, n.extraInitialRoutes()...)
|
if n.fakeIPRoute != nil {
|
||||||
|
allRoutes = append(allRoutes, n.fakeIPRoute)
|
||||||
|
}
|
||||||
|
|
||||||
routeStrings := n.routesToStrings(allRoutes)
|
routeStrings := n.routesToStrings(allRoutes)
|
||||||
sort.Strings(routeStrings)
|
sort.Strings(routeStrings)
|
||||||
@@ -78,23 +85,6 @@ func (n *Notifier) notify() {
|
|||||||
}(n.listener)
|
}(n.listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// extraInitialRoutes returns initialRoutes whose network prefix is absent
|
|
||||||
// from currentRoutes (e.g. the fake IP block added at setup time).
|
|
||||||
func (n *Notifier) extraInitialRoutes() []*route.Route {
|
|
||||||
currentNets := make(map[netip.Prefix]struct{}, len(n.currentRoutes))
|
|
||||||
for _, r := range n.currentRoutes {
|
|
||||||
currentNets[r.Network] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
var extra []*route.Route
|
|
||||||
for _, r := range n.initialRoutes {
|
|
||||||
if _, ok := currentNets[r.Network]; !ok {
|
|
||||||
extra = append(extra, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return extra
|
|
||||||
}
|
|
||||||
|
|
||||||
func filterStatic(routes []*route.Route) []*route.Route {
|
func filterStatic(routes []*route.Route) []*route.Route {
|
||||||
out := make([]*route.Route, 0, len(routes))
|
out := make([]*route.Route, 0, len(routes))
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
|
|||||||
@@ -34,6 +34,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
|
|||||||
// iOS doesn't care about initial routes
|
// iOS doesn't care about initial routes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) SetFakeIPRoute(*route.Route) {
|
||||||
|
// Not used on iOS
|
||||||
|
}
|
||||||
|
|
||||||
func (n *Notifier) OnNewRoutes(route.HAMap) {
|
func (n *Notifier) OnNewRoutes(route.HAMap) {
|
||||||
// Not used on iOS
|
// Not used on iOS
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,6 +23,10 @@ func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
|
|||||||
// Not used on non-mobile platforms
|
// Not used on non-mobile platforms
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (n *Notifier) SetFakeIPRoute(*route.Route) {
|
||||||
|
// Not used on non-mobile platforms
|
||||||
|
}
|
||||||
|
|
||||||
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||||
// Not used on non-mobile platforms
|
// Not used on non-mobile platforms
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,10 @@
|
|||||||
|
//go:build (dragonfly || freebsd || netbsd || openbsd) && !darwin
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
// Non-darwin BSDs don't support the IP_BOUND_IF + scoped default model. They
|
||||||
|
// always fall through to the ref-counter exclusion-route path; these stubs
|
||||||
|
// exist only so systemops_unix.go compiles.
|
||||||
|
func (r *SysOps) setupAdvancedRouting() error { return nil }
|
||||||
|
func (r *SysOps) cleanupAdvancedRouting() error { return nil }
|
||||||
|
func (r *SysOps) flushPlatformExtras() error { return nil }
|
||||||
241
client/internal/routemanager/systemops/systemops_darwin.go
Normal file
241
client/internal/routemanager/systemops/systemops_darwin.go
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package systemops
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/net/route"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// scopedRouteBudget bounds retries for the scoped default route. Installing or
|
||||||
|
// deleting it matters enough that we're willing to spend longer waiting for the
|
||||||
|
// kernel reply than for per-prefix exclusion routes.
|
||||||
|
const scopedRouteBudget = 5 * time.Second
|
||||||
|
|
||||||
|
// setupAdvancedRouting installs an RTF_IFSCOPE default route per address family
|
||||||
|
// pinned to the current physical egress, so IP_BOUND_IF scoped lookups can
|
||||||
|
// resolve gateway'd destinations while the VPN's split default owns the
|
||||||
|
// unscoped table.
|
||||||
|
//
|
||||||
|
// Timing note: this runs during routeManager.Init, which happens before the
|
||||||
|
// VPN interface is created and before any peer routes propagate. The initial
|
||||||
|
// mgmt / signal / relay TCP dials always fire before this runs, so those
|
||||||
|
// sockets miss the IP_BOUND_IF binding and rely on the kernel's normal route
|
||||||
|
// lookup, which at that point correctly picks the physical default. Those
|
||||||
|
// already-established TCP flows keep their originally-selected interface for
|
||||||
|
// their lifetime on Darwin because the kernel caches the egress route
|
||||||
|
// per-socket at connect time; adding the VPN's 0/1 + 128/1 split default
|
||||||
|
// afterwards does not migrate them since the original en0 default stays in
|
||||||
|
// the table. Any subsequent reconnect via nbnet.NewDialer picks up the
|
||||||
|
// populated bound-iface cache and gets IP_BOUND_IF set cleanly.
|
||||||
|
func (r *SysOps) setupAdvancedRouting() error {
|
||||||
|
// Drop any previously-cached egress interface before reinstalling. On a
|
||||||
|
// refresh, a family that no longer resolves would otherwise keep the stale
|
||||||
|
// binding, causing new sockets to scope to an interface without a matching
|
||||||
|
// scoped default.
|
||||||
|
nbnet.ClearBoundInterfaces()
|
||||||
|
|
||||||
|
if err := r.flushScopedDefaults(); err != nil {
|
||||||
|
log.Warnf("flush residual scoped defaults: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
installed := 0
|
||||||
|
|
||||||
|
for _, unspec := range []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified()} {
|
||||||
|
ok, err := r.installScopedDefaultFor(unspec)
|
||||||
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
installed++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if installed == 0 && merr != nil {
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
if merr != nil {
|
||||||
|
log.Warnf("advanced routing setup partially succeeded: %v", nberrors.FormatErrorOrNil(merr))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// installScopedDefaultFor resolves the physical default nexthop for the given
|
||||||
|
// address family, installs a scoped default via it, and caches the iface for
|
||||||
|
// subsequent IP_BOUND_IF / IPV6_BOUND_IF socket binds.
|
||||||
|
func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
|
||||||
|
nexthop, err := GetNextHop(unspec)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, vars.ErrRouteNotFound) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, fmt.Errorf("get default nexthop for %s: %w", unspec, err)
|
||||||
|
}
|
||||||
|
if nexthop.Intf == nil {
|
||||||
|
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.addScopedDefault(unspec, nexthop); err != nil {
|
||||||
|
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
af := unix.AF_INET
|
||||||
|
if unspec.Is6() {
|
||||||
|
af = unix.AF_INET6
|
||||||
|
}
|
||||||
|
nbnet.SetBoundInterface(af, nexthop.Intf)
|
||||||
|
via := "point-to-point"
|
||||||
|
if nexthop.IP.IsValid() {
|
||||||
|
via = nexthop.IP.String()
|
||||||
|
}
|
||||||
|
log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec))
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) cleanupAdvancedRouting() error {
|
||||||
|
nbnet.ClearBoundInterfaces()
|
||||||
|
return r.flushScopedDefaults()
|
||||||
|
}
|
||||||
|
|
||||||
|
// flushPlatformExtras runs darwin-specific residual cleanup hooked into the
|
||||||
|
// generic FlushMarkedRoutes path, so a crashed daemon's scoped defaults get
|
||||||
|
// removed on the next boot regardless of whether a profile is brought up.
|
||||||
|
func (r *SysOps) flushPlatformExtras() error {
|
||||||
|
return r.flushScopedDefaults()
|
||||||
|
}
|
||||||
|
|
||||||
|
// flushScopedDefaults removes any scoped default routes tagged with routeProtoFlag.
|
||||||
|
// Safe to call at startup to clear residual entries from a prior session.
|
||||||
|
func (r *SysOps) flushScopedDefaults() error {
|
||||||
|
rib, err := retryFetchRIB()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("fetch routing table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse routing table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
removed := 0
|
||||||
|
|
||||||
|
for _, msg := range msgs {
|
||||||
|
rtMsg, ok := msg.(*route.RouteMessage)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if rtMsg.Flags&routeProtoFlag == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if rtMsg.Flags&unix.RTF_IFSCOPE == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := MsgToRoute(rtMsg)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("skip scoped flush: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !info.Dst.IsValid() || info.Dst.Bits() != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.deleteScopedRoute(rtMsg); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete scoped default %s on index %d: %w",
|
||||||
|
info.Dst, rtMsg.Index, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
removed++
|
||||||
|
log.Debugf("flushed residual scoped default %s on index %d", info.Dst, rtMsg.Index)
|
||||||
|
}
|
||||||
|
|
||||||
|
if removed > 0 {
|
||||||
|
log.Infof("flushed %d residual scoped default route(s)", removed)
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) addScopedDefault(unspec netip.Addr, nexthop Nexthop) error {
|
||||||
|
return r.scopedRouteSocket(unix.RTM_ADD, unspec, nexthop)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) deleteScopedRoute(rtMsg *route.RouteMessage) error {
|
||||||
|
// Preserve identifying flags from the stored route (including RTF_GATEWAY
|
||||||
|
// only if present); kernel-set bits like RTF_DONE don't belong on RTM_DELETE.
|
||||||
|
keep := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY | unix.RTF_IFSCOPE | routeProtoFlag
|
||||||
|
del := &route.RouteMessage{
|
||||||
|
Type: unix.RTM_DELETE,
|
||||||
|
Flags: rtMsg.Flags & keep,
|
||||||
|
Version: unix.RTM_VERSION,
|
||||||
|
Seq: r.getSeq(),
|
||||||
|
Index: rtMsg.Index,
|
||||||
|
Addrs: rtMsg.Addrs,
|
||||||
|
}
|
||||||
|
return r.writeRouteMessage(del, scopedRouteBudget)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) scopedRouteSocket(action int, unspec netip.Addr, nexthop Nexthop) error {
|
||||||
|
flags := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_IFSCOPE | routeProtoFlag
|
||||||
|
|
||||||
|
msg := &route.RouteMessage{
|
||||||
|
Type: action,
|
||||||
|
Flags: flags,
|
||||||
|
Version: unix.RTM_VERSION,
|
||||||
|
ID: uintptr(os.Getpid()),
|
||||||
|
Seq: r.getSeq(),
|
||||||
|
Index: nexthop.Intf.Index,
|
||||||
|
}
|
||||||
|
|
||||||
|
const numAddrs = unix.RTAX_NETMASK + 1
|
||||||
|
addrs := make([]route.Addr, numAddrs)
|
||||||
|
|
||||||
|
dst, err := addrToRouteAddr(unspec)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("build destination: %w", err)
|
||||||
|
}
|
||||||
|
mask, err := prefixToRouteNetmask(netip.PrefixFrom(unspec, 0))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("build netmask: %w", err)
|
||||||
|
}
|
||||||
|
addrs[unix.RTAX_DST] = dst
|
||||||
|
addrs[unix.RTAX_NETMASK] = mask
|
||||||
|
|
||||||
|
if nexthop.IP.IsValid() {
|
||||||
|
msg.Flags |= unix.RTF_GATEWAY
|
||||||
|
gw, err := addrToRouteAddr(nexthop.IP.Unmap())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("build gateway: %w", err)
|
||||||
|
}
|
||||||
|
addrs[unix.RTAX_GATEWAY] = gw
|
||||||
|
} else {
|
||||||
|
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
|
||||||
|
Index: nexthop.Intf.Index,
|
||||||
|
Name: nexthop.Intf.Name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
msg.Addrs = addrs
|
||||||
|
|
||||||
|
return r.writeRouteMessage(msg, scopedRouteBudget)
|
||||||
|
}
|
||||||
|
|
||||||
|
func afOf(a netip.Addr) string {
|
||||||
|
if a.Is4() {
|
||||||
|
return "IPv4"
|
||||||
|
}
|
||||||
|
return "IPv6"
|
||||||
|
}
|
||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/client/net/hooks"
|
"github.com/netbirdio/netbird/client/net/hooks"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,8 +32,6 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
|||||||
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
||||||
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
||||||
|
|
||||||
var ErrRoutingIsSeparate = errors.New("routing is separate")
|
|
||||||
|
|
||||||
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||||
stateManager.RegisterState(&ShutdownState{})
|
stateManager.RegisterState(&ShutdownState{})
|
||||||
|
|
||||||
@@ -397,12 +396,16 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
|
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
|
||||||
|
// When advanced routing is active the WG socket is bound to the physical interface (fwmark on linux,
|
||||||
|
// IP_UNICAST_IF on windows, IP_BOUND_IF on darwin) and bypasses the main routing table, so the check is skipped.
|
||||||
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
|
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
|
||||||
localRoutes, err := hasSeparateRouting()
|
if nbnet.AdvancedRouting() {
|
||||||
if err != nil {
|
return false, netip.Prefix{}
|
||||||
if !errors.Is(err, ErrRoutingIsSeparate) {
|
|
||||||
log.Errorf("Failed to get routes: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
localRoutes, err := GetRoutesFromTable()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to get routes: %v", err)
|
||||||
return false, netip.Prefix{}
|
return false, netip.Prefix{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -22,10 +22,6 @@ func GetRoutesFromTable() ([]netip.Prefix, error) {
|
|||||||
return []netip.Prefix{}, nil
|
return []netip.Prefix{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
|
||||||
return []netip.Prefix{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDetailedRoutesFromTable returns empty routes for WASM.
|
// GetDetailedRoutesFromTable returns empty routes for WASM.
|
||||||
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
|
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
|
||||||
return []DetailedRoute{}, nil
|
return []DetailedRoute{}, nil
|
||||||
|
|||||||
@@ -894,13 +894,6 @@ func getAddressFamily(prefix netip.Prefix) int {
|
|||||||
return netlink.FAMILY_V6
|
return netlink.FAMILY_V6
|
||||||
}
|
}
|
||||||
|
|
||||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
|
||||||
if !nbnet.AdvancedRouting() {
|
|
||||||
return GetRoutesFromTable()
|
|
||||||
}
|
|
||||||
return nil, ErrRoutingIsSeparate
|
|
||||||
}
|
|
||||||
|
|
||||||
func isOpErr(err error) bool {
|
func isOpErr(err error) bool {
|
||||||
// EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported
|
// EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported
|
||||||
if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) {
|
if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) {
|
||||||
|
|||||||
@@ -48,10 +48,6 @@ func EnableIPForwarding() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
|
||||||
return GetRoutesFromTable()
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetIPRules returns IP rules for debugging (not supported on non-Linux platforms)
|
// GetIPRules returns IP rules for debugging (not supported on non-Linux platforms)
|
||||||
func GetIPRules() ([]IPRule, error) {
|
func GetIPRules() ([]IPRule, error) {
|
||||||
log.Infof("IP rules collection is not supported on %s", runtime.GOOS)
|
log.Infof("IP rules collection is not supported on %s", runtime.GOOS)
|
||||||
|
|||||||
@@ -25,6 +25,9 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
|
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
|
||||||
|
|
||||||
|
// routeBudget bounds retries for per-prefix exclusion route programming.
|
||||||
|
routeBudget = 1 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
var routeProtoFlag int
|
var routeProtoFlag int
|
||||||
@@ -41,26 +44,42 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||||
|
if advancedRouting {
|
||||||
|
return r.setupAdvancedRouting()
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Using legacy routing setup with ref counters")
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
|
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||||
|
if advancedRouting {
|
||||||
|
return r.cleanupAdvancedRouting()
|
||||||
|
}
|
||||||
|
|
||||||
return r.cleanupRefCounter(stateManager)
|
return r.cleanupRefCounter(stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
|
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
|
||||||
|
// On darwin it also flushes residual RTF_IFSCOPE scoped default routes so a
|
||||||
|
// crashed prior session can't leave crud in the table.
|
||||||
func (r *SysOps) FlushMarkedRoutes() error {
|
func (r *SysOps) FlushMarkedRoutes() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := r.flushPlatformExtras(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("flush platform extras: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
rib, err := retryFetchRIB()
|
rib, err := retryFetchRIB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("fetch routing table: %w", err)
|
return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("fetch routing table: %w", err)))
|
||||||
}
|
}
|
||||||
|
|
||||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parse routing table: %w", err)
|
return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("parse routing table: %w", err)))
|
||||||
}
|
}
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
flushedCount := 0
|
flushedCount := 0
|
||||||
|
|
||||||
for _, msg := range msgs {
|
for _, msg := range msgs {
|
||||||
@@ -117,12 +136,12 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
|
|||||||
return fmt.Errorf("invalid prefix: %s", prefix)
|
return fmt.Errorf("invalid prefix: %s", prefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
expBackOff := backoff.NewExponentialBackOff()
|
msg, err := r.buildRouteMessage(action, prefix, nexthop)
|
||||||
expBackOff.InitialInterval = 50 * time.Millisecond
|
if err != nil {
|
||||||
expBackOff.MaxInterval = 500 * time.Millisecond
|
return fmt.Errorf("build route message: %w", err)
|
||||||
expBackOff.MaxElapsedTime = 1 * time.Second
|
}
|
||||||
|
|
||||||
if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil {
|
if err := r.writeRouteMessage(msg, routeBudget); err != nil {
|
||||||
a := "add"
|
a := "add"
|
||||||
if action == unix.RTM_DELETE {
|
if action == unix.RTM_DELETE {
|
||||||
a = "remove"
|
a = "remove"
|
||||||
@@ -132,50 +151,91 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error {
|
// writeRouteMessage sends a route message over AF_ROUTE and waits for the
|
||||||
operation := func() error {
|
// kernel's matching reply, retrying transient failures until budget elapses.
|
||||||
|
// Callers do not need to manage sockets or seq numbers themselves.
|
||||||
|
func (r *SysOps) writeRouteMessage(msg *route.RouteMessage, budget time.Duration) error {
|
||||||
|
expBackOff := backoff.NewExponentialBackOff()
|
||||||
|
expBackOff.InitialInterval = 50 * time.Millisecond
|
||||||
|
expBackOff.MaxInterval = 500 * time.Millisecond
|
||||||
|
expBackOff.MaxElapsedTime = budget
|
||||||
|
|
||||||
|
return backoff.Retry(func() error { return routeMessageRoundtrip(msg) }, expBackOff)
|
||||||
|
}
|
||||||
|
|
||||||
|
func routeMessageRoundtrip(msg *route.RouteMessage) error {
|
||||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("open routing socket: %w", err)
|
return fmt.Errorf("open routing socket: %w", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
|
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
|
||||||
log.Warnf("failed to close routing socket: %v", err)
|
log.Warnf("close routing socket: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
msg, err := r.buildRouteMessage(action, prefix, nexthop)
|
tv := unix.Timeval{Sec: 1}
|
||||||
if err != nil {
|
if err := unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil {
|
||||||
return backoff.Permanent(fmt.Errorf("build route message: %w", err))
|
return backoff.Permanent(fmt.Errorf("set recv timeout: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
msgBytes, err := msg.Marshal()
|
// AF_ROUTE is a broadcast channel: every route socket on the host sees
|
||||||
if err != nil {
|
// every RTM_* event. With concurrent route programming the default
|
||||||
return backoff.Permanent(fmt.Errorf("marshal route message: %w", err))
|
// per-socket queue overflows and our own reply gets dropped.
|
||||||
|
if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1<<20); err != nil {
|
||||||
|
log.Debugf("set SO_RCVBUF on route socket: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err = unix.Write(fd, msgBytes); err != nil {
|
bytes, err := msg.Marshal()
|
||||||
|
if err != nil {
|
||||||
|
return backoff.Permanent(fmt.Errorf("marshal: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = unix.Write(fd, bytes); err != nil {
|
||||||
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
|
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
|
||||||
return fmt.Errorf("write: %w", err)
|
return fmt.Errorf("write: %w", err)
|
||||||
}
|
}
|
||||||
return backoff.Permanent(fmt.Errorf("write: %w", err))
|
return backoff.Permanent(fmt.Errorf("write: %w", err))
|
||||||
}
|
}
|
||||||
|
return readRouteResponse(fd, msg.Type, msg.Seq)
|
||||||
|
}
|
||||||
|
|
||||||
respBuf := make([]byte, 2048)
|
// readRouteResponse reads from the AF_ROUTE socket until it sees a reply
|
||||||
n, err := unix.Read(fd, respBuf)
|
// matching our write (same type, seq, and pid). AF_ROUTE SOCK_RAW is a
|
||||||
|
// broadcast channel: interface up/down, third-party route changes and neighbor
|
||||||
|
// discovery events can all land between our write and read, so we must filter.
|
||||||
|
func readRouteResponse(fd, wantType, wantSeq int) error {
|
||||||
|
pid := int32(os.Getpid())
|
||||||
|
resp := make([]byte, 2048)
|
||||||
|
deadline := time.Now().Add(time.Second)
|
||||||
|
for {
|
||||||
|
if time.Now().After(deadline) {
|
||||||
|
// Transient: under concurrent pressure the kernel can drop our reply
|
||||||
|
// from the socket buffer. Let backoff.Retry re-send with a fresh seq.
|
||||||
|
return fmt.Errorf("read: timeout waiting for route reply type=%d seq=%d", wantType, wantSeq)
|
||||||
|
}
|
||||||
|
n, err := unix.Read(fd, resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return backoff.Permanent(fmt.Errorf("read route response: %w", err))
|
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) {
|
||||||
|
// SO_RCVTIMEO fired while waiting; loop to re-check the absolute deadline.
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
return backoff.Permanent(fmt.Errorf("read: %w", err))
|
||||||
if n > 0 {
|
|
||||||
if err := r.parseRouteResponse(respBuf[:n]); err != nil {
|
|
||||||
return backoff.Permanent(err)
|
|
||||||
}
|
}
|
||||||
|
if n < int(unsafe.Sizeof(unix.RtMsghdr{})) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
hdr := (*unix.RtMsghdr)(unsafe.Pointer(&resp[0]))
|
||||||
|
// Darwin reflects the sender's pid on replies; matching (Type, Seq, Pid)
|
||||||
|
// uniquely identifies our own reply among broadcast traffic.
|
||||||
|
if int(hdr.Type) != wantType || int(hdr.Seq) != wantSeq || hdr.Pid != pid {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if hdr.Errno != 0 {
|
||||||
|
return backoff.Permanent(fmt.Errorf("kernel: %w", syscall.Errno(hdr.Errno)))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return operation
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
|
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
|
||||||
@@ -183,6 +243,7 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
|
|||||||
Type: action,
|
Type: action,
|
||||||
Flags: unix.RTF_UP | routeProtoFlag,
|
Flags: unix.RTF_UP | routeProtoFlag,
|
||||||
Version: unix.RTM_VERSION,
|
Version: unix.RTM_VERSION,
|
||||||
|
ID: uintptr(os.Getpid()),
|
||||||
Seq: r.getSeq(),
|
Seq: r.getSeq(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -221,19 +282,6 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
|
|||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) parseRouteResponse(buf []byte) error {
|
|
||||||
if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
|
||||||
if rtMsg.Errno != 0 {
|
|
||||||
return fmt.Errorf("parse: %d", rtMsg.Errno)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
|
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
|
||||||
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
|
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
|
||||||
if addr.Is4() {
|
if addr.Is4() {
|
||||||
|
|||||||
@@ -74,6 +74,14 @@ func New(filePath string) *Manager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FilePath returns the path of the underlying state file.
|
||||||
|
func (m *Manager) FilePath() string {
|
||||||
|
if m == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return m.filePath
|
||||||
|
}
|
||||||
|
|
||||||
// Start starts the state manager periodic save routine
|
// Start starts the state manager periodic save routine
|
||||||
func (m *Manager) Start() {
|
func (m *Manager) Start() {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
|
|||||||
5
client/net/dialer_init_darwin.go
Normal file
5
client/net/dialer_init_darwin.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
func (d *Dialer) init() {
|
||||||
|
d.Dialer.Control = applyBoundIfToSocket
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !linux && !windows
|
//go:build !linux && !windows && !darwin
|
||||||
|
|
||||||
package net
|
package net
|
||||||
|
|
||||||
|
|||||||
@@ -1,24 +0,0 @@
|
|||||||
//go:build android
|
|
||||||
|
|
||||||
package net
|
|
||||||
|
|
||||||
// Init initializes the network environment for Android
|
|
||||||
func Init() {
|
|
||||||
// No initialization needed on Android
|
|
||||||
}
|
|
||||||
|
|
||||||
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
|
|
||||||
// Always returns true on Android since we cannot handle routes dynamically.
|
|
||||||
func AdvancedRouting() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetVPNInterfaceName is a no-op on Android
|
|
||||||
func SetVPNInterfaceName(name string) {
|
|
||||||
// No-op on Android - not needed for Android VPN service
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetVPNInterfaceName returns empty string on Android
|
|
||||||
func GetVPNInterfaceName() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build windows
|
//go:build (darwin && !ios) || windows
|
||||||
|
|
||||||
package net
|
package net
|
||||||
|
|
||||||
@@ -24,17 +24,22 @@ func Init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func checkAdvancedRoutingSupport() bool {
|
func checkAdvancedRoutingSupport() bool {
|
||||||
var err error
|
legacyRouting := false
|
||||||
var legacyRouting bool
|
|
||||||
if val := os.Getenv(envUseLegacyRouting); val != "" {
|
if val := os.Getenv(envUseLegacyRouting); val != "" {
|
||||||
legacyRouting, err = strconv.ParseBool(val)
|
parsed, err := strconv.ParseBool(val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
|
log.Warnf("ignoring unparsable %s=%q: %v", envUseLegacyRouting, val, err)
|
||||||
|
} else {
|
||||||
|
legacyRouting = parsed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if legacyRouting || netstack.IsEnabled() {
|
if legacyRouting {
|
||||||
log.Info("advanced routing has been requested to be disabled")
|
log.Infof("advanced routing disabled: legacy routing requested via %s", envUseLegacyRouting)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
log.Info("advanced routing disabled: netstack mode is enabled")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !linux && !windows && !android
|
//go:build !linux && !windows && !darwin
|
||||||
|
|
||||||
package net
|
package net
|
||||||
|
|
||||||
|
|||||||
25
client/net/env_mobile.go
Normal file
25
client/net/env_mobile.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
//go:build ios || android
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
// Init initializes the network environment for mobile platforms.
|
||||||
|
func Init() {
|
||||||
|
// no-op on mobile: routing scope is owned by the VPN extension.
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
|
||||||
|
// Always returns true on mobile since routes cannot be handled dynamically and the VPN extension
|
||||||
|
// owns the routing scope.
|
||||||
|
func AdvancedRouting() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetVPNInterfaceName is a no-op on mobile.
|
||||||
|
func SetVPNInterfaceName(string) {
|
||||||
|
// no-op on mobile: the VPN extension manages the interface.
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetVPNInterfaceName returns an empty string on mobile.
|
||||||
|
func GetVPNInterfaceName() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
5
client/net/listener_init_darwin.go
Normal file
5
client/net/listener_init_darwin.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
func (l *ListenerConfig) init() {
|
||||||
|
l.ListenConfig.Control = applyBoundIfToSocket
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build !linux && !windows
|
//go:build !linux && !windows && !darwin
|
||||||
|
|
||||||
package net
|
package net
|
||||||
|
|
||||||
|
|||||||
160
client/net/net_darwin.go
Normal file
160
client/net/net_darwin.go
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
// On darwin IPV6_BOUND_IF also scopes v4-mapped egress from dual-stack
|
||||||
|
// (IPV6_V6ONLY=0) AF_INET6 sockets, so a single setsockopt on "udp6"/"tcp6"
|
||||||
|
// covers both families. Setting IP_BOUND_IF on an AF_INET6 socket returns
|
||||||
|
// EINVAL regardless of V6ONLY because the IPPROTO_IP ctloutput path is
|
||||||
|
// dispatched by socket domain (AF_INET only) not by inp_vflag.
|
||||||
|
|
||||||
|
// boundIface holds the physical interface chosen at routing setup time. Sockets
|
||||||
|
// created via nbnet.NewDialer / nbnet.NewListener bind to it via IP_BOUND_IF
|
||||||
|
// (IPv4) or IPV6_BOUND_IF (IPv6 / dual-stack) so their scoped route lookup
|
||||||
|
// hits the RTF_IFSCOPE default installed by the routemanager, rather than
|
||||||
|
// following the VPN's split default.
|
||||||
|
var (
|
||||||
|
boundIfaceMu sync.RWMutex
|
||||||
|
boundIface4 *net.Interface
|
||||||
|
boundIface6 *net.Interface
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetBoundInterface records the egress interface for an address family. Called
|
||||||
|
// by the routemanager after a scoped default route has been installed.
|
||||||
|
// af must be unix.AF_INET or unix.AF_INET6; other values are ignored.
|
||||||
|
// nil iface is rejected — use ClearBoundInterfaces to clear all slots.
|
||||||
|
func SetBoundInterface(af int, iface *net.Interface) {
|
||||||
|
if iface == nil {
|
||||||
|
log.Warnf("SetBoundInterface: nil iface for AF %d, ignored", af)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
boundIfaceMu.Lock()
|
||||||
|
defer boundIfaceMu.Unlock()
|
||||||
|
switch af {
|
||||||
|
case unix.AF_INET:
|
||||||
|
boundIface4 = iface
|
||||||
|
case unix.AF_INET6:
|
||||||
|
boundIface6 = iface
|
||||||
|
default:
|
||||||
|
log.Warnf("SetBoundInterface: unsupported address family %d", af)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearBoundInterfaces resets the cached egress interfaces. Called by the
|
||||||
|
// routemanager during cleanup.
|
||||||
|
func ClearBoundInterfaces() {
|
||||||
|
boundIfaceMu.Lock()
|
||||||
|
defer boundIfaceMu.Unlock()
|
||||||
|
boundIface4 = nil
|
||||||
|
boundIface6 = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// boundInterfaceFor returns the cached egress interface for a socket's address
|
||||||
|
// family, falling back to the other family if the preferred slot is empty.
|
||||||
|
// The kernel stores both IP_BOUND_IF and IPV6_BOUND_IF in inp_boundifp, so
|
||||||
|
// either setsockopt scopes the socket; preferring same-family still matters
|
||||||
|
// when v4 and v6 defaults egress different NICs.
|
||||||
|
func boundInterfaceFor(network, address string) *net.Interface {
|
||||||
|
if iface := zoneInterface(address); iface != nil {
|
||||||
|
return iface
|
||||||
|
}
|
||||||
|
|
||||||
|
boundIfaceMu.RLock()
|
||||||
|
defer boundIfaceMu.RUnlock()
|
||||||
|
|
||||||
|
primary, secondary := boundIface4, boundIface6
|
||||||
|
if isV6Network(network) {
|
||||||
|
primary, secondary = boundIface6, boundIface4
|
||||||
|
}
|
||||||
|
if primary != nil {
|
||||||
|
return primary
|
||||||
|
}
|
||||||
|
return secondary
|
||||||
|
}
|
||||||
|
|
||||||
|
func isV6Network(network string) bool {
|
||||||
|
return strings.HasSuffix(network, "6")
|
||||||
|
}
|
||||||
|
|
||||||
|
// zoneInterface extracts an explicit interface from an IPv6 link-local zone (e.g. fe80::1%en0).
|
||||||
|
func zoneInterface(address string) *net.Interface {
|
||||||
|
if address == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
addr, err := netip.ParseAddrPort(address)
|
||||||
|
if err != nil {
|
||||||
|
a, err := netip.ParseAddr(address)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
addr = netip.AddrPortFrom(a, 0)
|
||||||
|
}
|
||||||
|
zone := addr.Addr().Zone()
|
||||||
|
if zone == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if iface, err := net.InterfaceByName(zone); err == nil {
|
||||||
|
return iface
|
||||||
|
}
|
||||||
|
if idx, err := strconv.Atoi(zone); err == nil {
|
||||||
|
if iface, err := net.InterfaceByIndex(idx); err == nil {
|
||||||
|
return iface
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setIPv4BoundIf(fd uintptr, iface *net.Interface) error {
|
||||||
|
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, iface.Index); err != nil {
|
||||||
|
return fmt.Errorf("set IP_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setIPv6BoundIf(fd uintptr, iface *net.Interface) error {
|
||||||
|
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, iface.Index); err != nil {
|
||||||
|
return fmt.Errorf("set IPV6_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyBoundIfToSocket binds the socket to the cached physical egress interface
|
||||||
|
// so scoped route lookup avoids the VPN utun and egresses the underlay directly.
|
||||||
|
func applyBoundIfToSocket(network, address string, c syscall.RawConn) error {
|
||||||
|
if !AdvancedRouting() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
iface := boundInterfaceFor(network, address)
|
||||||
|
if iface == nil {
|
||||||
|
log.Debugf("no bound iface cached for %s to %s, skipping BOUND_IF", network, address)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
isV6 := isV6Network(network)
|
||||||
|
var controlErr error
|
||||||
|
if err := c.Control(func(fd uintptr) {
|
||||||
|
if isV6 {
|
||||||
|
controlErr = setIPv6BoundIf(fd, iface)
|
||||||
|
} else {
|
||||||
|
controlErr = setIPv4BoundIf(fd, iface)
|
||||||
|
}
|
||||||
|
if controlErr == nil {
|
||||||
|
log.Debugf("set BOUND_IF=%d on %s for %s to %s", iface.Index, iface.Name, network, address)
|
||||||
|
}
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("control: %w", err)
|
||||||
|
}
|
||||||
|
return controlErr
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -209,6 +209,9 @@ message LoginRequest {
|
|||||||
optional bool enableSSHRemotePortForwarding = 37;
|
optional bool enableSSHRemotePortForwarding = 37;
|
||||||
optional bool disableSSHAuth = 38;
|
optional bool disableSSHAuth = 38;
|
||||||
optional int32 sshJWTCacheTTL = 39;
|
optional int32 sshJWTCacheTTL = 39;
|
||||||
|
|
||||||
|
optional bool serverVNCAllowed = 41;
|
||||||
|
optional bool disableVNCAuth = 42;
|
||||||
}
|
}
|
||||||
|
|
||||||
message LoginResponse {
|
message LoginResponse {
|
||||||
@@ -316,6 +319,10 @@ message GetConfigResponse {
|
|||||||
bool disableSSHAuth = 25;
|
bool disableSSHAuth = 25;
|
||||||
|
|
||||||
int32 sshJWTCacheTTL = 26;
|
int32 sshJWTCacheTTL = 26;
|
||||||
|
|
||||||
|
bool serverVNCAllowed = 28;
|
||||||
|
|
||||||
|
bool disableVNCAuth = 29;
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerState contains the latest state of a peer
|
// PeerState contains the latest state of a peer
|
||||||
@@ -394,6 +401,11 @@ message SSHServerState {
|
|||||||
repeated SSHSessionInfo sessions = 2;
|
repeated SSHSessionInfo sessions = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// VNCServerState contains the latest state of the VNC server
|
||||||
|
message VNCServerState {
|
||||||
|
bool enabled = 1;
|
||||||
|
}
|
||||||
|
|
||||||
// FullStatus contains the full state held by the Status instance
|
// FullStatus contains the full state held by the Status instance
|
||||||
message FullStatus {
|
message FullStatus {
|
||||||
ManagementState managementState = 1;
|
ManagementState managementState = 1;
|
||||||
@@ -408,6 +420,7 @@ message FullStatus {
|
|||||||
|
|
||||||
bool lazyConnectionEnabled = 9;
|
bool lazyConnectionEnabled = 9;
|
||||||
SSHServerState sshServerState = 10;
|
SSHServerState sshServerState = 10;
|
||||||
|
VNCServerState vncServerState = 11;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Networks
|
// Networks
|
||||||
@@ -677,6 +690,9 @@ message SetConfigRequest {
|
|||||||
optional bool enableSSHRemotePortForwarding = 32;
|
optional bool enableSSHRemotePortForwarding = 32;
|
||||||
optional bool disableSSHAuth = 33;
|
optional bool disableSSHAuth = 33;
|
||||||
optional int32 sshJWTCacheTTL = 34;
|
optional int32 sshJWTCacheTTL = 34;
|
||||||
|
|
||||||
|
optional bool serverVNCAllowed = 36;
|
||||||
|
optional bool disableVNCAuth = 37;
|
||||||
}
|
}
|
||||||
|
|
||||||
message SetConfigResponse{}
|
message SetConfigResponse{}
|
||||||
@@ -727,6 +743,7 @@ message GetFeaturesRequest{}
|
|||||||
message GetFeaturesResponse{
|
message GetFeaturesResponse{
|
||||||
bool disable_profiles = 1;
|
bool disable_profiles = 1;
|
||||||
bool disable_update_settings = 2;
|
bool disable_update_settings = 2;
|
||||||
|
bool disable_networks = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message TriggerUpdateRequest {}
|
message TriggerUpdateRequest {}
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@@ -27,6 +29,10 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
|||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
defer s.mutex.Unlock()
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
|
if s.networksDisabled {
|
||||||
|
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||||
|
}
|
||||||
|
|
||||||
if s.connectClient == nil {
|
if s.connectClient == nil {
|
||||||
return nil, fmt.Errorf("not connected")
|
return nil, fmt.Errorf("not connected")
|
||||||
}
|
}
|
||||||
@@ -118,6 +124,10 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
|
|||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
defer s.mutex.Unlock()
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
|
if s.networksDisabled {
|
||||||
|
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||||
|
}
|
||||||
|
|
||||||
if s.connectClient == nil {
|
if s.connectClient == nil {
|
||||||
return nil, fmt.Errorf("not connected")
|
return nil, fmt.Errorf("not connected")
|
||||||
}
|
}
|
||||||
@@ -164,6 +174,10 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
|
|||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
defer s.mutex.Unlock()
|
defer s.mutex.Unlock()
|
||||||
|
|
||||||
|
if s.networksDisabled {
|
||||||
|
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||||
|
}
|
||||||
|
|
||||||
if s.connectClient == nil {
|
if s.connectClient == nil {
|
||||||
return nil, fmt.Errorf("not connected")
|
return nil, fmt.Errorf("not connected")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ const (
|
|||||||
errRestoreResidualState = "failed to restore residual state: %v"
|
errRestoreResidualState = "failed to restore residual state: %v"
|
||||||
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
|
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
|
||||||
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
|
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
|
||||||
|
errNetworksDisabled = "network selection is disabled by the administrator"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrServiceNotUp = errors.New("service is not up")
|
var ErrServiceNotUp = errors.New("service is not up")
|
||||||
@@ -88,6 +89,7 @@ type Server struct {
|
|||||||
profileManager *profilemanager.ServiceManager
|
profileManager *profilemanager.ServiceManager
|
||||||
profilesDisabled bool
|
profilesDisabled bool
|
||||||
updateSettingsDisabled bool
|
updateSettingsDisabled bool
|
||||||
|
networksDisabled bool
|
||||||
|
|
||||||
sleepHandler *sleephandler.SleepHandler
|
sleepHandler *sleephandler.SleepHandler
|
||||||
|
|
||||||
@@ -104,7 +106,7 @@ type oauthAuthFlow struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// New server instance constructor.
|
// New server instance constructor.
|
||||||
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server {
|
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, networksDisabled bool) *Server {
|
||||||
s := &Server{
|
s := &Server{
|
||||||
rootCtx: ctx,
|
rootCtx: ctx,
|
||||||
logFile: logFile,
|
logFile: logFile,
|
||||||
@@ -113,6 +115,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
|||||||
profileManager: profilemanager.NewServiceManager(configFile),
|
profileManager: profilemanager.NewServiceManager(configFile),
|
||||||
profilesDisabled: profilesDisabled,
|
profilesDisabled: profilesDisabled,
|
||||||
updateSettingsDisabled: updateSettingsDisabled,
|
updateSettingsDisabled: updateSettingsDisabled,
|
||||||
|
networksDisabled: networksDisabled,
|
||||||
jwtCache: newJWTCache(),
|
jwtCache: newJWTCache(),
|
||||||
}
|
}
|
||||||
agent := &serverAgent{s}
|
agent := &serverAgent{s}
|
||||||
@@ -366,6 +369,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
|||||||
config.RosenpassPermissive = msg.RosenpassPermissive
|
config.RosenpassPermissive = msg.RosenpassPermissive
|
||||||
config.DisableAutoConnect = msg.DisableAutoConnect
|
config.DisableAutoConnect = msg.DisableAutoConnect
|
||||||
config.ServerSSHAllowed = msg.ServerSSHAllowed
|
config.ServerSSHAllowed = msg.ServerSSHAllowed
|
||||||
|
config.ServerVNCAllowed = msg.ServerVNCAllowed
|
||||||
config.NetworkMonitor = msg.NetworkMonitor
|
config.NetworkMonitor = msg.NetworkMonitor
|
||||||
config.DisableClientRoutes = msg.DisableClientRoutes
|
config.DisableClientRoutes = msg.DisableClientRoutes
|
||||||
config.DisableServerRoutes = msg.DisableServerRoutes
|
config.DisableServerRoutes = msg.DisableServerRoutes
|
||||||
@@ -382,6 +386,9 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
|||||||
if msg.DisableSSHAuth != nil {
|
if msg.DisableSSHAuth != nil {
|
||||||
config.DisableSSHAuth = msg.DisableSSHAuth
|
config.DisableSSHAuth = msg.DisableSSHAuth
|
||||||
}
|
}
|
||||||
|
if msg.DisableVNCAuth != nil {
|
||||||
|
config.DisableVNCAuth = msg.DisableVNCAuth
|
||||||
|
}
|
||||||
if msg.SshJWTCacheTTL != nil {
|
if msg.SshJWTCacheTTL != nil {
|
||||||
ttl := int(*msg.SshJWTCacheTTL)
|
ttl := int(*msg.SshJWTCacheTTL)
|
||||||
config.SSHJWTCacheTTL = &ttl
|
config.SSHJWTCacheTTL = &ttl
|
||||||
@@ -1120,6 +1127,7 @@ func (s *Server) Status(
|
|||||||
pbFullStatus := fullStatus.ToProto()
|
pbFullStatus := fullStatus.ToProto()
|
||||||
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||||
pbFullStatus.SshServerState = s.getSSHServerState()
|
pbFullStatus.SshServerState = s.getSSHServerState()
|
||||||
|
pbFullStatus.VncServerState = s.getVNCServerState()
|
||||||
statusResponse.FullStatus = pbFullStatus
|
statusResponse.FullStatus = pbFullStatus
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1159,6 +1167,26 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
|
|||||||
return sshServerState
|
return sshServerState
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getVNCServerState retrieves the current VNC server state.
|
||||||
|
func (s *Server) getVNCServerState() *proto.VNCServerState {
|
||||||
|
s.mutex.Lock()
|
||||||
|
connectClient := s.connectClient
|
||||||
|
s.mutex.Unlock()
|
||||||
|
|
||||||
|
if connectClient == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := connectClient.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proto.VNCServerState{
|
||||||
|
Enabled: engine.GetVNCServerStatus(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||||
func (s *Server) GetPeerSSHHostKey(
|
func (s *Server) GetPeerSSHHostKey(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
@@ -1500,6 +1528,11 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
|||||||
disableSSHAuth = *cfg.DisableSSHAuth
|
disableSSHAuth = *cfg.DisableSSHAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
|
disableVNCAuth := false
|
||||||
|
if cfg.DisableVNCAuth != nil {
|
||||||
|
disableVNCAuth = *cfg.DisableVNCAuth
|
||||||
|
}
|
||||||
|
|
||||||
sshJWTCacheTTL := int32(0)
|
sshJWTCacheTTL := int32(0)
|
||||||
if cfg.SSHJWTCacheTTL != nil {
|
if cfg.SSHJWTCacheTTL != nil {
|
||||||
sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL)
|
sshJWTCacheTTL = int32(*cfg.SSHJWTCacheTTL)
|
||||||
@@ -1514,6 +1547,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
|||||||
Mtu: int64(cfg.MTU),
|
Mtu: int64(cfg.MTU),
|
||||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||||
|
ServerVNCAllowed: cfg.ServerVNCAllowed != nil && *cfg.ServerVNCAllowed,
|
||||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||||
@@ -1529,6 +1563,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
|||||||
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
|
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
|
||||||
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
|
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
|
||||||
DisableSSHAuth: disableSSHAuth,
|
DisableSSHAuth: disableSSHAuth,
|
||||||
|
DisableVNCAuth: disableVNCAuth,
|
||||||
SshJWTCacheTTL: sshJWTCacheTTL,
|
SshJWTCacheTTL: sshJWTCacheTTL,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -1628,6 +1663,7 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
|
|||||||
features := &proto.GetFeaturesResponse{
|
features := &proto.GetFeaturesResponse{
|
||||||
DisableProfiles: s.checkProfilesDisabled(),
|
DisableProfiles: s.checkProfilesDisabled(),
|
||||||
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
|
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
|
||||||
|
DisableNetworks: s.networksDisabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
return features, nil
|
return features, nil
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ import (
|
|||||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
@@ -103,7 +104,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
|
|||||||
t.Fatalf("failed to set active profile state: %v", err)
|
t.Fatalf("failed to set active profile state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s := New(ctx, "debug", "", false, false)
|
s := New(ctx, "debug", "", false, false, false)
|
||||||
|
|
||||||
s.config = config
|
s.config = config
|
||||||
|
|
||||||
@@ -164,7 +165,7 @@ func TestServer_Up(t *testing.T) {
|
|||||||
t.Fatalf("failed to set active profile state: %v", err)
|
t.Fatalf("failed to set active profile state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s := New(ctx, "console", "", false, false)
|
s := New(ctx, "console", "", false, false, false)
|
||||||
err = s.Start()
|
err = s.Start()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -234,7 +235,7 @@ func TestServer_SubcribeEvents(t *testing.T) {
|
|||||||
t.Fatalf("failed to set active profile state: %v", err)
|
t.Fatalf("failed to set active profile state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s := New(ctx, "console", "", false, false)
|
s := New(ctx, "console", "", false, false, false)
|
||||||
|
|
||||||
err = s.Start()
|
err = s.Start()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -309,7 +310,12 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
|||||||
|
|
||||||
jobManager := job.NewJobManager(nil, store, peersManager)
|
jobManager := job.NewJobManager(nil, store, peersManager)
|
||||||
|
|
||||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore)
|
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -320,7 +326,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
|||||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,11 +53,13 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
s := New(ctx, "console", "", false, false)
|
s := New(ctx, "console", "", false, false, false)
|
||||||
|
|
||||||
rosenpassEnabled := true
|
rosenpassEnabled := true
|
||||||
rosenpassPermissive := true
|
rosenpassPermissive := true
|
||||||
serverSSHAllowed := true
|
serverSSHAllowed := true
|
||||||
|
serverVNCAllowed := true
|
||||||
|
disableVNCAuth := true
|
||||||
interfaceName := "utun100"
|
interfaceName := "utun100"
|
||||||
wireguardPort := int64(51820)
|
wireguardPort := int64(51820)
|
||||||
preSharedKey := "test-psk"
|
preSharedKey := "test-psk"
|
||||||
@@ -82,6 +84,8 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
|||||||
RosenpassEnabled: &rosenpassEnabled,
|
RosenpassEnabled: &rosenpassEnabled,
|
||||||
RosenpassPermissive: &rosenpassPermissive,
|
RosenpassPermissive: &rosenpassPermissive,
|
||||||
ServerSSHAllowed: &serverSSHAllowed,
|
ServerSSHAllowed: &serverSSHAllowed,
|
||||||
|
ServerVNCAllowed: &serverVNCAllowed,
|
||||||
|
DisableVNCAuth: &disableVNCAuth,
|
||||||
InterfaceName: &interfaceName,
|
InterfaceName: &interfaceName,
|
||||||
WireguardPort: &wireguardPort,
|
WireguardPort: &wireguardPort,
|
||||||
OptionalPreSharedKey: &preSharedKey,
|
OptionalPreSharedKey: &preSharedKey,
|
||||||
@@ -125,6 +129,10 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
|||||||
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
|
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
|
||||||
require.NotNil(t, cfg.ServerSSHAllowed)
|
require.NotNil(t, cfg.ServerSSHAllowed)
|
||||||
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
|
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
|
||||||
|
require.NotNil(t, cfg.ServerVNCAllowed)
|
||||||
|
require.Equal(t, serverVNCAllowed, *cfg.ServerVNCAllowed)
|
||||||
|
require.NotNil(t, cfg.DisableVNCAuth)
|
||||||
|
require.Equal(t, disableVNCAuth, *cfg.DisableVNCAuth)
|
||||||
require.Equal(t, interfaceName, cfg.WgIface)
|
require.Equal(t, interfaceName, cfg.WgIface)
|
||||||
require.Equal(t, int(wireguardPort), cfg.WgPort)
|
require.Equal(t, int(wireguardPort), cfg.WgPort)
|
||||||
require.Equal(t, preSharedKey, cfg.PreSharedKey)
|
require.Equal(t, preSharedKey, cfg.PreSharedKey)
|
||||||
@@ -176,6 +184,8 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
|||||||
"RosenpassEnabled": true,
|
"RosenpassEnabled": true,
|
||||||
"RosenpassPermissive": true,
|
"RosenpassPermissive": true,
|
||||||
"ServerSSHAllowed": true,
|
"ServerSSHAllowed": true,
|
||||||
|
"ServerVNCAllowed": true,
|
||||||
|
"DisableVNCAuth": true,
|
||||||
"InterfaceName": true,
|
"InterfaceName": true,
|
||||||
"WireguardPort": true,
|
"WireguardPort": true,
|
||||||
"OptionalPreSharedKey": true,
|
"OptionalPreSharedKey": true,
|
||||||
@@ -236,6 +246,8 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
|||||||
"enable-rosenpass": "RosenpassEnabled",
|
"enable-rosenpass": "RosenpassEnabled",
|
||||||
"rosenpass-permissive": "RosenpassPermissive",
|
"rosenpass-permissive": "RosenpassPermissive",
|
||||||
"allow-server-ssh": "ServerSSHAllowed",
|
"allow-server-ssh": "ServerSSHAllowed",
|
||||||
|
"allow-server-vnc": "ServerVNCAllowed",
|
||||||
|
"disable-vnc-auth": "DisableVNCAuth",
|
||||||
"interface-name": "InterfaceName",
|
"interface-name": "InterfaceName",
|
||||||
"wireguard-port": "WireguardPort",
|
"wireguard-port": "WireguardPort",
|
||||||
"preshared-key": "OptionalPreSharedKey",
|
"preshared-key": "OptionalPreSharedKey",
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -138,11 +137,9 @@ func restoreResidualState(ctx context.Context, statePath string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// clean up any remaining routes independently of the state file
|
// clean up any remaining routes independently of the state file
|
||||||
if !nbnet.AdvancedRouting() {
|
|
||||||
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
|
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -187,9 +187,8 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
|
|||||||
return "", fmt.Errorf("get NetBird executable path: %w", err)
|
return "", fmt.Errorf("get NetBird executable path: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
hostLine := strings.Join(deduplicatedPatterns, " ")
|
hostList := strings.Join(deduplicatedPatterns, ",")
|
||||||
config := fmt.Sprintf("Host %s\n", hostLine)
|
config := fmt.Sprintf("Match host \"%s\" exec \"%s ssh detect %%h %%p\"\n", hostList, execPath)
|
||||||
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
|
|
||||||
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
|
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
|
||||||
config += " PasswordAuthentication yes\n"
|
config += " PasswordAuthentication yes\n"
|
||||||
config += " PubkeyAuthentication yes\n"
|
config += " PubkeyAuthentication yes\n"
|
||||||
|
|||||||
@@ -116,6 +116,37 @@ func TestManager_PeerLimit(t *testing.T) {
|
|||||||
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
|
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestManager_MatchHostFormat(t *testing.T) {
|
||||||
|
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
|
||||||
|
|
||||||
|
manager := &Manager{
|
||||||
|
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||||
|
sshConfigFile: "99-netbird.conf",
|
||||||
|
}
|
||||||
|
|
||||||
|
peers := []PeerSSHInfo{
|
||||||
|
{Hostname: "peer1", IP: "100.125.1.1", FQDN: "peer1.nb.internal"},
|
||||||
|
{Hostname: "peer2", IP: "100.125.1.2", FQDN: "peer2.nb.internal"},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.SetupSSHClientConfig(peers)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
|
||||||
|
content, err := os.ReadFile(configPath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
configStr := string(content)
|
||||||
|
|
||||||
|
// Must use "Match host" with comma-separated patterns, not a bare "Host" directive.
|
||||||
|
// A bare "Host" followed by "Match exec" is incorrect per ssh_config(5): the Host block
|
||||||
|
// ends at the next Match keyword, making it a no-op and leaving the Match exec unscoped.
|
||||||
|
assert.NotContains(t, configStr, "\nHost ", "should not use bare Host directive")
|
||||||
|
assert.Contains(t, configStr, "Match host \"100.125.1.1,peer1.nb.internal,peer1,100.125.1.2,peer2.nb.internal,peer2\"",
|
||||||
|
"should use Match host with comma-separated patterns")
|
||||||
|
}
|
||||||
|
|
||||||
func TestManager_ForcedSSHConfig(t *testing.T) {
|
func TestManager_ForcedSSHConfig(t *testing.T) {
|
||||||
// Set force environment variable
|
// Set force environment variable
|
||||||
t.Setenv(EnvForceSSHConfig, "true")
|
t.Setenv(EnvForceSSHConfig, "true")
|
||||||
|
|||||||
@@ -200,8 +200,8 @@ func newLsaString(s string) lsaString {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateS4UUserToken creates a Windows token using S4U authentication
|
// generateS4UUserToken creates a Windows token using S4U authentication.
|
||||||
// This is the exact approach OpenSSH for Windows uses for public key authentication
|
// This is the same approach OpenSSH for Windows uses for public key authentication.
|
||||||
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
|
func generateS4UUserToken(logger *log.Entry, username, domain string) (windows.Handle, error) {
|
||||||
userCpn := buildUserCpn(username, domain)
|
userCpn := buildUserCpn(username, domain)
|
||||||
|
|
||||||
|
|||||||
@@ -507,27 +507,7 @@ func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
|
|||||||
maxTokenAge = DefaultJWTMaxTokenAge
|
maxTokenAge = DefaultJWTMaxTokenAge
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
return jwt.CheckTokenAge(token, time.Duration(maxTokenAge)*time.Second)
|
||||||
if !ok {
|
|
||||||
userID := extractUserID(token)
|
|
||||||
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
|
|
||||||
}
|
|
||||||
|
|
||||||
iat, ok := claims["iat"].(float64)
|
|
||||||
if !ok {
|
|
||||||
userID := extractUserID(token)
|
|
||||||
return fmt.Errorf("token missing iat claim (user=%s)", userID)
|
|
||||||
}
|
|
||||||
|
|
||||||
issuedAt := time.Unix(int64(iat), 0)
|
|
||||||
tokenAge := time.Since(issuedAt)
|
|
||||||
maxAge := time.Duration(maxTokenAge) * time.Second
|
|
||||||
if tokenAge > maxAge {
|
|
||||||
userID := getUserIDFromClaims(claims)
|
|
||||||
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
|
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
|
||||||
@@ -558,27 +538,7 @@ func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func extractUserID(token *gojwt.Token) string {
|
func extractUserID(token *gojwt.Token) string {
|
||||||
if token == nil {
|
return jwt.UserIDFromToken(token)
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
|
||||||
if !ok {
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
return getUserIDFromClaims(claims)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getUserIDFromClaims(claims gojwt.MapClaims) string {
|
|
||||||
if sub, ok := claims["sub"].(string); ok && sub != "" {
|
|
||||||
return sub
|
|
||||||
}
|
|
||||||
if userID, ok := claims["user_id"].(string); ok && userID != "" {
|
|
||||||
return userID
|
|
||||||
}
|
|
||||||
if email, ok := claims["email"].(string); ok && email != "" {
|
|
||||||
return email
|
|
||||||
}
|
|
||||||
return "unknown"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {
|
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {
|
||||||
|
|||||||
@@ -130,6 +130,10 @@ type SSHServerStateOutput struct {
|
|||||||
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
|
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type VNCServerStateOutput struct {
|
||||||
|
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
type OutputOverview struct {
|
type OutputOverview struct {
|
||||||
Peers PeersStateOutput `json:"peers" yaml:"peers"`
|
Peers PeersStateOutput `json:"peers" yaml:"peers"`
|
||||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||||
@@ -151,6 +155,7 @@ type OutputOverview struct {
|
|||||||
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
|
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
|
||||||
ProfileName string `json:"profileName" yaml:"profileName"`
|
ProfileName string `json:"profileName" yaml:"profileName"`
|
||||||
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
|
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
|
||||||
|
VNCServerState VNCServerStateOutput `json:"vncServer" yaml:"vncServer"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertToStatusOutputOverview converts protobuf status to the output overview.
|
// ConvertToStatusOutputOverview converts protobuf status to the output overview.
|
||||||
@@ -171,6 +176,9 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
|
|||||||
|
|
||||||
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
||||||
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
|
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
|
||||||
|
vncServerOverview := VNCServerStateOutput{
|
||||||
|
Enabled: pbFullStatus.GetVncServerState().GetEnabled(),
|
||||||
|
}
|
||||||
peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter)
|
peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter)
|
||||||
|
|
||||||
overview := OutputOverview{
|
overview := OutputOverview{
|
||||||
@@ -194,6 +202,7 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
|
|||||||
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
|
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
|
||||||
ProfileName: opts.ProfileName,
|
ProfileName: opts.ProfileName,
|
||||||
SSHServerState: sshServerOverview,
|
SSHServerState: sshServerOverview,
|
||||||
|
VNCServerState: vncServerOverview,
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.Anonymize {
|
if opts.Anonymize {
|
||||||
@@ -524,6 +533,11 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
vncServerStatus := "Disabled"
|
||||||
|
if o.VNCServerState.Enabled {
|
||||||
|
vncServerStatus = "Enabled"
|
||||||
|
}
|
||||||
|
|
||||||
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
|
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
|
||||||
|
|
||||||
var forwardingRulesString string
|
var forwardingRulesString string
|
||||||
@@ -553,6 +567,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
|||||||
"Quantum resistance: %s\n"+
|
"Quantum resistance: %s\n"+
|
||||||
"Lazy connection: %s\n"+
|
"Lazy connection: %s\n"+
|
||||||
"SSH Server: %s\n"+
|
"SSH Server: %s\n"+
|
||||||
|
"VNC Server: %s\n"+
|
||||||
"Networks: %s\n"+
|
"Networks: %s\n"+
|
||||||
"%s"+
|
"%s"+
|
||||||
"Peers count: %s\n",
|
"Peers count: %s\n",
|
||||||
@@ -570,6 +585,7 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
|||||||
rosenpassEnabledStatus,
|
rosenpassEnabledStatus,
|
||||||
lazyConnectionEnabledStatus,
|
lazyConnectionEnabledStatus,
|
||||||
sshServerStatus,
|
sshServerStatus,
|
||||||
|
vncServerStatus,
|
||||||
networks,
|
networks,
|
||||||
forwardingRulesString,
|
forwardingRulesString,
|
||||||
peersCountString,
|
peersCountString,
|
||||||
|
|||||||
@@ -398,6 +398,9 @@ func TestParsingToJSON(t *testing.T) {
|
|||||||
"sshServer":{
|
"sshServer":{
|
||||||
"enabled":false,
|
"enabled":false,
|
||||||
"sessions":[]
|
"sessions":[]
|
||||||
|
},
|
||||||
|
"vncServer":{
|
||||||
|
"enabled":false
|
||||||
}
|
}
|
||||||
}`
|
}`
|
||||||
// @formatter:on
|
// @formatter:on
|
||||||
@@ -505,6 +508,8 @@ profileName: ""
|
|||||||
sshServer:
|
sshServer:
|
||||||
enabled: false
|
enabled: false
|
||||||
sessions: []
|
sessions: []
|
||||||
|
vncServer:
|
||||||
|
enabled: false
|
||||||
`
|
`
|
||||||
|
|
||||||
assert.Equal(t, expectedYAML, yaml)
|
assert.Equal(t, expectedYAML, yaml)
|
||||||
@@ -572,6 +577,7 @@ Interface type: Kernel
|
|||||||
Quantum resistance: false
|
Quantum resistance: false
|
||||||
Lazy connection: false
|
Lazy connection: false
|
||||||
SSH Server: Disabled
|
SSH Server: Disabled
|
||||||
|
VNC Server: Disabled
|
||||||
Networks: 10.10.0.0/24
|
Networks: 10.10.0.0/24
|
||||||
Peers count: 2/2 Connected
|
Peers count: 2/2 Connected
|
||||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
||||||
@@ -596,6 +602,7 @@ Interface type: Kernel
|
|||||||
Quantum resistance: false
|
Quantum resistance: false
|
||||||
Lazy connection: false
|
Lazy connection: false
|
||||||
SSH Server: Disabled
|
SSH Server: Disabled
|
||||||
|
VNC Server: Disabled
|
||||||
Networks: 10.10.0.0/24
|
Networks: 10.10.0.0/24
|
||||||
Peers count: 2/2 Connected
|
Peers count: 2/2 Connected
|
||||||
`
|
`
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package system
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -63,6 +62,7 @@ type Info struct {
|
|||||||
RosenpassEnabled bool
|
RosenpassEnabled bool
|
||||||
RosenpassPermissive bool
|
RosenpassPermissive bool
|
||||||
ServerSSHAllowed bool
|
ServerSSHAllowed bool
|
||||||
|
ServerVNCAllowed bool
|
||||||
|
|
||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
DisableServerRoutes bool
|
DisableServerRoutes bool
|
||||||
@@ -78,21 +78,27 @@ type Info struct {
|
|||||||
EnableSSHLocalPortForwarding bool
|
EnableSSHLocalPortForwarding bool
|
||||||
EnableSSHRemotePortForwarding bool
|
EnableSSHRemotePortForwarding bool
|
||||||
DisableSSHAuth bool
|
DisableSSHAuth bool
|
||||||
|
DisableVNCAuth bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Info) SetFlags(
|
func (i *Info) SetFlags(
|
||||||
rosenpassEnabled, rosenpassPermissive bool,
|
rosenpassEnabled, rosenpassPermissive bool,
|
||||||
serverSSHAllowed *bool,
|
serverSSHAllowed *bool,
|
||||||
|
serverVNCAllowed *bool,
|
||||||
disableClientRoutes, disableServerRoutes,
|
disableClientRoutes, disableServerRoutes,
|
||||||
disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool,
|
disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool,
|
||||||
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
|
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
|
||||||
disableSSHAuth *bool,
|
disableSSHAuth *bool,
|
||||||
|
disableVNCAuth *bool,
|
||||||
) {
|
) {
|
||||||
i.RosenpassEnabled = rosenpassEnabled
|
i.RosenpassEnabled = rosenpassEnabled
|
||||||
i.RosenpassPermissive = rosenpassPermissive
|
i.RosenpassPermissive = rosenpassPermissive
|
||||||
if serverSSHAllowed != nil {
|
if serverSSHAllowed != nil {
|
||||||
i.ServerSSHAllowed = *serverSSHAllowed
|
i.ServerSSHAllowed = *serverSSHAllowed
|
||||||
}
|
}
|
||||||
|
if serverVNCAllowed != nil {
|
||||||
|
i.ServerVNCAllowed = *serverVNCAllowed
|
||||||
|
}
|
||||||
|
|
||||||
i.DisableClientRoutes = disableClientRoutes
|
i.DisableClientRoutes = disableClientRoutes
|
||||||
i.DisableServerRoutes = disableServerRoutes
|
i.DisableServerRoutes = disableServerRoutes
|
||||||
@@ -118,6 +124,9 @@ func (i *Info) SetFlags(
|
|||||||
if disableSSHAuth != nil {
|
if disableSSHAuth != nil {
|
||||||
i.DisableSSHAuth = *disableSSHAuth
|
i.DisableSSHAuth = *disableSSHAuth
|
||||||
}
|
}
|
||||||
|
if disableVNCAuth != nil {
|
||||||
|
i.DisableVNCAuth = *disableVNCAuth
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
||||||
@@ -145,59 +154,6 @@ func extractDeviceName(ctx context.Context, defaultName string) string {
|
|||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
func networkAddresses() ([]NetworkAddress, error) {
|
|
||||||
interfaces, err := net.Interfaces()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var netAddresses []NetworkAddress
|
|
||||||
for _, iface := range interfaces {
|
|
||||||
if iface.Flags&net.FlagUp == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if iface.HardwareAddr.String() == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
addrs, err := iface.Addrs()
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, address := range addrs {
|
|
||||||
ipNet, ok := address.(*net.IPNet)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if ipNet.IP.IsLoopback() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
netAddr := NetworkAddress{
|
|
||||||
NetIP: netip.MustParsePrefix(ipNet.String()),
|
|
||||||
Mac: iface.HardwareAddr.String(),
|
|
||||||
}
|
|
||||||
|
|
||||||
if isDuplicated(netAddresses, netAddr) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
netAddresses = append(netAddresses, netAddr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return netAddresses, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
|
||||||
for _, duplicated := range addresses {
|
|
||||||
if duplicated.NetIP == addr.NetIP {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetInfoWithChecks retrieves and parses the system information with applied checks.
|
// GetInfoWithChecks retrieves and parses the system information with applied checks.
|
||||||
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
|
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
|
||||||
log.Debugf("gathering system information with checks: %d", len(checks))
|
log.Debugf("gathering system information with checks: %d", len(checks))
|
||||||
|
|||||||
@@ -2,12 +2,16 @@ package system
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update
|
// UpdateStaticInfoAsync is a no-op on iOS as there is no static info to update
|
||||||
func UpdateStaticInfoAsync() {
|
func UpdateStaticInfoAsync() {
|
||||||
// do nothing
|
// do nothing
|
||||||
}
|
}
|
||||||
@@ -15,11 +19,24 @@ func UpdateStaticInfoAsync() {
|
|||||||
// GetInfo retrieves and parses the system information
|
// GetInfo retrieves and parses the system information
|
||||||
func GetInfo(ctx context.Context) *Info {
|
func GetInfo(ctx context.Context) *Info {
|
||||||
|
|
||||||
// Convert fixed-size byte arrays to Go strings
|
|
||||||
sysName := extractOsName(ctx, "sysName")
|
sysName := extractOsName(ctx, "sysName")
|
||||||
swVersion := extractOsVersion(ctx, "swVersion")
|
swVersion := extractOsVersion(ctx, "swVersion")
|
||||||
|
|
||||||
gio := &Info{Kernel: sysName, OSVersion: swVersion, Platform: "unknown", OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: swVersion}
|
addrs, err := networkAddresses()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to discover network addresses: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gio := &Info{
|
||||||
|
Kernel: sysName,
|
||||||
|
OSVersion: swVersion,
|
||||||
|
Platform: "unknown",
|
||||||
|
OS: sysName,
|
||||||
|
GoOS: runtime.GOOS,
|
||||||
|
CPUs: runtime.NumCPU(),
|
||||||
|
KernelVersion: swVersion,
|
||||||
|
NetworkAddresses: addrs,
|
||||||
|
}
|
||||||
gio.Hostname = extractDeviceName(ctx, "hostname")
|
gio.Hostname = extractDeviceName(ctx, "hostname")
|
||||||
gio.NetbirdVersion = version.NetbirdVersion()
|
gio.NetbirdVersion = version.NetbirdVersion()
|
||||||
gio.UIVersion = extractUserAgent(ctx)
|
gio.UIVersion = extractUserAgent(ctx)
|
||||||
@@ -27,6 +44,66 @@ func GetInfo(ctx context.Context) *Info {
|
|||||||
return gio
|
return gio
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// networkAddresses returns the list of network addresses on iOS.
|
||||||
|
// On iOS, hardware (MAC) addresses are not available due to Apple's privacy
|
||||||
|
// restrictions (iOS returns a fixed 02:00:00:00:00:00 placeholder), so we
|
||||||
|
// leave Mac empty to match Android's behavior. We also skip the HardwareAddr
|
||||||
|
// check that other platforms use and filter out link-local addresses as they
|
||||||
|
// are not useful for posture checks.
|
||||||
|
func networkAddresses() ([]NetworkAddress, error) {
|
||||||
|
interfaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var netAddresses []NetworkAddress
|
||||||
|
for _, iface := range interfaces {
|
||||||
|
if iface.Flags&net.FlagUp == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, address := range addrs {
|
||||||
|
netAddr, ok := toNetworkAddress(address)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if isDuplicated(netAddresses, netAddr) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
netAddresses = append(netAddresses, netAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return netAddresses, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func toNetworkAddress(address net.Addr) (NetworkAddress, bool) {
|
||||||
|
ipNet, ok := address.(*net.IPNet)
|
||||||
|
if !ok {
|
||||||
|
return NetworkAddress{}, false
|
||||||
|
}
|
||||||
|
if ipNet.IP.IsLoopback() || ipNet.IP.IsLinkLocalUnicast() || ipNet.IP.IsMulticast() {
|
||||||
|
return NetworkAddress{}, false
|
||||||
|
}
|
||||||
|
prefix, err := netip.ParsePrefix(ipNet.String())
|
||||||
|
if err != nil {
|
||||||
|
return NetworkAddress{}, false
|
||||||
|
}
|
||||||
|
return NetworkAddress{NetIP: prefix, Mac: ""}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
||||||
|
for _, duplicated := range addresses {
|
||||||
|
if duplicated.NetIP == addr.NetIP {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||||
return []File{}, nil
|
return []File{}, nil
|
||||||
|
|||||||
66
client/system/network_addr.go
Normal file
66
client/system/network_addr.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package system
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
func networkAddresses() ([]NetworkAddress, error) {
|
||||||
|
interfaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var netAddresses []NetworkAddress
|
||||||
|
for _, iface := range interfaces {
|
||||||
|
if iface.Flags&net.FlagUp == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if iface.HardwareAddr.String() == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
mac := iface.HardwareAddr.String()
|
||||||
|
for _, address := range addrs {
|
||||||
|
netAddr, ok := toNetworkAddress(address, mac)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if isDuplicated(netAddresses, netAddr) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
netAddresses = append(netAddresses, netAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return netAddresses, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func toNetworkAddress(address net.Addr, mac string) (NetworkAddress, bool) {
|
||||||
|
ipNet, ok := address.(*net.IPNet)
|
||||||
|
if !ok {
|
||||||
|
return NetworkAddress{}, false
|
||||||
|
}
|
||||||
|
if ipNet.IP.IsLoopback() {
|
||||||
|
return NetworkAddress{}, false
|
||||||
|
}
|
||||||
|
prefix, err := netip.ParsePrefix(ipNet.String())
|
||||||
|
if err != nil {
|
||||||
|
return NetworkAddress{}, false
|
||||||
|
}
|
||||||
|
return NetworkAddress{NetIP: prefix, Mac: mac}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
||||||
|
for _, duplicated := range addresses {
|
||||||
|
if duplicated.NetIP == addr.NetIP {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -314,6 +314,7 @@ type serviceClient struct {
|
|||||||
lastNotifiedVersion string
|
lastNotifiedVersion string
|
||||||
settingsEnabled bool
|
settingsEnabled bool
|
||||||
profilesEnabled bool
|
profilesEnabled bool
|
||||||
|
networksEnabled bool
|
||||||
showNetworks bool
|
showNetworks bool
|
||||||
wNetworks fyne.Window
|
wNetworks fyne.Window
|
||||||
wProfiles fyne.Window
|
wProfiles fyne.Window
|
||||||
@@ -368,6 +369,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
|
|||||||
|
|
||||||
showAdvancedSettings: args.showSettings,
|
showAdvancedSettings: args.showSettings,
|
||||||
showNetworks: args.showNetworks,
|
showNetworks: args.showNetworks,
|
||||||
|
networksEnabled: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
s.eventHandler = newEventHandler(s)
|
s.eventHandler = newEventHandler(s)
|
||||||
@@ -920,8 +922,10 @@ func (s *serviceClient) updateStatus() error {
|
|||||||
s.mStatus.SetIcon(s.icConnectedDot)
|
s.mStatus.SetIcon(s.icConnectedDot)
|
||||||
s.mUp.Disable()
|
s.mUp.Disable()
|
||||||
s.mDown.Enable()
|
s.mDown.Enable()
|
||||||
|
if s.networksEnabled {
|
||||||
s.mNetworks.Enable()
|
s.mNetworks.Enable()
|
||||||
s.mExitNode.Enable()
|
s.mExitNode.Enable()
|
||||||
|
}
|
||||||
s.startExitNodeRefresh()
|
s.startExitNodeRefresh()
|
||||||
systrayIconState = true
|
systrayIconState = true
|
||||||
case status.Status == string(internal.StatusConnecting):
|
case status.Status == string(internal.StatusConnecting):
|
||||||
@@ -1093,14 +1097,14 @@ func (s *serviceClient) onTrayReady() {
|
|||||||
s.getSrvConfig()
|
s.getSrvConfig()
|
||||||
time.Sleep(100 * time.Millisecond) // To prevent race condition caused by systray not being fully initialized and ignoring setIcon
|
time.Sleep(100 * time.Millisecond) // To prevent race condition caused by systray not being fully initialized and ignoring setIcon
|
||||||
for {
|
for {
|
||||||
|
// Check features before status so menus respect disable flags before being enabled
|
||||||
|
s.checkAndUpdateFeatures()
|
||||||
|
|
||||||
err := s.updateStatus()
|
err := s.updateStatus()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("error while updating status: %v", err)
|
log.Errorf("error while updating status: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check features periodically to handle daemon restarts
|
|
||||||
s.checkAndUpdateFeatures()
|
|
||||||
|
|
||||||
time.Sleep(2 * time.Second)
|
time.Sleep(2 * time.Second)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -1299,6 +1303,16 @@ func (s *serviceClient) checkAndUpdateFeatures() {
|
|||||||
s.mProfile.setEnabled(profilesEnabled)
|
s.mProfile.setEnabled(profilesEnabled)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update networks and exit node menus based on current features
|
||||||
|
s.networksEnabled = features == nil || !features.DisableNetworks
|
||||||
|
if s.networksEnabled && s.connected {
|
||||||
|
s.mNetworks.Enable()
|
||||||
|
s.mExitNode.Enable()
|
||||||
|
} else {
|
||||||
|
s.mNetworks.Disable()
|
||||||
|
s.mExitNode.Disable()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getFeatures from the daemon to determine which features are enabled/disabled.
|
// getFeatures from the daemon to determine which features are enabled/disabled.
|
||||||
|
|||||||
474
client/vnc/server/agent_windows.go
Normal file
474
client/vnc/server/agent_windows.go
Normal file
@@ -0,0 +1,474 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
crand "crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
agentPort = "15900"
|
||||||
|
|
||||||
|
// agentTokenLen is the length of the random authentication token
|
||||||
|
// used to verify that connections to the agent come from the service.
|
||||||
|
agentTokenLen = 32
|
||||||
|
|
||||||
|
stillActive = 259
|
||||||
|
|
||||||
|
tokenPrimary = 1
|
||||||
|
securityImpersonation = 2
|
||||||
|
tokenSessionID = 12
|
||||||
|
|
||||||
|
createUnicodeEnvironment = 0x00000400
|
||||||
|
createNoWindow = 0x08000000
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||||
|
advapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
||||||
|
userenv = windows.NewLazySystemDLL("userenv.dll")
|
||||||
|
|
||||||
|
procWTSGetActiveConsoleSessionId = kernel32.NewProc("WTSGetActiveConsoleSessionId")
|
||||||
|
procSetTokenInformation = advapi32.NewProc("SetTokenInformation")
|
||||||
|
procCreateEnvironmentBlock = userenv.NewProc("CreateEnvironmentBlock")
|
||||||
|
procDestroyEnvironmentBlock = userenv.NewProc("DestroyEnvironmentBlock")
|
||||||
|
|
||||||
|
wtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll")
|
||||||
|
procWTSEnumerateSessionsW = wtsapi32.NewProc("WTSEnumerateSessionsW")
|
||||||
|
procWTSFreeMemory = wtsapi32.NewProc("WTSFreeMemory")
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetCurrentSessionID returns the session ID of the current process.
|
||||||
|
func GetCurrentSessionID() uint32 {
|
||||||
|
var token windows.Token
|
||||||
|
if err := windows.OpenProcessToken(windows.CurrentProcess(),
|
||||||
|
windows.TOKEN_QUERY, &token); err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
defer token.Close()
|
||||||
|
var id uint32
|
||||||
|
var ret uint32
|
||||||
|
_ = windows.GetTokenInformation(token, windows.TokenSessionId,
|
||||||
|
(*byte)(unsafe.Pointer(&id)), 4, &ret)
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func getConsoleSessionID() uint32 {
|
||||||
|
r, _, _ := procWTSGetActiveConsoleSessionId.Call()
|
||||||
|
return uint32(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
wtsActive = 0
|
||||||
|
wtsConnected = 1
|
||||||
|
wtsDisconnected = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
type wtsSessionInfo struct {
|
||||||
|
SessionID uint32
|
||||||
|
WinStationName [66]byte // actually *uint16, but we just need the struct size
|
||||||
|
State uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// getActiveSessionID returns the session ID of the best session to attach to.
|
||||||
|
// Prefers an active (logged-in, interactive) session over the console session.
|
||||||
|
// This avoids kicking out an RDP user when the console is at the login screen.
|
||||||
|
func getActiveSessionID() uint32 {
|
||||||
|
var sessionInfo uintptr
|
||||||
|
var count uint32
|
||||||
|
|
||||||
|
r, _, _ := procWTSEnumerateSessionsW.Call(
|
||||||
|
0, // WTS_CURRENT_SERVER_HANDLE
|
||||||
|
0, // reserved
|
||||||
|
1, // version
|
||||||
|
uintptr(unsafe.Pointer(&sessionInfo)),
|
||||||
|
uintptr(unsafe.Pointer(&count)),
|
||||||
|
)
|
||||||
|
if r == 0 || count == 0 {
|
||||||
|
return getConsoleSessionID()
|
||||||
|
}
|
||||||
|
defer procWTSFreeMemory.Call(sessionInfo)
|
||||||
|
|
||||||
|
type wtsSession struct {
|
||||||
|
SessionID uint32
|
||||||
|
Station *uint16
|
||||||
|
State uint32
|
||||||
|
}
|
||||||
|
sessions := unsafe.Slice((*wtsSession)(unsafe.Pointer(sessionInfo)), count)
|
||||||
|
|
||||||
|
// Find the first active session (not session 0, which is the services session).
|
||||||
|
var bestID uint32
|
||||||
|
found := false
|
||||||
|
for _, s := range sessions {
|
||||||
|
if s.SessionID == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if s.State == wtsActive {
|
||||||
|
bestID = s.SessionID
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
return getConsoleSessionID()
|
||||||
|
}
|
||||||
|
return bestID
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSystemTokenForSession duplicates the current SYSTEM token and sets its
|
||||||
|
// session ID so the spawned process runs in the target session. Using a SYSTEM
|
||||||
|
// token gives access to both Default and Winlogon desktops plus UIPI bypass.
|
||||||
|
func getSystemTokenForSession(sessionID uint32) (windows.Token, error) {
|
||||||
|
var cur windows.Token
|
||||||
|
if err := windows.OpenProcessToken(windows.CurrentProcess(),
|
||||||
|
windows.MAXIMUM_ALLOWED, &cur); err != nil {
|
||||||
|
return 0, fmt.Errorf("OpenProcessToken: %w", err)
|
||||||
|
}
|
||||||
|
defer cur.Close()
|
||||||
|
|
||||||
|
var dup windows.Token
|
||||||
|
if err := windows.DuplicateTokenEx(cur, windows.MAXIMUM_ALLOWED, nil,
|
||||||
|
securityImpersonation, tokenPrimary, &dup); err != nil {
|
||||||
|
return 0, fmt.Errorf("DuplicateTokenEx: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sid := sessionID
|
||||||
|
r, _, err := procSetTokenInformation.Call(
|
||||||
|
uintptr(dup),
|
||||||
|
uintptr(tokenSessionID),
|
||||||
|
uintptr(unsafe.Pointer(&sid)),
|
||||||
|
unsafe.Sizeof(sid),
|
||||||
|
)
|
||||||
|
if r == 0 {
|
||||||
|
dup.Close()
|
||||||
|
return 0, fmt.Errorf("SetTokenInformation(SessionId=%d): %w", sessionID, err)
|
||||||
|
}
|
||||||
|
return dup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const agentTokenEnvVar = "NB_VNC_AGENT_TOKEN"
|
||||||
|
|
||||||
|
// injectEnvVar appends a KEY=VALUE entry to a Unicode environment block.
|
||||||
|
// The block is a sequence of null-terminated UTF-16 strings, terminated by
|
||||||
|
// an extra null. Returns a new block pointer with the entry added.
|
||||||
|
func injectEnvVar(envBlock uintptr, key, value string) uintptr {
|
||||||
|
entry := key + "=" + value
|
||||||
|
|
||||||
|
// Walk the existing block to find its total length.
|
||||||
|
ptr := (*uint16)(unsafe.Pointer(envBlock))
|
||||||
|
var totalChars int
|
||||||
|
for {
|
||||||
|
ch := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars)*2))
|
||||||
|
if ch == 0 {
|
||||||
|
// Check for double-null terminator.
|
||||||
|
next := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars+1)*2))
|
||||||
|
totalChars++
|
||||||
|
if next == 0 {
|
||||||
|
// End of block (don't count the final null yet, we'll rebuild).
|
||||||
|
break
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
totalChars++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
entryUTF16, _ := windows.UTF16FromString(entry)
|
||||||
|
// New block: existing entries + new entry (null-terminated) + final null.
|
||||||
|
newLen := totalChars + len(entryUTF16) + 1
|
||||||
|
newBlock := make([]uint16, newLen)
|
||||||
|
// Copy existing entries (up to but not including the final null).
|
||||||
|
for i := range totalChars {
|
||||||
|
newBlock[i] = *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(i)*2))
|
||||||
|
}
|
||||||
|
copy(newBlock[totalChars:], entryUTF16)
|
||||||
|
newBlock[newLen-1] = 0 // final null terminator
|
||||||
|
|
||||||
|
return uintptr(unsafe.Pointer(&newBlock[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func spawnAgentInSession(sessionID uint32, port string, authToken string) (windows.Handle, error) {
|
||||||
|
token, err := getSystemTokenForSession(sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("get SYSTEM token for session %d: %w", sessionID, err)
|
||||||
|
}
|
||||||
|
defer token.Close()
|
||||||
|
|
||||||
|
var envBlock uintptr
|
||||||
|
r, _, _ := procCreateEnvironmentBlock.Call(
|
||||||
|
uintptr(unsafe.Pointer(&envBlock)),
|
||||||
|
uintptr(token),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
if r != 0 {
|
||||||
|
defer procDestroyEnvironmentBlock.Call(envBlock)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inject the auth token into the environment block so it doesn't appear
|
||||||
|
// in the process command line (visible via tasklist/wmic).
|
||||||
|
if r != 0 {
|
||||||
|
envBlock = injectEnvVar(envBlock, agentTokenEnvVar, authToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
exePath, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("get executable path: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmdLine := fmt.Sprintf(`"%s" vnc-agent --port %s`, exePath, port)
|
||||||
|
cmdLineW, err := windows.UTF16PtrFromString(cmdLine)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("UTF16 cmdline: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an inheritable pipe for the agent's stderr so we can relog
|
||||||
|
// its output in the service process.
|
||||||
|
var sa windows.SecurityAttributes
|
||||||
|
sa.Length = uint32(unsafe.Sizeof(sa))
|
||||||
|
sa.InheritHandle = 1
|
||||||
|
|
||||||
|
var stderrRead, stderrWrite windows.Handle
|
||||||
|
if err := windows.CreatePipe(&stderrRead, &stderrWrite, &sa, 0); err != nil {
|
||||||
|
return 0, fmt.Errorf("create stderr pipe: %w", err)
|
||||||
|
}
|
||||||
|
// The read end must NOT be inherited by the child.
|
||||||
|
windows.SetHandleInformation(stderrRead, windows.HANDLE_FLAG_INHERIT, 0)
|
||||||
|
|
||||||
|
desktop, _ := windows.UTF16PtrFromString(`WinSta0\Default`)
|
||||||
|
si := windows.StartupInfo{
|
||||||
|
Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})),
|
||||||
|
Desktop: desktop,
|
||||||
|
Flags: windows.STARTF_USESHOWWINDOW | windows.STARTF_USESTDHANDLES,
|
||||||
|
ShowWindow: 0,
|
||||||
|
StdErr: stderrWrite,
|
||||||
|
StdOutput: stderrWrite,
|
||||||
|
}
|
||||||
|
var pi windows.ProcessInformation
|
||||||
|
|
||||||
|
var envPtr *uint16
|
||||||
|
if envBlock != 0 {
|
||||||
|
envPtr = (*uint16)(unsafe.Pointer(envBlock))
|
||||||
|
}
|
||||||
|
|
||||||
|
err = windows.CreateProcessAsUser(
|
||||||
|
token, nil, cmdLineW,
|
||||||
|
nil, nil, true, // inheritHandles=true for the pipe
|
||||||
|
createUnicodeEnvironment|createNoWindow,
|
||||||
|
envPtr, nil, &si, &pi,
|
||||||
|
)
|
||||||
|
// Close the write end in the parent so reads will get EOF when the child exits.
|
||||||
|
windows.CloseHandle(stderrWrite)
|
||||||
|
if err != nil {
|
||||||
|
windows.CloseHandle(stderrRead)
|
||||||
|
return 0, fmt.Errorf("CreateProcessAsUser: %w", err)
|
||||||
|
}
|
||||||
|
windows.CloseHandle(pi.Thread)
|
||||||
|
|
||||||
|
// Relog agent output in the service with a [vnc-agent] prefix.
|
||||||
|
go relogAgentOutput(stderrRead)
|
||||||
|
|
||||||
|
log.Infof("spawned agent PID=%d in session %d on port %s", pi.ProcessId, sessionID, port)
|
||||||
|
return pi.Process, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sessionManager monitors the active console session and ensures a VNC agent
|
||||||
|
// process is running in it. When the session changes (e.g., user switch, RDP
|
||||||
|
// connect/disconnect), it kills the old agent and spawns a new one.
|
||||||
|
type sessionManager struct {
|
||||||
|
port string
|
||||||
|
mu sync.Mutex
|
||||||
|
agentProc windows.Handle
|
||||||
|
sessionID uint32
|
||||||
|
authToken string
|
||||||
|
done chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSessionManager(port string) *sessionManager {
|
||||||
|
return &sessionManager{port: port, sessionID: ^uint32(0), done: make(chan struct{})}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateAuthToken creates a new random hex token for agent authentication.
|
||||||
|
func generateAuthToken() string {
|
||||||
|
b := make([]byte, agentTokenLen)
|
||||||
|
if _, err := crand.Read(b); err != nil {
|
||||||
|
log.Warnf("generate agent auth token: %v", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthToken returns the current agent authentication token.
|
||||||
|
func (m *sessionManager) AuthToken() string {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return m.authToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop signals the session manager to exit its polling loop.
|
||||||
|
func (m *sessionManager) Stop() {
|
||||||
|
select {
|
||||||
|
case <-m.done:
|
||||||
|
default:
|
||||||
|
close(m.done)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *sessionManager) run() {
|
||||||
|
ticker := time.NewTicker(2 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
sid := getActiveSessionID()
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
if sid != m.sessionID {
|
||||||
|
log.Infof("active session changed: %d -> %d", m.sessionID, sid)
|
||||||
|
m.killAgent()
|
||||||
|
m.sessionID = sid
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.agentProc != 0 {
|
||||||
|
var code uint32
|
||||||
|
_ = windows.GetExitCodeProcess(m.agentProc, &code)
|
||||||
|
if code != stillActive {
|
||||||
|
log.Infof("agent exited (code=%d), respawning", code)
|
||||||
|
windows.CloseHandle(m.agentProc)
|
||||||
|
m.agentProc = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.agentProc == 0 && sid != 0xFFFFFFFF {
|
||||||
|
m.authToken = generateAuthToken()
|
||||||
|
h, err := spawnAgentInSession(sid, m.port, m.authToken)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("spawn agent in session %d: %v", sid, err)
|
||||||
|
m.authToken = ""
|
||||||
|
} else {
|
||||||
|
m.agentProc = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-m.done:
|
||||||
|
m.mu.Lock()
|
||||||
|
m.killAgent()
|
||||||
|
m.mu.Unlock()
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *sessionManager) killAgent() {
|
||||||
|
if m.agentProc != 0 {
|
||||||
|
_ = windows.TerminateProcess(m.agentProc, 0)
|
||||||
|
windows.CloseHandle(m.agentProc)
|
||||||
|
m.agentProc = 0
|
||||||
|
log.Info("killed old agent")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// relogAgentOutput reads JSON log lines from the agent's stderr pipe and
|
||||||
|
// relogs them at the correct level with the service's formatter.
|
||||||
|
func relogAgentOutput(pipe windows.Handle) {
|
||||||
|
defer windows.CloseHandle(pipe)
|
||||||
|
f := os.NewFile(uintptr(pipe), "vnc-agent-stderr")
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
entry := log.WithField("component", "vnc-agent")
|
||||||
|
dec := json.NewDecoder(f)
|
||||||
|
for dec.More() {
|
||||||
|
var m map[string]any
|
||||||
|
if err := dec.Decode(&m); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
msg, _ := m["msg"].(string)
|
||||||
|
if msg == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward extra fields from the agent (skip standard logrus fields).
|
||||||
|
// Remap "caller" to "source" so it doesn't conflict with logrus internals
|
||||||
|
// but still shows the original file/line from the agent process.
|
||||||
|
fields := make(log.Fields)
|
||||||
|
for k, v := range m {
|
||||||
|
switch k {
|
||||||
|
case "msg", "level", "time", "func":
|
||||||
|
continue
|
||||||
|
case "caller":
|
||||||
|
fields["source"] = v
|
||||||
|
default:
|
||||||
|
fields[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
e := entry.WithFields(fields)
|
||||||
|
|
||||||
|
switch m["level"] {
|
||||||
|
case "error":
|
||||||
|
e.Error(msg)
|
||||||
|
case "warning":
|
||||||
|
e.Warn(msg)
|
||||||
|
case "debug":
|
||||||
|
e.Debug(msg)
|
||||||
|
case "trace":
|
||||||
|
e.Trace(msg)
|
||||||
|
default:
|
||||||
|
e.Info(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// proxyToAgent connects to the agent, sends the auth token, then proxies
|
||||||
|
// the VNC client connection bidirectionally.
|
||||||
|
func proxyToAgent(client net.Conn, port string, authToken string) {
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
addr := "127.0.0.1:" + port
|
||||||
|
var agentConn net.Conn
|
||||||
|
var err error
|
||||||
|
for range 50 {
|
||||||
|
agentConn, err = net.DialTimeout("tcp", addr, time.Second)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("proxy cannot reach agent at %s: %v", addr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer agentConn.Close()
|
||||||
|
|
||||||
|
// Send the auth token so the agent can verify this connection
|
||||||
|
// comes from the trusted service process.
|
||||||
|
tokenBytes, _ := hex.DecodeString(authToken)
|
||||||
|
if _, err := agentConn.Write(tokenBytes); err != nil {
|
||||||
|
log.Warnf("send auth token to agent: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("proxy connected to agent, starting bidirectional copy")
|
||||||
|
|
||||||
|
done := make(chan struct{}, 2)
|
||||||
|
cp := func(label string, dst, src net.Conn) {
|
||||||
|
n, err := io.Copy(dst, src)
|
||||||
|
log.Debugf("proxy %s: %d bytes, err=%v", label, n, err)
|
||||||
|
done <- struct{}{}
|
||||||
|
}
|
||||||
|
go cp("client→agent", agentConn, client)
|
||||||
|
go cp("agent→client", client, agentConn)
|
||||||
|
<-done
|
||||||
|
}
|
||||||
486
client/vnc/server/capture_darwin.go
Normal file
486
client/vnc/server/capture_darwin.go
Normal file
@@ -0,0 +1,486 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"hash/maphash"
|
||||||
|
"image"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/ebitengine/purego"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
var darwinCaptureOnce sync.Once
|
||||||
|
|
||||||
|
var (
|
||||||
|
cgMainDisplayID func() uint32
|
||||||
|
cgDisplayPixelsWide func(uint32) uintptr
|
||||||
|
cgDisplayPixelsHigh func(uint32) uintptr
|
||||||
|
cgDisplayCreateImage func(uint32) uintptr
|
||||||
|
cgImageGetWidth func(uintptr) uintptr
|
||||||
|
cgImageGetHeight func(uintptr) uintptr
|
||||||
|
cgImageGetBytesPerRow func(uintptr) uintptr
|
||||||
|
cgImageGetBitsPerPixel func(uintptr) uintptr
|
||||||
|
cgImageGetDataProvider func(uintptr) uintptr
|
||||||
|
cgDataProviderCopyData func(uintptr) uintptr
|
||||||
|
cgImageRelease func(uintptr)
|
||||||
|
cfDataGetLength func(uintptr) int64
|
||||||
|
cfDataGetBytePtr func(uintptr) uintptr
|
||||||
|
cfRelease func(uintptr)
|
||||||
|
cgPreflightScreenCaptureAccess func() bool
|
||||||
|
cgRequestScreenCaptureAccess func() bool
|
||||||
|
darwinCaptureReady bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func initDarwinCapture() {
|
||||||
|
darwinCaptureOnce.Do(func() {
|
||||||
|
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("load CoreGraphics: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("load CoreFoundation: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
purego.RegisterLibFunc(&cgMainDisplayID, cg, "CGMainDisplayID")
|
||||||
|
purego.RegisterLibFunc(&cgDisplayPixelsWide, cg, "CGDisplayPixelsWide")
|
||||||
|
purego.RegisterLibFunc(&cgDisplayPixelsHigh, cg, "CGDisplayPixelsHigh")
|
||||||
|
purego.RegisterLibFunc(&cgDisplayCreateImage, cg, "CGDisplayCreateImage")
|
||||||
|
purego.RegisterLibFunc(&cgImageGetWidth, cg, "CGImageGetWidth")
|
||||||
|
purego.RegisterLibFunc(&cgImageGetHeight, cg, "CGImageGetHeight")
|
||||||
|
purego.RegisterLibFunc(&cgImageGetBytesPerRow, cg, "CGImageGetBytesPerRow")
|
||||||
|
purego.RegisterLibFunc(&cgImageGetBitsPerPixel, cg, "CGImageGetBitsPerPixel")
|
||||||
|
purego.RegisterLibFunc(&cgImageGetDataProvider, cg, "CGImageGetDataProvider")
|
||||||
|
purego.RegisterLibFunc(&cgDataProviderCopyData, cg, "CGDataProviderCopyData")
|
||||||
|
purego.RegisterLibFunc(&cgImageRelease, cg, "CGImageRelease")
|
||||||
|
purego.RegisterLibFunc(&cfDataGetLength, cf, "CFDataGetLength")
|
||||||
|
purego.RegisterLibFunc(&cfDataGetBytePtr, cf, "CFDataGetBytePtr")
|
||||||
|
purego.RegisterLibFunc(&cfRelease, cf, "CFRelease")
|
||||||
|
|
||||||
|
// Screen capture permission APIs (macOS 11+). Might not exist on older versions.
|
||||||
|
if sym, err := purego.Dlsym(cg, "CGPreflightScreenCaptureAccess"); err == nil {
|
||||||
|
purego.RegisterFunc(&cgPreflightScreenCaptureAccess, sym)
|
||||||
|
}
|
||||||
|
if sym, err := purego.Dlsym(cg, "CGRequestScreenCaptureAccess"); err == nil {
|
||||||
|
purego.RegisterFunc(&cgRequestScreenCaptureAccess, sym)
|
||||||
|
}
|
||||||
|
|
||||||
|
darwinCaptureReady = true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// errFrameUnchanged signals that the raw capture bytes matched the previous
|
||||||
|
// frame, so the caller can skip the expensive BGRA to RGBA conversion.
|
||||||
|
var errFrameUnchanged = errors.New("frame unchanged")
|
||||||
|
|
||||||
|
// CGCapturer captures the macOS main display using Core Graphics.
|
||||||
|
type CGCapturer struct {
|
||||||
|
displayID uint32
|
||||||
|
w, h int
|
||||||
|
// downscale is 1 for pixel-perfect, 2 for Retina 2:1 box-filter downscale.
|
||||||
|
downscale int
|
||||||
|
hashSeed maphash.Seed
|
||||||
|
lastHash uint64
|
||||||
|
hasHash bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCGCapturer creates a screen capturer for the main display.
|
||||||
|
func NewCGCapturer() (*CGCapturer, error) {
|
||||||
|
initDarwinCapture()
|
||||||
|
if !darwinCaptureReady {
|
||||||
|
return nil, fmt.Errorf("CoreGraphics not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request Screen Recording permission (shows system dialog on macOS 11+).
|
||||||
|
if cgPreflightScreenCaptureAccess != nil && !cgPreflightScreenCaptureAccess() {
|
||||||
|
if cgRequestScreenCaptureAccess != nil {
|
||||||
|
cgRequestScreenCaptureAccess()
|
||||||
|
}
|
||||||
|
openPrivacyPane("Privacy_ScreenCapture")
|
||||||
|
log.Warn("Screen Recording permission not granted. " +
|
||||||
|
"Opened System Settings > Privacy & Security > Screen Recording; enable netbird and restart.")
|
||||||
|
}
|
||||||
|
|
||||||
|
displayID := cgMainDisplayID()
|
||||||
|
c := &CGCapturer{displayID: displayID, downscale: 1, hashSeed: maphash.MakeSeed()}
|
||||||
|
|
||||||
|
// Probe actual pixel dimensions via a test capture. CGDisplayPixelsWide/High
|
||||||
|
// returns logical points on Retina, but CGDisplayCreateImage produces native
|
||||||
|
// pixels (often 2x), so probing the image is the only reliable source.
|
||||||
|
img, err := c.Capture()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("probe capture: %w", err)
|
||||||
|
}
|
||||||
|
nativeW := img.Rect.Dx()
|
||||||
|
nativeH := img.Rect.Dy()
|
||||||
|
c.hasHash = false
|
||||||
|
if nativeW == 0 || nativeH == 0 {
|
||||||
|
return nil, errors.New("display dimensions are zero")
|
||||||
|
}
|
||||||
|
|
||||||
|
logicalW := int(cgDisplayPixelsWide(displayID))
|
||||||
|
logicalH := int(cgDisplayPixelsHigh(displayID))
|
||||||
|
|
||||||
|
// Enable 2:1 downscale on Retina unless explicitly disabled. Cuts pixel
|
||||||
|
// count 4x, shrinking convert, diff, and wire data proportionally.
|
||||||
|
if !retinaDownscaleDisabled() && nativeW >= 2*logicalW && nativeH >= 2*logicalH && nativeW%2 == 0 && nativeH%2 == 0 {
|
||||||
|
c.downscale = 2
|
||||||
|
}
|
||||||
|
c.w = nativeW / c.downscale
|
||||||
|
c.h = nativeH / c.downscale
|
||||||
|
|
||||||
|
log.Infof("macOS capturer ready: %dx%d (native %dx%d, logical %dx%d, downscale=%d, display=%d)",
|
||||||
|
c.w, c.h, nativeW, nativeH, logicalW, logicalH, c.downscale, displayID)
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func retinaDownscaleDisabled() bool {
|
||||||
|
v := os.Getenv(EnvVNCDisableDownscale)
|
||||||
|
if v == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
disabled, err := strconv.ParseBool(v)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("parse %s: %v", EnvVNCDisableDownscale, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return disabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// Width returns the screen width.
|
||||||
|
func (c *CGCapturer) Width() int { return c.w }
|
||||||
|
|
||||||
|
// Height returns the screen height.
|
||||||
|
func (c *CGCapturer) Height() int { return c.h }
|
||||||
|
|
||||||
|
// Capture returns the current screen as an RGBA image.
|
||||||
|
func (c *CGCapturer) Capture() (*image.RGBA, error) {
|
||||||
|
cgImage := cgDisplayCreateImage(c.displayID)
|
||||||
|
if cgImage == 0 {
|
||||||
|
return nil, fmt.Errorf("CGDisplayCreateImage returned nil (screen recording permission?)")
|
||||||
|
}
|
||||||
|
defer cgImageRelease(cgImage)
|
||||||
|
|
||||||
|
w := int(cgImageGetWidth(cgImage))
|
||||||
|
h := int(cgImageGetHeight(cgImage))
|
||||||
|
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
|
||||||
|
bpp := int(cgImageGetBitsPerPixel(cgImage))
|
||||||
|
|
||||||
|
provider := cgImageGetDataProvider(cgImage)
|
||||||
|
if provider == 0 {
|
||||||
|
return nil, fmt.Errorf("CGImageGetDataProvider returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
cfData := cgDataProviderCopyData(provider)
|
||||||
|
if cfData == 0 {
|
||||||
|
return nil, fmt.Errorf("CGDataProviderCopyData returned nil")
|
||||||
|
}
|
||||||
|
defer cfRelease(cfData)
|
||||||
|
|
||||||
|
dataLen := int(cfDataGetLength(cfData))
|
||||||
|
dataPtr := cfDataGetBytePtr(cfData)
|
||||||
|
if dataPtr == 0 || dataLen == 0 {
|
||||||
|
return nil, fmt.Errorf("empty image data")
|
||||||
|
}
|
||||||
|
|
||||||
|
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
|
||||||
|
|
||||||
|
hash := maphash.Bytes(c.hashSeed, src)
|
||||||
|
if c.hasHash && hash == c.lastHash {
|
||||||
|
return nil, errFrameUnchanged
|
||||||
|
}
|
||||||
|
c.lastHash = hash
|
||||||
|
c.hasHash = true
|
||||||
|
|
||||||
|
ds := c.downscale
|
||||||
|
if ds < 1 {
|
||||||
|
ds = 1
|
||||||
|
}
|
||||||
|
outW := w / ds
|
||||||
|
outH := h / ds
|
||||||
|
img := image.NewRGBA(image.Rect(0, 0, outW, outH))
|
||||||
|
|
||||||
|
bytesPerPixel := bpp / 8
|
||||||
|
if bytesPerPixel == 4 && ds == 1 {
|
||||||
|
convertBGRAToRGBA(img.Pix, img.Stride, src, bytesPerRow, w, h)
|
||||||
|
} else if bytesPerPixel == 4 && ds == 2 {
|
||||||
|
convertBGRAToRGBADownscale2(img.Pix, img.Stride, src, bytesPerRow, outW, outH)
|
||||||
|
} else {
|
||||||
|
for row := 0; row < outH; row++ {
|
||||||
|
srcOff := row * ds * bytesPerRow
|
||||||
|
dstOff := row * img.Stride
|
||||||
|
for col := 0; col < outW; col++ {
|
||||||
|
si := srcOff + col*ds*bytesPerPixel
|
||||||
|
di := dstOff + col*4
|
||||||
|
img.Pix[di+0] = src[si+2]
|
||||||
|
img.Pix[di+1] = src[si+1]
|
||||||
|
img.Pix[di+2] = src[si+0]
|
||||||
|
img.Pix[di+3] = 0xff
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return img, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertBGRAToRGBADownscale2 averages every 2x2 BGRA block into one RGBA
|
||||||
|
// output pixel, parallelised across GOMAXPROCS cores. outW and outH are the
|
||||||
|
// destination dimensions (source is 2*outW by 2*outH).
|
||||||
|
func convertBGRAToRGBADownscale2(dst []byte, dstStride int, src []byte, srcStride, outW, outH int) {
|
||||||
|
workers := runtime.GOMAXPROCS(0)
|
||||||
|
if workers > outH {
|
||||||
|
workers = outH
|
||||||
|
}
|
||||||
|
if workers < 1 || outH < 32 {
|
||||||
|
workers = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
convertRows := func(y0, y1 int) {
|
||||||
|
for row := y0; row < y1; row++ {
|
||||||
|
srcRow0 := 2 * row * srcStride
|
||||||
|
srcRow1 := srcRow0 + srcStride
|
||||||
|
dstOff := row * dstStride
|
||||||
|
for col := 0; col < outW; col++ {
|
||||||
|
s0 := srcRow0 + col*8
|
||||||
|
s1 := srcRow1 + col*8
|
||||||
|
b := (uint32(src[s0]) + uint32(src[s0+4]) + uint32(src[s1]) + uint32(src[s1+4])) >> 2
|
||||||
|
g := (uint32(src[s0+1]) + uint32(src[s0+5]) + uint32(src[s1+1]) + uint32(src[s1+5])) >> 2
|
||||||
|
r := (uint32(src[s0+2]) + uint32(src[s0+6]) + uint32(src[s1+2]) + uint32(src[s1+6])) >> 2
|
||||||
|
di := dstOff + col*4
|
||||||
|
dst[di+0] = byte(r)
|
||||||
|
dst[di+1] = byte(g)
|
||||||
|
dst[di+2] = byte(b)
|
||||||
|
dst[di+3] = 0xff
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if workers == 1 {
|
||||||
|
convertRows(0, outH)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
chunk := (outH + workers - 1) / workers
|
||||||
|
for i := 0; i < workers; i++ {
|
||||||
|
y0 := i * chunk
|
||||||
|
y1 := y0 + chunk
|
||||||
|
if y1 > outH {
|
||||||
|
y1 = outH
|
||||||
|
}
|
||||||
|
if y0 >= y1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go func(y0, y1 int) {
|
||||||
|
defer wg.Done()
|
||||||
|
convertRows(y0, y1)
|
||||||
|
}(y0, y1)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertBGRAToRGBA swaps R/B channels using uint32 word operations, and
|
||||||
|
// parallelises across GOMAXPROCS cores for large images.
|
||||||
|
func convertBGRAToRGBA(dst []byte, dstStride int, src []byte, srcStride, w, h int) {
|
||||||
|
workers := runtime.GOMAXPROCS(0)
|
||||||
|
if workers > h {
|
||||||
|
workers = h
|
||||||
|
}
|
||||||
|
if workers < 1 || h < 64 {
|
||||||
|
workers = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
convertRows := func(y0, y1 int) {
|
||||||
|
rowBytes := w * 4
|
||||||
|
for row := y0; row < y1; row++ {
|
||||||
|
dstRow := dst[row*dstStride : row*dstStride+rowBytes]
|
||||||
|
srcRow := src[row*srcStride : row*srcStride+rowBytes]
|
||||||
|
dstU := unsafe.Slice((*uint32)(unsafe.Pointer(&dstRow[0])), w)
|
||||||
|
srcU := unsafe.Slice((*uint32)(unsafe.Pointer(&srcRow[0])), w)
|
||||||
|
for i, p := range srcU {
|
||||||
|
dstU[i] = (p & 0xff00ff00) | ((p & 0x000000ff) << 16) | ((p & 0x00ff0000) >> 16) | 0xff000000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if workers == 1 {
|
||||||
|
convertRows(0, h)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
chunk := (h + workers - 1) / workers
|
||||||
|
for i := 0; i < workers; i++ {
|
||||||
|
y0 := i * chunk
|
||||||
|
y1 := y0 + chunk
|
||||||
|
if y1 > h {
|
||||||
|
y1 = h
|
||||||
|
}
|
||||||
|
if y0 >= y1 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go func(y0, y1 int) {
|
||||||
|
defer wg.Done()
|
||||||
|
convertRows(y0, y1)
|
||||||
|
}(y0, y1)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MacPoller wraps CGCapturer in a continuous capture loop.
|
||||||
|
type MacPoller struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
frame *image.RGBA
|
||||||
|
w, h int
|
||||||
|
done chan struct{}
|
||||||
|
// wake shortens the init-retry backoff when a client is trying to connect,
|
||||||
|
// so granting Screen Recording mid-session takes effect immediately.
|
||||||
|
wake chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMacPoller creates a capturer that continuously grabs the macOS display.
|
||||||
|
func NewMacPoller() *MacPoller {
|
||||||
|
p := &MacPoller{
|
||||||
|
done: make(chan struct{}),
|
||||||
|
wake: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
go p.loop()
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wake pokes the init-retry loop so it doesn't wait out the full backoff
|
||||||
|
// before trying again. Safe to call from any goroutine; extra calls while a
|
||||||
|
// wake is pending are dropped.
|
||||||
|
func (p *MacPoller) Wake() {
|
||||||
|
select {
|
||||||
|
case p.wake <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the capture loop.
|
||||||
|
func (p *MacPoller) Close() {
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
default:
|
||||||
|
close(p.done)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Width returns the screen width.
|
||||||
|
func (p *MacPoller) Width() int {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
return p.w
|
||||||
|
}
|
||||||
|
|
||||||
|
// Height returns the screen height.
|
||||||
|
func (p *MacPoller) Height() int {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
return p.h
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture returns the most recent frame.
|
||||||
|
func (p *MacPoller) Capture() (*image.RGBA, error) {
|
||||||
|
p.mu.Lock()
|
||||||
|
img := p.frame
|
||||||
|
p.mu.Unlock()
|
||||||
|
if img != nil {
|
||||||
|
return img, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("no frame available yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *MacPoller) loop() {
|
||||||
|
var capturer *CGCapturer
|
||||||
|
var initFails int
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if capturer == nil {
|
||||||
|
var err error
|
||||||
|
capturer, err = NewCGCapturer()
|
||||||
|
if err != nil {
|
||||||
|
initFails++
|
||||||
|
// Retry forever with backoff: the user may grant Screen
|
||||||
|
// Recording after the server started, and we need to pick it
|
||||||
|
// up whenever that happens.
|
||||||
|
delay := 2 * time.Second
|
||||||
|
if initFails > 15 {
|
||||||
|
delay = 30 * time.Second
|
||||||
|
} else if initFails > 5 {
|
||||||
|
delay = 10 * time.Second
|
||||||
|
}
|
||||||
|
if initFails == 1 || initFails%10 == 0 {
|
||||||
|
log.Warnf("macOS capturer: %v (attempt %d, retrying every %s)", err, initFails, delay)
|
||||||
|
} else {
|
||||||
|
log.Debugf("macOS capturer: %v (attempt %d)", err, initFails)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return
|
||||||
|
case <-p.wake:
|
||||||
|
// Client is trying to connect, retry now.
|
||||||
|
case <-time.After(delay):
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
initFails = 0
|
||||||
|
p.mu.Lock()
|
||||||
|
p.w, p.h = capturer.Width(), capturer.Height()
|
||||||
|
p.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
img, err := capturer.Capture()
|
||||||
|
if errors.Is(err, errFrameUnchanged) {
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return
|
||||||
|
case <-time.After(33 * time.Millisecond):
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("macOS capture: %v", err)
|
||||||
|
capturer = nil
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
p.mu.Lock()
|
||||||
|
p.frame = img
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return
|
||||||
|
case <-time.After(33 * time.Millisecond): // ~30 fps
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ScreenCapturer = (*MacPoller)(nil)
|
||||||
99
client/vnc/server/capture_dxgi_windows.go
Normal file
99
client/vnc/server/capture_dxgi_windows.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
|
||||||
|
"github.com/kirides/go-d3d/d3d11"
|
||||||
|
"github.com/kirides/go-d3d/outputduplication"
|
||||||
|
)
|
||||||
|
|
||||||
|
// dxgiCapturer captures the desktop using DXGI Desktop Duplication.
|
||||||
|
// Provides GPU-accelerated capture with native dirty rect tracking.
|
||||||
|
// Only works from the interactive user session, not Session 0.
|
||||||
|
//
|
||||||
|
// Uses a double-buffer: DXGI writes into img, then we copy to the current
|
||||||
|
// output buffer and hand it out. Alternating between two output buffers
|
||||||
|
// avoids allocating a new image.RGBA per frame (~8MB at 1080p, 30fps).
|
||||||
|
type dxgiCapturer struct {
|
||||||
|
dup *outputduplication.OutputDuplicator
|
||||||
|
device *d3d11.ID3D11Device
|
||||||
|
ctx *d3d11.ID3D11DeviceContext
|
||||||
|
img *image.RGBA
|
||||||
|
out [2]*image.RGBA
|
||||||
|
outIdx int
|
||||||
|
width int
|
||||||
|
height int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newDXGICapturer() (*dxgiCapturer, error) {
|
||||||
|
device, deviceCtx, err := d3d11.NewD3D11Device()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create D3D11 device: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dup, err := outputduplication.NewIDXGIOutputDuplication(device, deviceCtx, 0)
|
||||||
|
if err != nil {
|
||||||
|
device.Release()
|
||||||
|
deviceCtx.Release()
|
||||||
|
return nil, fmt.Errorf("create output duplication: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w, h := screenSize()
|
||||||
|
if w == 0 || h == 0 {
|
||||||
|
dup.Release()
|
||||||
|
device.Release()
|
||||||
|
deviceCtx.Release()
|
||||||
|
return nil, fmt.Errorf("screen dimensions are zero")
|
||||||
|
}
|
||||||
|
|
||||||
|
rect := image.Rect(0, 0, w, h)
|
||||||
|
c := &dxgiCapturer{
|
||||||
|
dup: dup,
|
||||||
|
device: device,
|
||||||
|
ctx: deviceCtx,
|
||||||
|
img: image.NewRGBA(rect),
|
||||||
|
out: [2]*image.RGBA{image.NewRGBA(rect), image.NewRGBA(rect)},
|
||||||
|
width: w,
|
||||||
|
height: h,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Grab the initial frame with a longer timeout to ensure we have
|
||||||
|
// a valid image before returning.
|
||||||
|
_ = dup.GetImage(c.img, 2000)
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dxgiCapturer) capture() (*image.RGBA, error) {
|
||||||
|
err := c.dup.GetImage(c.img, 100)
|
||||||
|
if err != nil && !errors.Is(err, outputduplication.ErrNoImageYet) {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy into the next output buffer. The DesktopCapturer hands out the
|
||||||
|
// returned pointer to VNC sessions that read pixels concurrently, so we
|
||||||
|
// alternate between two pre-allocated buffers instead of allocating per frame.
|
||||||
|
out := c.out[c.outIdx]
|
||||||
|
c.outIdx ^= 1
|
||||||
|
copy(out.Pix, c.img.Pix)
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *dxgiCapturer) close() {
|
||||||
|
if c.dup != nil {
|
||||||
|
c.dup.Release()
|
||||||
|
c.dup = nil
|
||||||
|
}
|
||||||
|
if c.ctx != nil {
|
||||||
|
c.ctx.Release()
|
||||||
|
c.ctx = nil
|
||||||
|
}
|
||||||
|
if c.device != nil {
|
||||||
|
c.device.Release()
|
||||||
|
c.device = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
461
client/vnc/server/capture_windows.go
Normal file
461
client/vnc/server/capture_windows.go
Normal file
@@ -0,0 +1,461 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
gdi32 = windows.NewLazySystemDLL("gdi32.dll")
|
||||||
|
user32 = windows.NewLazySystemDLL("user32.dll")
|
||||||
|
|
||||||
|
procGetDC = user32.NewProc("GetDC")
|
||||||
|
procReleaseDC = user32.NewProc("ReleaseDC")
|
||||||
|
procCreateCompatDC = gdi32.NewProc("CreateCompatibleDC")
|
||||||
|
procCreateDIBSection = gdi32.NewProc("CreateDIBSection")
|
||||||
|
procSelectObject = gdi32.NewProc("SelectObject")
|
||||||
|
procDeleteObject = gdi32.NewProc("DeleteObject")
|
||||||
|
procDeleteDC = gdi32.NewProc("DeleteDC")
|
||||||
|
procBitBlt = gdi32.NewProc("BitBlt")
|
||||||
|
procGetSystemMetrics = user32.NewProc("GetSystemMetrics")
|
||||||
|
|
||||||
|
// Desktop switching for service/Session 0 capture.
|
||||||
|
procOpenInputDesktop = user32.NewProc("OpenInputDesktop")
|
||||||
|
procSetThreadDesktop = user32.NewProc("SetThreadDesktop")
|
||||||
|
procCloseDesktop = user32.NewProc("CloseDesktop")
|
||||||
|
procOpenWindowStation = user32.NewProc("OpenWindowStationW")
|
||||||
|
procSetProcessWindowStation = user32.NewProc("SetProcessWindowStation")
|
||||||
|
procCloseWindowStation = user32.NewProc("CloseWindowStation")
|
||||||
|
procGetUserObjectInformationW = user32.NewProc("GetUserObjectInformationW")
|
||||||
|
)
|
||||||
|
|
||||||
|
const uoiName = 2
|
||||||
|
|
||||||
|
const (
|
||||||
|
smCxScreen = 0
|
||||||
|
smCyScreen = 1
|
||||||
|
srccopy = 0x00CC0020
|
||||||
|
dibRgbColors = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
type bitmapInfoHeader struct {
|
||||||
|
Size uint32
|
||||||
|
Width int32
|
||||||
|
Height int32
|
||||||
|
Planes uint16
|
||||||
|
BitCount uint16
|
||||||
|
Compression uint32
|
||||||
|
SizeImage uint32
|
||||||
|
XPelsPerMeter int32
|
||||||
|
YPelsPerMeter int32
|
||||||
|
ClrUsed uint32
|
||||||
|
ClrImportant uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type bitmapInfo struct {
|
||||||
|
Header bitmapInfoHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupInteractiveWindowStation associates the current process with WinSta0,
|
||||||
|
// the interactive window station. This is required for a SYSTEM service in
|
||||||
|
// Session 0 to call OpenInputDesktop for screen capture and input injection.
|
||||||
|
func setupInteractiveWindowStation() error {
|
||||||
|
name, err := windows.UTF16PtrFromString("WinSta0")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("UTF16 WinSta0: %w", err)
|
||||||
|
}
|
||||||
|
hWinSta, _, err := procOpenWindowStation.Call(
|
||||||
|
uintptr(unsafe.Pointer(name)),
|
||||||
|
0,
|
||||||
|
uintptr(windows.MAXIMUM_ALLOWED),
|
||||||
|
)
|
||||||
|
if hWinSta == 0 {
|
||||||
|
return fmt.Errorf("OpenWindowStation(WinSta0): %w", err)
|
||||||
|
}
|
||||||
|
r, _, err := procSetProcessWindowStation.Call(hWinSta)
|
||||||
|
if r == 0 {
|
||||||
|
procCloseWindowStation.Call(hWinSta)
|
||||||
|
return fmt.Errorf("SetProcessWindowStation: %w", err)
|
||||||
|
}
|
||||||
|
log.Info("process window station set to WinSta0 (interactive)")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func screenSize() (int, int) {
|
||||||
|
w, _, _ := procGetSystemMetrics.Call(uintptr(smCxScreen))
|
||||||
|
h, _, _ := procGetSystemMetrics.Call(uintptr(smCyScreen))
|
||||||
|
return int(w), int(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDesktopName(hDesk uintptr) string {
|
||||||
|
var buf [256]uint16
|
||||||
|
var needed uint32
|
||||||
|
procGetUserObjectInformationW.Call(hDesk, uoiName,
|
||||||
|
uintptr(unsafe.Pointer(&buf[0])), 512,
|
||||||
|
uintptr(unsafe.Pointer(&needed)))
|
||||||
|
return windows.UTF16ToString(buf[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// switchToInputDesktop opens the desktop currently receiving user input
|
||||||
|
// and sets it as the calling OS thread's desktop. Must be called from a
|
||||||
|
// goroutine locked to its OS thread via runtime.LockOSThread().
|
||||||
|
func switchToInputDesktop() (bool, string) {
|
||||||
|
hDesk, _, _ := procOpenInputDesktop.Call(0, 0, uintptr(windows.MAXIMUM_ALLOWED))
|
||||||
|
if hDesk == 0 {
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
name := getDesktopName(hDesk)
|
||||||
|
ret, _, _ := procSetThreadDesktop.Call(hDesk)
|
||||||
|
procCloseDesktop.Call(hDesk)
|
||||||
|
return ret != 0, name
|
||||||
|
}
|
||||||
|
|
||||||
|
// gdiCapturer captures the desktop screen using GDI BitBlt.
|
||||||
|
// GDI objects (DC, DIBSection) are allocated once and reused across frames.
|
||||||
|
type gdiCapturer struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
width int
|
||||||
|
height int
|
||||||
|
|
||||||
|
// Pre-allocated GDI resources, reused across captures.
|
||||||
|
memDC uintptr
|
||||||
|
bmp uintptr
|
||||||
|
bits uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
func newGDICapturer() (*gdiCapturer, error) {
|
||||||
|
w, h := screenSize()
|
||||||
|
if w == 0 || h == 0 {
|
||||||
|
return nil, fmt.Errorf("screen dimensions are zero")
|
||||||
|
}
|
||||||
|
c := &gdiCapturer{width: w, height: h}
|
||||||
|
if err := c.allocGDI(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// allocGDI pre-allocates the compatible DC and DIB section for reuse.
|
||||||
|
func (c *gdiCapturer) allocGDI() error {
|
||||||
|
screenDC, _, _ := procGetDC.Call(0)
|
||||||
|
if screenDC == 0 {
|
||||||
|
return fmt.Errorf("GetDC returned 0")
|
||||||
|
}
|
||||||
|
defer procReleaseDC.Call(0, screenDC)
|
||||||
|
|
||||||
|
memDC, _, _ := procCreateCompatDC.Call(screenDC)
|
||||||
|
if memDC == 0 {
|
||||||
|
return fmt.Errorf("CreateCompatibleDC returned 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
bi := bitmapInfo{
|
||||||
|
Header: bitmapInfoHeader{
|
||||||
|
Size: uint32(unsafe.Sizeof(bitmapInfoHeader{})),
|
||||||
|
Width: int32(c.width),
|
||||||
|
Height: -int32(c.height), // negative = top-down DIB
|
||||||
|
Planes: 1,
|
||||||
|
BitCount: 32,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var bits uintptr
|
||||||
|
bmp, _, _ := procCreateDIBSection.Call(
|
||||||
|
screenDC,
|
||||||
|
uintptr(unsafe.Pointer(&bi)),
|
||||||
|
dibRgbColors,
|
||||||
|
uintptr(unsafe.Pointer(&bits)),
|
||||||
|
0, 0,
|
||||||
|
)
|
||||||
|
if bmp == 0 || bits == 0 {
|
||||||
|
procDeleteDC.Call(memDC)
|
||||||
|
return fmt.Errorf("CreateDIBSection returned 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
procSelectObject.Call(memDC, bmp)
|
||||||
|
|
||||||
|
c.memDC = memDC
|
||||||
|
c.bmp = bmp
|
||||||
|
c.bits = bits
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *gdiCapturer) close() { c.freeGDI() }
|
||||||
|
|
||||||
|
// freeGDI releases pre-allocated GDI resources.
|
||||||
|
func (c *gdiCapturer) freeGDI() {
|
||||||
|
if c.bmp != 0 {
|
||||||
|
procDeleteObject.Call(c.bmp)
|
||||||
|
c.bmp = 0
|
||||||
|
}
|
||||||
|
if c.memDC != 0 {
|
||||||
|
procDeleteDC.Call(c.memDC)
|
||||||
|
c.memDC = 0
|
||||||
|
}
|
||||||
|
c.bits = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *gdiCapturer) capture() (*image.RGBA, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if c.memDC == 0 {
|
||||||
|
return nil, fmt.Errorf("GDI resources not allocated")
|
||||||
|
}
|
||||||
|
|
||||||
|
screenDC, _, _ := procGetDC.Call(0)
|
||||||
|
if screenDC == 0 {
|
||||||
|
return nil, fmt.Errorf("GetDC returned 0")
|
||||||
|
}
|
||||||
|
defer procReleaseDC.Call(0, screenDC)
|
||||||
|
|
||||||
|
ret, _, _ := procBitBlt.Call(c.memDC, 0, 0, uintptr(c.width), uintptr(c.height),
|
||||||
|
screenDC, 0, 0, srccopy)
|
||||||
|
if ret == 0 {
|
||||||
|
return nil, fmt.Errorf("BitBlt returned 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
n := c.width * c.height * 4
|
||||||
|
raw := unsafe.Slice((*byte)(unsafe.Pointer(c.bits)), n)
|
||||||
|
|
||||||
|
// GDI gives BGRA, the RFB encoder expects RGBA (img.Pix layout).
|
||||||
|
// Swap R and B in bulk using uint32 operations (one load + mask + shift
|
||||||
|
// per pixel instead of three separate byte assignments).
|
||||||
|
img := image.NewRGBA(image.Rect(0, 0, c.width, c.height))
|
||||||
|
pix := img.Pix
|
||||||
|
copy(pix, raw)
|
||||||
|
swizzleBGRAtoRGBA(pix)
|
||||||
|
return img, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DesktopCapturer captures the interactive desktop, handling desktop transitions
|
||||||
|
// (login screen, UAC prompts). A dedicated OS-locked goroutine continuously
|
||||||
|
// captures frames, which are retrieved by the VNC session on demand.
|
||||||
|
// Capture pauses automatically when no clients are connected.
|
||||||
|
type DesktopCapturer struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
frame *image.RGBA
|
||||||
|
w, h int
|
||||||
|
|
||||||
|
// clients tracks the number of active VNC sessions. When zero, the
|
||||||
|
// capture loop idles instead of grabbing frames.
|
||||||
|
clients atomic.Int32
|
||||||
|
|
||||||
|
// wake is signaled when a client connects and the loop should resume.
|
||||||
|
wake chan struct{}
|
||||||
|
// done is closed when Close is called, terminating the capture loop.
|
||||||
|
done chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDesktopCapturer creates a capturer that continuously grabs the active desktop.
|
||||||
|
func NewDesktopCapturer() *DesktopCapturer {
|
||||||
|
c := &DesktopCapturer{
|
||||||
|
wake: make(chan struct{}, 1),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
go c.loop()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientConnect increments the active client count, resuming capture if needed.
|
||||||
|
func (c *DesktopCapturer) ClientConnect() {
|
||||||
|
c.clients.Add(1)
|
||||||
|
select {
|
||||||
|
case c.wake <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientDisconnect decrements the active client count.
|
||||||
|
func (c *DesktopCapturer) ClientDisconnect() {
|
||||||
|
c.clients.Add(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the capture loop and releases resources.
|
||||||
|
func (c *DesktopCapturer) Close() {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
default:
|
||||||
|
close(c.done)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Width returns the current screen width.
|
||||||
|
func (c *DesktopCapturer) Width() int {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
return c.w
|
||||||
|
}
|
||||||
|
|
||||||
|
// Height returns the current screen height.
|
||||||
|
func (c *DesktopCapturer) Height() int {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
return c.h
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture returns the most recent desktop frame.
|
||||||
|
func (c *DesktopCapturer) Capture() (*image.RGBA, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
img := c.frame
|
||||||
|
c.mu.Unlock()
|
||||||
|
if img != nil {
|
||||||
|
return img, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("no frame available yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForClient blocks until a client connects or the capturer is closed.
|
||||||
|
func (c *DesktopCapturer) waitForClient() bool {
|
||||||
|
if c.clients.Load() > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-c.wake:
|
||||||
|
return true
|
||||||
|
case <-c.done:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DesktopCapturer) loop() {
|
||||||
|
runtime.LockOSThread()
|
||||||
|
|
||||||
|
// When running as a Windows service (Session 0), we need to attach to the
|
||||||
|
// interactive window station before OpenInputDesktop will succeed.
|
||||||
|
if err := setupInteractiveWindowStation(); err != nil {
|
||||||
|
log.Warnf("attach to interactive window station: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
frameTicker := time.NewTicker(33 * time.Millisecond) // ~30 fps
|
||||||
|
defer frameTicker.Stop()
|
||||||
|
|
||||||
|
retryTimer := time.NewTimer(0)
|
||||||
|
retryTimer.Stop()
|
||||||
|
defer retryTimer.Stop()
|
||||||
|
|
||||||
|
type frameCapturer interface {
|
||||||
|
capture() (*image.RGBA, error)
|
||||||
|
close()
|
||||||
|
}
|
||||||
|
|
||||||
|
var cap frameCapturer
|
||||||
|
var desktopFails int
|
||||||
|
var lastDesktop string
|
||||||
|
|
||||||
|
createCapturer := func() (frameCapturer, error) {
|
||||||
|
dc, err := newDXGICapturer()
|
||||||
|
if err == nil {
|
||||||
|
log.Info("using DXGI Desktop Duplication for capture")
|
||||||
|
return dc, nil
|
||||||
|
}
|
||||||
|
log.Debugf("DXGI unavailable (%v), falling back to GDI", err)
|
||||||
|
gc, err := newGDICapturer()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.Info("using GDI BitBlt for capture")
|
||||||
|
return gc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
if !c.waitForClient() {
|
||||||
|
if cap != nil {
|
||||||
|
cap.close()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// No clients: release the capturer and wait.
|
||||||
|
if c.clients.Load() <= 0 {
|
||||||
|
if cap != nil {
|
||||||
|
cap.close()
|
||||||
|
cap = nil
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, desk := switchToInputDesktop()
|
||||||
|
if !ok {
|
||||||
|
desktopFails++
|
||||||
|
if desktopFails == 1 || desktopFails%100 == 0 {
|
||||||
|
log.Warnf("switchToInputDesktop failed (count=%d), no interactive desktop session?", desktopFails)
|
||||||
|
}
|
||||||
|
retryTimer.Reset(100 * time.Millisecond)
|
||||||
|
select {
|
||||||
|
case <-retryTimer.C:
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if desktopFails > 0 {
|
||||||
|
log.Infof("switchToInputDesktop recovered after %d failures, desktop=%q", desktopFails, desk)
|
||||||
|
desktopFails = 0
|
||||||
|
}
|
||||||
|
if desk != lastDesktop {
|
||||||
|
log.Infof("desktop changed: %q -> %q", lastDesktop, desk)
|
||||||
|
lastDesktop = desk
|
||||||
|
if cap != nil {
|
||||||
|
cap.close()
|
||||||
|
}
|
||||||
|
cap = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if cap == nil {
|
||||||
|
fc, err := createCapturer()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("create capturer: %v", err)
|
||||||
|
retryTimer.Reset(500 * time.Millisecond)
|
||||||
|
select {
|
||||||
|
case <-retryTimer.C:
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cap = fc
|
||||||
|
w, h := screenSize()
|
||||||
|
c.mu.Lock()
|
||||||
|
c.w, c.h = w, h
|
||||||
|
c.mu.Unlock()
|
||||||
|
log.Infof("screen capturer ready: %dx%d", w, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
img, err := cap.capture()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("capture: %v", err)
|
||||||
|
cap.close()
|
||||||
|
cap = nil
|
||||||
|
retryTimer.Reset(100 * time.Millisecond)
|
||||||
|
select {
|
||||||
|
case <-retryTimer.C:
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
c.frame = img
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-frameTicker.C:
|
||||||
|
case <-c.done:
|
||||||
|
if cap != nil {
|
||||||
|
cap.close()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
385
client/vnc/server/capture_x11.go
Normal file
385
client/vnc/server/capture_x11.go
Normal file
@@ -0,0 +1,385 @@
|
|||||||
|
//go:build (linux && !android) || freebsd
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/jezek/xgb"
|
||||||
|
"github.com/jezek/xgb/xproto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// X11Capturer captures the screen from an X11 display using the MIT-SHM extension.
|
||||||
|
type X11Capturer struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
conn *xgb.Conn
|
||||||
|
screen *xproto.ScreenInfo
|
||||||
|
w, h int
|
||||||
|
shmID int
|
||||||
|
shmAddr []byte
|
||||||
|
shmSeg uint32 // shm.Seg
|
||||||
|
useSHM bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectX11Display finds the active X11 display and sets DISPLAY/XAUTHORITY
|
||||||
|
// environment variables if needed. This is required when running as a system
|
||||||
|
// service where these vars aren't set.
|
||||||
|
func detectX11Display() {
|
||||||
|
if os.Getenv("DISPLAY") != "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try /proc first (Linux), then ps fallback (FreeBSD and others).
|
||||||
|
if detectX11FromProc() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if detectX11FromSockets() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectX11FromProc scans /proc/*/cmdline for Xorg (Linux).
|
||||||
|
func detectX11FromProc() bool {
|
||||||
|
entries, err := os.ReadDir("/proc")
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, e := range entries {
|
||||||
|
if !e.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cmdline, err := os.ReadFile("/proc/" + e.Name() + "/cmdline")
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if display, auth := parseXorgArgs(splitCmdline(cmdline)); display != "" {
|
||||||
|
setDisplayEnv(display, auth)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectX11FromSockets checks /tmp/.X11-unix/ for X sockets and uses ps
|
||||||
|
// to find the auth file. Works on FreeBSD and other systems without /proc.
|
||||||
|
func detectX11FromSockets() bool {
|
||||||
|
entries, err := os.ReadDir("/tmp/.X11-unix")
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the lowest display number.
|
||||||
|
for _, e := range entries {
|
||||||
|
name := e.Name()
|
||||||
|
if len(name) < 2 || name[0] != 'X' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
display := ":" + name[1:]
|
||||||
|
os.Setenv("DISPLAY", display)
|
||||||
|
log.Infof("auto-detected DISPLAY=%s (from socket)", display)
|
||||||
|
|
||||||
|
// Try to find -auth from ps output.
|
||||||
|
if auth := findXorgAuthFromPS(); auth != "" {
|
||||||
|
os.Setenv("XAUTHORITY", auth)
|
||||||
|
log.Infof("auto-detected XAUTHORITY=%s (from ps)", auth)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// findXorgAuthFromPS runs ps to find Xorg and extract its -auth argument.
|
||||||
|
func findXorgAuthFromPS() string {
|
||||||
|
out, err := exec.Command("ps", "auxww").Output()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for _, line := range strings.Split(string(out), "\n") {
|
||||||
|
if !strings.Contains(line, "Xorg") && !strings.Contains(line, "/X ") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
for i, f := range fields {
|
||||||
|
if f == "-auth" && i+1 < len(fields) {
|
||||||
|
return fields[i+1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseXorgArgs(args []string) (display, auth string) {
|
||||||
|
if len(args) == 0 {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
base := args[0]
|
||||||
|
if !(base == "Xorg" || base == "X" || len(base) > 0 && base[len(base)-1] == 'X' ||
|
||||||
|
strings.Contains(base, "/Xorg") || strings.Contains(base, "/X")) {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
for i, arg := range args[1:] {
|
||||||
|
if len(arg) > 0 && arg[0] == ':' {
|
||||||
|
display = arg
|
||||||
|
}
|
||||||
|
if arg == "-auth" && i+2 < len(args) {
|
||||||
|
auth = args[i+2]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return display, auth
|
||||||
|
}
|
||||||
|
|
||||||
|
func setDisplayEnv(display, auth string) {
|
||||||
|
os.Setenv("DISPLAY", display)
|
||||||
|
log.Infof("auto-detected DISPLAY=%s", display)
|
||||||
|
if auth != "" {
|
||||||
|
os.Setenv("XAUTHORITY", auth)
|
||||||
|
log.Infof("auto-detected XAUTHORITY=%s", auth)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitCmdline(data []byte) []string {
|
||||||
|
var args []string
|
||||||
|
for _, b := range splitNull(data) {
|
||||||
|
if len(b) > 0 {
|
||||||
|
args = append(args, string(b))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitNull(data []byte) [][]byte {
|
||||||
|
var parts [][]byte
|
||||||
|
start := 0
|
||||||
|
for i, b := range data {
|
||||||
|
if b == 0 {
|
||||||
|
parts = append(parts, data[start:i])
|
||||||
|
start = i + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if start < len(data) {
|
||||||
|
parts = append(parts, data[start:])
|
||||||
|
}
|
||||||
|
return parts
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewX11Capturer connects to the X11 display and sets up shared memory capture.
|
||||||
|
func NewX11Capturer(display string) (*X11Capturer, error) {
|
||||||
|
detectX11Display()
|
||||||
|
|
||||||
|
if display == "" {
|
||||||
|
display = os.Getenv("DISPLAY")
|
||||||
|
}
|
||||||
|
if display == "" {
|
||||||
|
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := xgb.NewConnDisplay(display)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
setup := xproto.Setup(conn)
|
||||||
|
if len(setup.Roots) == 0 {
|
||||||
|
conn.Close()
|
||||||
|
return nil, fmt.Errorf("no X11 screens")
|
||||||
|
}
|
||||||
|
screen := setup.Roots[0]
|
||||||
|
|
||||||
|
c := &X11Capturer{
|
||||||
|
conn: conn,
|
||||||
|
screen: &screen,
|
||||||
|
w: int(screen.WidthInPixels),
|
||||||
|
h: int(screen.HeightInPixels),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.initSHM(); err != nil {
|
||||||
|
log.Debugf("X11 SHM not available, using slow GetImage: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("X11 capturer ready: %dx%d (display=%s, shm=%v)", c.w, c.h, display, c.useSHM)
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// initSHM is implemented in capture_x11_shm_linux.go (requires SysV SHM).
|
||||||
|
// On platforms without SysV SHM (FreeBSD), a stub returns an error and
|
||||||
|
// the capturer falls back to GetImage.
|
||||||
|
|
||||||
|
// Width returns the screen width.
|
||||||
|
func (c *X11Capturer) Width() int { return c.w }
|
||||||
|
|
||||||
|
// Height returns the screen height.
|
||||||
|
func (c *X11Capturer) Height() int { return c.h }
|
||||||
|
|
||||||
|
// Capture returns the current screen as an RGBA image.
|
||||||
|
func (c *X11Capturer) Capture() (*image.RGBA, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if c.useSHM {
|
||||||
|
return c.captureSHM()
|
||||||
|
}
|
||||||
|
return c.captureGetImage()
|
||||||
|
}
|
||||||
|
|
||||||
|
// captureSHM is implemented in capture_x11_shm_linux.go.
|
||||||
|
|
||||||
|
func (c *X11Capturer) captureGetImage() (*image.RGBA, error) {
|
||||||
|
cookie := xproto.GetImage(c.conn, xproto.ImageFormatZPixmap,
|
||||||
|
xproto.Drawable(c.screen.Root),
|
||||||
|
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF)
|
||||||
|
|
||||||
|
reply, err := cookie.Reply()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("GetImage: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
|
||||||
|
data := reply.Data
|
||||||
|
n := c.w * c.h * 4
|
||||||
|
if len(data) < n {
|
||||||
|
return nil, fmt.Errorf("GetImage returned %d bytes, expected %d", len(data), n)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < n; i += 4 {
|
||||||
|
img.Pix[i+0] = data[i+2] // R
|
||||||
|
img.Pix[i+1] = data[i+1] // G
|
||||||
|
img.Pix[i+2] = data[i+0] // B
|
||||||
|
img.Pix[i+3] = 0xff
|
||||||
|
}
|
||||||
|
return img, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close releases X11 resources.
|
||||||
|
func (c *X11Capturer) Close() {
|
||||||
|
c.closeSHM()
|
||||||
|
c.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeSHM is implemented in capture_x11_shm_linux.go.
|
||||||
|
|
||||||
|
// X11Poller wraps X11Capturer in a continuous capture loop, matching the
|
||||||
|
// DesktopCapturer pattern from Windows.
|
||||||
|
type X11Poller struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
frame *image.RGBA
|
||||||
|
w, h int
|
||||||
|
display string
|
||||||
|
done chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewX11Poller creates a capturer that continuously grabs the X11 display.
|
||||||
|
func NewX11Poller(display string) *X11Poller {
|
||||||
|
p := &X11Poller{
|
||||||
|
display: display,
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
go p.loop()
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the capture loop.
|
||||||
|
func (p *X11Poller) Close() {
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
default:
|
||||||
|
close(p.done)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Width returns the screen width.
|
||||||
|
func (p *X11Poller) Width() int {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
return p.w
|
||||||
|
}
|
||||||
|
|
||||||
|
// Height returns the screen height.
|
||||||
|
func (p *X11Poller) Height() int {
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
return p.h
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture returns the most recent frame.
|
||||||
|
func (p *X11Poller) Capture() (*image.RGBA, error) {
|
||||||
|
p.mu.Lock()
|
||||||
|
img := p.frame
|
||||||
|
p.mu.Unlock()
|
||||||
|
if img != nil {
|
||||||
|
return img, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("no frame available yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *X11Poller) loop() {
|
||||||
|
var capturer *X11Capturer
|
||||||
|
var initFails int
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if capturer != nil {
|
||||||
|
capturer.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if capturer == nil {
|
||||||
|
var err error
|
||||||
|
capturer, err = NewX11Capturer(p.display)
|
||||||
|
if err != nil {
|
||||||
|
initFails++
|
||||||
|
if initFails <= maxCapturerRetries {
|
||||||
|
log.Debugf("X11 capturer: %v (attempt %d/%d)", err, initFails, maxCapturerRetries)
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Warnf("X11 capturer unavailable after %d attempts, stopping poller", maxCapturerRetries)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
initFails = 0
|
||||||
|
p.mu.Lock()
|
||||||
|
p.w, p.h = capturer.Width(), capturer.Height()
|
||||||
|
p.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
img, err := capturer.Capture()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("X11 capture: %v", err)
|
||||||
|
capturer.Close()
|
||||||
|
capturer = nil
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
p.mu.Lock()
|
||||||
|
p.frame = img
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return
|
||||||
|
case <-time.After(33 * time.Millisecond): // ~30 fps
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
78
client/vnc/server/capture_x11_shm_linux.go
Normal file
78
client/vnc/server/capture_x11_shm_linux.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
|
||||||
|
"github.com/jezek/xgb/shm"
|
||||||
|
"github.com/jezek/xgb/xproto"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *X11Capturer) initSHM() error {
|
||||||
|
if err := shm.Init(c.conn); err != nil {
|
||||||
|
return fmt.Errorf("init SHM extension: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
size := c.w * c.h * 4
|
||||||
|
id, err := unix.SysvShmGet(unix.IPC_PRIVATE, size, unix.IPC_CREAT|0600)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("shmget: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, err := unix.SysvShmAttach(id, 0, 0)
|
||||||
|
if err != nil {
|
||||||
|
unix.SysvShmCtl(id, unix.IPC_RMID, nil)
|
||||||
|
return fmt.Errorf("shmat: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
unix.SysvShmCtl(id, unix.IPC_RMID, nil)
|
||||||
|
|
||||||
|
seg, err := shm.NewSegId(c.conn)
|
||||||
|
if err != nil {
|
||||||
|
unix.SysvShmDetach(addr)
|
||||||
|
return fmt.Errorf("new SHM seg: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := shm.AttachChecked(c.conn, seg, uint32(id), false).Check(); err != nil {
|
||||||
|
unix.SysvShmDetach(addr)
|
||||||
|
return fmt.Errorf("SHM attach to X: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.shmID = id
|
||||||
|
c.shmAddr = addr
|
||||||
|
c.shmSeg = uint32(seg)
|
||||||
|
c.useSHM = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
|
||||||
|
cookie := shm.GetImage(c.conn, xproto.Drawable(c.screen.Root),
|
||||||
|
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF,
|
||||||
|
xproto.ImageFormatZPixmap, shm.Seg(c.shmSeg), 0)
|
||||||
|
|
||||||
|
_, err := cookie.Reply()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("SHM GetImage: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
|
||||||
|
n := c.w * c.h * 4
|
||||||
|
|
||||||
|
for i := 0; i < n; i += 4 {
|
||||||
|
img.Pix[i+0] = c.shmAddr[i+2] // R
|
||||||
|
img.Pix[i+1] = c.shmAddr[i+1] // G
|
||||||
|
img.Pix[i+2] = c.shmAddr[i+0] // B
|
||||||
|
img.Pix[i+3] = 0xff
|
||||||
|
}
|
||||||
|
return img, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *X11Capturer) closeSHM() {
|
||||||
|
if c.useSHM {
|
||||||
|
shm.Detach(c.conn, shm.Seg(c.shmSeg))
|
||||||
|
unix.SysvShmDetach(c.shmAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
18
client/vnc/server/capture_x11_shm_stub.go
Normal file
18
client/vnc/server/capture_x11_shm_stub.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
//go:build freebsd
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *X11Capturer) initSHM() error {
|
||||||
|
return fmt.Errorf("SysV SHM not available on this platform")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
|
||||||
|
return nil, fmt.Errorf("SHM capture not available on this platform")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *X11Capturer) closeSHM() {}
|
||||||
151
client/vnc/server/crypto.go
Normal file
151
client/vnc/server/crypto.go
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/ecdh"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"crypto/sha256"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/hkdf"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
aesKeySize = 32 // AES-256
|
||||||
|
gcmNonceSize = 12
|
||||||
|
)
|
||||||
|
|
||||||
|
// recCrypto holds per-session encryption state.
|
||||||
|
type recCrypto struct {
|
||||||
|
gcm cipher.AEAD
|
||||||
|
frameCounter uint64
|
||||||
|
// ephemeralPub is stored in the recording header so the admin can derive the same key.
|
||||||
|
ephemeralPub []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// newRecCrypto sets up encryption for a new recording session.
|
||||||
|
// adminPubKeyB64 is the base64-encoded X25519 public key from management settings.
|
||||||
|
func newRecCrypto(adminPubKeyB64 string) (*recCrypto, error) {
|
||||||
|
adminPubBytes, err := base64.StdEncoding.DecodeString(adminPubKeyB64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decode admin public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
adminPub, err := ecdh.X25519().NewPublicKey(adminPubBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse admin X25519 public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate ephemeral keypair
|
||||||
|
ephemeral, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate ephemeral key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ECDH shared secret
|
||||||
|
shared, err := ephemeral.ECDH(adminPub)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("ECDH: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Derive AES-256 key via HKDF
|
||||||
|
aesKey, err := deriveKey(shared, ephemeral.PublicKey().Bytes())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("derive key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(aesKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create AES cipher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create GCM: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &recCrypto{
|
||||||
|
gcm: gcm,
|
||||||
|
ephemeralPub: ephemeral.PublicKey().Bytes(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// encrypt encrypts plaintext using a counter-based nonce. Each call increments the counter.
|
||||||
|
func (c *recCrypto) encrypt(plaintext []byte) []byte {
|
||||||
|
nonce := make([]byte, gcmNonceSize)
|
||||||
|
binary.LittleEndian.PutUint64(nonce, c.frameCounter)
|
||||||
|
c.frameCounter++
|
||||||
|
return c.gcm.Seal(nil, nonce, plaintext, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecryptRecording creates a decryptor from the admin's private key and the ephemeral public key from the header.
|
||||||
|
func DecryptRecording(adminPrivKeyB64 string, ephemeralPubB64 string) (*recDecryptor, error) {
|
||||||
|
adminPrivBytes, err := base64.StdEncoding.DecodeString(adminPrivKeyB64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decode admin private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
adminPriv, err := ecdh.X25519().NewPrivateKey(adminPrivBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse admin X25519 private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ephPubBytes, err := base64.StdEncoding.DecodeString(ephemeralPubB64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decode ephemeral public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ephPub, err := ecdh.X25519().NewPublicKey(ephPubBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse ephemeral public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
shared, err := adminPriv.ECDH(ephPub)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("ECDH: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
aesKey, err := deriveKey(shared, ephPubBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("derive key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(aesKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create AES cipher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create GCM: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &recDecryptor{gcm: gcm}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type recDecryptor struct {
|
||||||
|
gcm cipher.AEAD
|
||||||
|
frameCounter uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt decrypts a frame. Must be called in the same order as encryption.
|
||||||
|
func (d *recDecryptor) Decrypt(ciphertext []byte) ([]byte, error) {
|
||||||
|
nonce := make([]byte, gcmNonceSize)
|
||||||
|
binary.LittleEndian.PutUint64(nonce, d.frameCounter)
|
||||||
|
d.frameCounter++
|
||||||
|
return d.gcm.Open(nil, nonce, ciphertext, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func deriveKey(shared, ephemeralPub []byte) ([]byte, error) {
|
||||||
|
hkdfReader := hkdf.New(sha256.New, shared, ephemeralPub, []byte("netbird-recording"))
|
||||||
|
key := make([]byte, aesKeySize)
|
||||||
|
if _, err := io.ReadFull(hkdfReader, key); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
129
client/vnc/server/crypto_test.go
Normal file
129
client/vnc/server/crypto_test.go
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdh"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCryptoRoundtrip(t *testing.T) {
|
||||||
|
// Generate admin keypair
|
||||||
|
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
|
||||||
|
adminPrivB64 := base64.StdEncoding.EncodeToString(adminPriv.Bytes())
|
||||||
|
|
||||||
|
// Create encryptor (recording side)
|
||||||
|
enc, err := newRecCrypto(adminPubB64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, enc.ephemeralPub, 32)
|
||||||
|
|
||||||
|
ephPubB64 := base64.StdEncoding.EncodeToString(enc.ephemeralPub)
|
||||||
|
|
||||||
|
// Encrypt some frames
|
||||||
|
plaintext1 := []byte("frame data one - PNG bytes would go here")
|
||||||
|
plaintext2 := []byte("frame data two - different content")
|
||||||
|
plaintext3 := make([]byte, 1024*100) // 100KB frame
|
||||||
|
rand.Read(plaintext3)
|
||||||
|
|
||||||
|
ct1 := enc.encrypt(plaintext1)
|
||||||
|
ct2 := enc.encrypt(plaintext2)
|
||||||
|
ct3 := enc.encrypt(plaintext3)
|
||||||
|
|
||||||
|
// Ciphertext should differ from plaintext
|
||||||
|
assert.NotEqual(t, plaintext1, ct1)
|
||||||
|
// Ciphertext is larger (GCM tag overhead)
|
||||||
|
assert.Greater(t, len(ct1), len(plaintext1))
|
||||||
|
|
||||||
|
// Create decryptor (playback side)
|
||||||
|
dec, err := DecryptRecording(adminPrivB64, ephPubB64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Decrypt in same order
|
||||||
|
got1, err := dec.Decrypt(ct1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, plaintext1, got1)
|
||||||
|
|
||||||
|
got2, err := dec.Decrypt(ct2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, plaintext2, got2)
|
||||||
|
|
||||||
|
got3, err := dec.Decrypt(ct3)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, plaintext3, got3)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCryptoWrongKey(t *testing.T) {
|
||||||
|
// Admin key
|
||||||
|
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
|
||||||
|
|
||||||
|
// Encrypt with admin's public key
|
||||||
|
enc, err := newRecCrypto(adminPubB64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
ephPubB64 := base64.StdEncoding.EncodeToString(enc.ephemeralPub)
|
||||||
|
|
||||||
|
ct := enc.encrypt([]byte("secret frame data"))
|
||||||
|
|
||||||
|
// Try to decrypt with a different private key
|
||||||
|
wrongPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
wrongPrivB64 := base64.StdEncoding.EncodeToString(wrongPriv.Bytes())
|
||||||
|
|
||||||
|
dec, err := DecryptRecording(wrongPrivB64, ephPubB64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = dec.Decrypt(ct)
|
||||||
|
assert.Error(t, err, "decryption with wrong key should fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCryptoInvalidKey(t *testing.T) {
|
||||||
|
_, err := newRecCrypto("")
|
||||||
|
assert.Error(t, err, "empty key should fail")
|
||||||
|
|
||||||
|
_, err = newRecCrypto("not-base64!!!")
|
||||||
|
assert.Error(t, err, "invalid base64 should fail")
|
||||||
|
|
||||||
|
_, err = newRecCrypto(base64.StdEncoding.EncodeToString([]byte("too-short")))
|
||||||
|
assert.Error(t, err, "wrong-length key should fail")
|
||||||
|
|
||||||
|
_, err = DecryptRecording("", "validbutirrelevant")
|
||||||
|
assert.Error(t, err, "empty private key should fail")
|
||||||
|
|
||||||
|
_, err = DecryptRecording("not-base64!!!", base64.StdEncoding.EncodeToString(make([]byte, 32)))
|
||||||
|
assert.Error(t, err, "invalid base64 private key should fail")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCryptoOutOfOrderFails(t *testing.T) {
|
||||||
|
adminPriv, err := ecdh.X25519().GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
adminPubB64 := base64.StdEncoding.EncodeToString(adminPriv.PublicKey().Bytes())
|
||||||
|
adminPrivB64 := base64.StdEncoding.EncodeToString(adminPriv.Bytes())
|
||||||
|
|
||||||
|
enc, err := newRecCrypto(adminPubB64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
ephPubB64 := base64.StdEncoding.EncodeToString(enc.ephemeralPub)
|
||||||
|
|
||||||
|
ct0 := enc.encrypt([]byte("frame 0"))
|
||||||
|
ct1 := enc.encrypt([]byte("frame 1"))
|
||||||
|
|
||||||
|
dec, err := DecryptRecording(adminPrivB64, ephPubB64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Skip frame 0, try to decrypt frame 1 first (wrong nonce)
|
||||||
|
_, err = dec.Decrypt(ct1)
|
||||||
|
assert.Error(t, err, "out-of-order decryption should fail due to nonce mismatch")
|
||||||
|
|
||||||
|
// But frame 0 with a fresh decryptor should work
|
||||||
|
dec2, err := DecryptRecording(adminPrivB64, ephPubB64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
got, err := dec2.Decrypt(ct0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte("frame 0"), got)
|
||||||
|
}
|
||||||
540
client/vnc/server/input_darwin.go
Normal file
540
client/vnc/server/input_darwin.go
Normal file
@@ -0,0 +1,540 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/ebitengine/purego"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Core Graphics event constants.
|
||||||
|
const (
|
||||||
|
kCGEventSourceStateCombinedSessionState int32 = 0
|
||||||
|
|
||||||
|
kCGEventLeftMouseDown int32 = 1
|
||||||
|
kCGEventLeftMouseUp int32 = 2
|
||||||
|
kCGEventRightMouseDown int32 = 3
|
||||||
|
kCGEventRightMouseUp int32 = 4
|
||||||
|
kCGEventMouseMoved int32 = 5
|
||||||
|
kCGEventLeftMouseDragged int32 = 6
|
||||||
|
kCGEventRightMouseDragged int32 = 7
|
||||||
|
kCGEventKeyDown int32 = 10
|
||||||
|
kCGEventKeyUp int32 = 11
|
||||||
|
kCGEventOtherMouseDown int32 = 25
|
||||||
|
kCGEventOtherMouseUp int32 = 26
|
||||||
|
|
||||||
|
kCGMouseButtonLeft int32 = 0
|
||||||
|
kCGMouseButtonRight int32 = 1
|
||||||
|
kCGMouseButtonCenter int32 = 2
|
||||||
|
|
||||||
|
kCGHIDEventTap int32 = 0
|
||||||
|
|
||||||
|
// IOKit power management constants.
|
||||||
|
kIOPMUserActiveLocal int32 = 0
|
||||||
|
kIOPMAssertionLevelOn uint32 = 255
|
||||||
|
kCFStringEncodingUTF8 uint32 = 0x08000100
|
||||||
|
)
|
||||||
|
|
||||||
|
var darwinInputOnce sync.Once
|
||||||
|
|
||||||
|
var (
|
||||||
|
cgEventSourceCreate func(int32) uintptr
|
||||||
|
cgEventCreateKeyboardEvent func(uintptr, uint16, bool) uintptr
|
||||||
|
// CGEventCreateMouseEvent takes CGPoint as two separate float64 args.
|
||||||
|
// purego can't handle array/struct types but individual float64s work.
|
||||||
|
cgEventCreateMouseEvent func(uintptr, int32, float64, float64, int32) uintptr
|
||||||
|
cgEventPost func(int32, uintptr)
|
||||||
|
|
||||||
|
// CGEventCreateScrollWheelEvent is variadic, call via SyscallN.
|
||||||
|
cgEventCreateScrollWheelEventAddr uintptr
|
||||||
|
|
||||||
|
axIsProcessTrusted func() bool
|
||||||
|
|
||||||
|
// IOKit power-management bindings used to wake the display and inhibit
|
||||||
|
// idle sleep while a VNC client is driving input.
|
||||||
|
iopmAssertionDeclareUserActivity func(uintptr, int32, *uint32) int32
|
||||||
|
iopmAssertionCreateWithName func(uintptr, uint32, uintptr, *uint32) int32
|
||||||
|
iopmAssertionRelease func(uint32) int32
|
||||||
|
cfStringCreateWithCString func(uintptr, string, uint32) uintptr
|
||||||
|
|
||||||
|
// Cached CFStrings for assertion name and idle-sleep type.
|
||||||
|
pmAssertionNameCFStr uintptr
|
||||||
|
pmPreventIdleDisplayCFStr uintptr
|
||||||
|
|
||||||
|
// Assertion IDs. userActivityID is reused across input events so repeated
|
||||||
|
// calls refresh the same assertion rather than create new ones.
|
||||||
|
pmMu sync.Mutex
|
||||||
|
userActivityID uint32
|
||||||
|
preventSleepID uint32
|
||||||
|
preventSleepHeld bool
|
||||||
|
|
||||||
|
darwinInputReady bool
|
||||||
|
darwinEventSource uintptr
|
||||||
|
)
|
||||||
|
|
||||||
|
func initDarwinInput() {
|
||||||
|
darwinInputOnce.Do(func() {
|
||||||
|
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("load CoreGraphics for input: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
purego.RegisterLibFunc(&cgEventSourceCreate, cg, "CGEventSourceCreate")
|
||||||
|
purego.RegisterLibFunc(&cgEventCreateKeyboardEvent, cg, "CGEventCreateKeyboardEvent")
|
||||||
|
purego.RegisterLibFunc(&cgEventCreateMouseEvent, cg, "CGEventCreateMouseEvent")
|
||||||
|
purego.RegisterLibFunc(&cgEventPost, cg, "CGEventPost")
|
||||||
|
|
||||||
|
sym, err := purego.Dlsym(cg, "CGEventCreateScrollWheelEvent")
|
||||||
|
if err == nil {
|
||||||
|
cgEventCreateScrollWheelEventAddr = sym
|
||||||
|
}
|
||||||
|
|
||||||
|
if ax, err := purego.Dlopen("/System/Library/Frameworks/ApplicationServices.framework/ApplicationServices", purego.RTLD_NOW|purego.RTLD_GLOBAL); err == nil {
|
||||||
|
if sym, err := purego.Dlsym(ax, "AXIsProcessTrusted"); err == nil {
|
||||||
|
purego.RegisterFunc(&axIsProcessTrusted, sym)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
initPowerAssertions()
|
||||||
|
|
||||||
|
darwinInputReady = true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func initPowerAssertions() {
|
||||||
|
iokit, err := purego.Dlopen("/System/Library/Frameworks/IOKit.framework/IOKit", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("load IOKit: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("load CoreFoundation for power assertions: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
purego.RegisterLibFunc(&cfStringCreateWithCString, cf, "CFStringCreateWithCString")
|
||||||
|
purego.RegisterLibFunc(&iopmAssertionDeclareUserActivity, iokit, "IOPMAssertionDeclareUserActivity")
|
||||||
|
purego.RegisterLibFunc(&iopmAssertionCreateWithName, iokit, "IOPMAssertionCreateWithName")
|
||||||
|
purego.RegisterLibFunc(&iopmAssertionRelease, iokit, "IOPMAssertionRelease")
|
||||||
|
|
||||||
|
pmAssertionNameCFStr = cfStringCreateWithCString(0, "NetBird VNC input", kCFStringEncodingUTF8)
|
||||||
|
pmPreventIdleDisplayCFStr = cfStringCreateWithCString(0, "PreventUserIdleDisplaySleep", kCFStringEncodingUTF8)
|
||||||
|
}
|
||||||
|
|
||||||
|
// wakeDisplay declares user activity so macOS treats the synthesized input as
|
||||||
|
// real HID activity, waking the display if it is asleep. Called on every key
|
||||||
|
// and pointer event; the kernel coalesces repeated calls cheaply.
|
||||||
|
func wakeDisplay() {
|
||||||
|
if iopmAssertionDeclareUserActivity == nil || pmAssertionNameCFStr == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pmMu.Lock()
|
||||||
|
id := userActivityID
|
||||||
|
pmMu.Unlock()
|
||||||
|
r := iopmAssertionDeclareUserActivity(pmAssertionNameCFStr, kIOPMUserActiveLocal, &id)
|
||||||
|
if r != 0 {
|
||||||
|
log.Tracef("IOPMAssertionDeclareUserActivity returned %d", r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pmMu.Lock()
|
||||||
|
userActivityID = id
|
||||||
|
pmMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// holdPreventIdleSleep creates an assertion that keeps the display from going
|
||||||
|
// idle-to-sleep while a VNC session is active. Safe to call repeatedly.
|
||||||
|
func holdPreventIdleSleep() {
|
||||||
|
if iopmAssertionCreateWithName == nil || pmPreventIdleDisplayCFStr == 0 || pmAssertionNameCFStr == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pmMu.Lock()
|
||||||
|
defer pmMu.Unlock()
|
||||||
|
if preventSleepHeld {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var id uint32
|
||||||
|
r := iopmAssertionCreateWithName(pmPreventIdleDisplayCFStr, kIOPMAssertionLevelOn, pmAssertionNameCFStr, &id)
|
||||||
|
if r != 0 {
|
||||||
|
log.Debugf("IOPMAssertionCreateWithName returned %d", r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
preventSleepID = id
|
||||||
|
preventSleepHeld = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// releasePreventIdleSleep drops the idle-sleep assertion.
|
||||||
|
func releasePreventIdleSleep() {
|
||||||
|
if iopmAssertionRelease == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pmMu.Lock()
|
||||||
|
defer pmMu.Unlock()
|
||||||
|
if !preventSleepHeld {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r := iopmAssertionRelease(preventSleepID); r != 0 {
|
||||||
|
log.Debugf("IOPMAssertionRelease returned %d", r)
|
||||||
|
}
|
||||||
|
preventSleepHeld = false
|
||||||
|
preventSleepID = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureEventSource() uintptr {
|
||||||
|
if darwinEventSource != 0 {
|
||||||
|
return darwinEventSource
|
||||||
|
}
|
||||||
|
darwinEventSource = cgEventSourceCreate(kCGEventSourceStateCombinedSessionState)
|
||||||
|
return darwinEventSource
|
||||||
|
}
|
||||||
|
|
||||||
|
// MacInputInjector injects keyboard and mouse events via Core Graphics.
|
||||||
|
type MacInputInjector struct {
|
||||||
|
lastButtons uint8
|
||||||
|
pbcopyPath string
|
||||||
|
pbpastePath string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMacInputInjector creates a macOS input injector.
|
||||||
|
func NewMacInputInjector() (*MacInputInjector, error) {
|
||||||
|
initDarwinInput()
|
||||||
|
if !darwinInputReady {
|
||||||
|
return nil, fmt.Errorf("CoreGraphics not available for input injection")
|
||||||
|
}
|
||||||
|
checkMacPermissions()
|
||||||
|
|
||||||
|
m := &MacInputInjector{}
|
||||||
|
if path, err := exec.LookPath("pbcopy"); err == nil {
|
||||||
|
m.pbcopyPath = path
|
||||||
|
}
|
||||||
|
if path, err := exec.LookPath("pbpaste"); err == nil {
|
||||||
|
m.pbpastePath = path
|
||||||
|
}
|
||||||
|
if m.pbcopyPath == "" || m.pbpastePath == "" {
|
||||||
|
log.Debugf("clipboard tools not found (pbcopy=%q, pbpaste=%q)", m.pbcopyPath, m.pbpastePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
holdPreventIdleSleep()
|
||||||
|
|
||||||
|
log.Info("macOS input injector ready")
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkMacPermissions warns and opens the Privacy pane if Accessibility is
|
||||||
|
// missing. Uses AXIsProcessTrusted which returns immediately; the previous
|
||||||
|
// osascript probe blocked for 120s (AppleEvent timeout) when access was
|
||||||
|
// denied, which delayed VNC server startup past client deadlines.
|
||||||
|
func checkMacPermissions() {
|
||||||
|
if axIsProcessTrusted != nil && !axIsProcessTrusted() {
|
||||||
|
openPrivacyPane("Privacy_Accessibility")
|
||||||
|
log.Warn("Accessibility permission not granted. Input injection will not work. " +
|
||||||
|
"Opened System Settings > Privacy & Security > Accessibility; enable netbird.")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Screen Recording permission is required for screen capture. " +
|
||||||
|
"If the screen appears black, grant in System Settings > Privacy & Security > Screen Recording.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// openPrivacyPane opens the given Privacy pane in System Settings so the user
|
||||||
|
// can toggle the permission without navigating manually.
|
||||||
|
func openPrivacyPane(pane string) {
|
||||||
|
url := "x-apple.systempreferences:com.apple.preference.security?" + pane
|
||||||
|
if err := exec.Command("open", url).Start(); err != nil {
|
||||||
|
log.Debugf("open privacy pane %s: %v", pane, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InjectKey simulates a key press or release.
|
||||||
|
func (m *MacInputInjector) InjectKey(keysym uint32, down bool) {
|
||||||
|
wakeDisplay()
|
||||||
|
src := ensureEventSource()
|
||||||
|
if src == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
keycode := keysymToMacKeycode(keysym)
|
||||||
|
if keycode == 0xFFFF {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
event := cgEventCreateKeyboardEvent(src, keycode, down)
|
||||||
|
if event == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cgEventPost(kCGHIDEventTap, event)
|
||||||
|
cfRelease(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InjectPointer simulates mouse movement and button events.
|
||||||
|
func (m *MacInputInjector) InjectPointer(buttonMask uint8, px, py, serverW, serverH int) {
|
||||||
|
wakeDisplay()
|
||||||
|
if serverW == 0 || serverH == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
src := ensureEventSource()
|
||||||
|
if src == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Framebuffer is in physical pixels (Retina). CGEventCreateMouseEvent
|
||||||
|
// expects logical points, so scale down by the display's pixel/point ratio.
|
||||||
|
x := float64(px)
|
||||||
|
y := float64(py)
|
||||||
|
if cgDisplayPixelsWide != nil && cgMainDisplayID != nil {
|
||||||
|
displayID := cgMainDisplayID()
|
||||||
|
logicalW := int(cgDisplayPixelsWide(displayID))
|
||||||
|
logicalH := int(cgDisplayPixelsHigh(displayID))
|
||||||
|
if logicalW > 0 && logicalH > 0 {
|
||||||
|
x = float64(px) * float64(logicalW) / float64(serverW)
|
||||||
|
y = float64(py) * float64(logicalH) / float64(serverH)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
leftDown := buttonMask&0x01 != 0
|
||||||
|
rightDown := buttonMask&0x04 != 0
|
||||||
|
middleDown := buttonMask&0x02 != 0
|
||||||
|
scrollUp := buttonMask&0x08 != 0
|
||||||
|
scrollDown := buttonMask&0x10 != 0
|
||||||
|
|
||||||
|
wasLeft := m.lastButtons&0x01 != 0
|
||||||
|
wasRight := m.lastButtons&0x04 != 0
|
||||||
|
wasMiddle := m.lastButtons&0x02 != 0
|
||||||
|
|
||||||
|
if leftDown {
|
||||||
|
m.postMouse(src, kCGEventLeftMouseDragged, x, y, kCGMouseButtonLeft)
|
||||||
|
} else if rightDown {
|
||||||
|
m.postMouse(src, kCGEventRightMouseDragged, x, y, kCGMouseButtonRight)
|
||||||
|
} else {
|
||||||
|
m.postMouse(src, kCGEventMouseMoved, x, y, kCGMouseButtonLeft)
|
||||||
|
}
|
||||||
|
|
||||||
|
if leftDown && !wasLeft {
|
||||||
|
m.postMouse(src, kCGEventLeftMouseDown, x, y, kCGMouseButtonLeft)
|
||||||
|
} else if !leftDown && wasLeft {
|
||||||
|
m.postMouse(src, kCGEventLeftMouseUp, x, y, kCGMouseButtonLeft)
|
||||||
|
}
|
||||||
|
if rightDown && !wasRight {
|
||||||
|
m.postMouse(src, kCGEventRightMouseDown, x, y, kCGMouseButtonRight)
|
||||||
|
} else if !rightDown && wasRight {
|
||||||
|
m.postMouse(src, kCGEventRightMouseUp, x, y, kCGMouseButtonRight)
|
||||||
|
}
|
||||||
|
if middleDown && !wasMiddle {
|
||||||
|
m.postMouse(src, kCGEventOtherMouseDown, x, y, kCGMouseButtonCenter)
|
||||||
|
} else if !middleDown && wasMiddle {
|
||||||
|
m.postMouse(src, kCGEventOtherMouseUp, x, y, kCGMouseButtonCenter)
|
||||||
|
}
|
||||||
|
|
||||||
|
if scrollUp {
|
||||||
|
m.postScroll(src, 3)
|
||||||
|
}
|
||||||
|
if scrollDown {
|
||||||
|
m.postScroll(src, -3)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.lastButtons = buttonMask
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MacInputInjector) postMouse(src uintptr, eventType int32, x, y float64, button int32) {
|
||||||
|
if cgEventCreateMouseEvent == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
event := cgEventCreateMouseEvent(src, eventType, x, y, button)
|
||||||
|
if event == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cgEventPost(kCGHIDEventTap, event)
|
||||||
|
cfRelease(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MacInputInjector) postScroll(src uintptr, deltaY int32) {
|
||||||
|
if cgEventCreateScrollWheelEventAddr == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CGEventCreateScrollWheelEvent(source, units, wheelCount, wheel1delta)
|
||||||
|
// units=0 (pixel), wheelCount=1, wheel1delta=deltaY
|
||||||
|
// Variadic C function: pass args as uintptr via SyscallN.
|
||||||
|
r1, _, _ := purego.SyscallN(cgEventCreateScrollWheelEventAddr,
|
||||||
|
src, 0, 1, uintptr(uint32(deltaY)))
|
||||||
|
if r1 == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cgEventPost(kCGHIDEventTap, r1)
|
||||||
|
cfRelease(r1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetClipboard sets the macOS clipboard using pbcopy.
|
||||||
|
func (m *MacInputInjector) SetClipboard(text string) {
|
||||||
|
if m.pbcopyPath == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cmd := exec.Command(m.pbcopyPath)
|
||||||
|
cmd.Stdin = strings.NewReader(text)
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
log.Tracef("set clipboard via pbcopy: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClipboard reads the macOS clipboard using pbpaste.
|
||||||
|
func (m *MacInputInjector) GetClipboard() string {
|
||||||
|
if m.pbpastePath == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
out, err := exec.Command(m.pbpastePath).Output()
|
||||||
|
if err != nil {
|
||||||
|
log.Tracef("get clipboard via pbpaste: %v", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close releases the idle-sleep assertion held for the injector's lifetime.
|
||||||
|
func (m *MacInputInjector) Close() {
|
||||||
|
releasePreventIdleSleep()
|
||||||
|
}
|
||||||
|
|
||||||
|
func keysymToMacKeycode(keysym uint32) uint16 {
|
||||||
|
if keysym >= 0x61 && keysym <= 0x7a {
|
||||||
|
return asciiToMacKey[keysym-0x61]
|
||||||
|
}
|
||||||
|
if keysym >= 0x41 && keysym <= 0x5a {
|
||||||
|
return asciiToMacKey[keysym-0x41]
|
||||||
|
}
|
||||||
|
if keysym >= 0x30 && keysym <= 0x39 {
|
||||||
|
return digitToMacKey[keysym-0x30]
|
||||||
|
}
|
||||||
|
if code, ok := specialKeyMap[keysym]; ok {
|
||||||
|
return code
|
||||||
|
}
|
||||||
|
return 0xFFFF
|
||||||
|
}
|
||||||
|
|
||||||
|
var asciiToMacKey = [26]uint16{
|
||||||
|
0x00, 0x0B, 0x08, 0x02, 0x0E, 0x03, 0x05, 0x04,
|
||||||
|
0x22, 0x26, 0x28, 0x25, 0x2E, 0x2D, 0x1F, 0x23,
|
||||||
|
0x0C, 0x0F, 0x01, 0x11, 0x20, 0x09, 0x0D, 0x07,
|
||||||
|
0x10, 0x06,
|
||||||
|
}
|
||||||
|
|
||||||
|
var digitToMacKey = [10]uint16{
|
||||||
|
0x1D, 0x12, 0x13, 0x14, 0x15, 0x17, 0x16, 0x1A, 0x1C, 0x19,
|
||||||
|
}
|
||||||
|
|
||||||
|
var specialKeyMap = map[uint32]uint16{
|
||||||
|
// Whitespace and editing
|
||||||
|
0x0020: 0x31, // space
|
||||||
|
0xff08: 0x33, // BackSpace
|
||||||
|
0xff09: 0x30, // Tab
|
||||||
|
0xff0d: 0x24, // Return
|
||||||
|
0xff1b: 0x35, // Escape
|
||||||
|
0xffff: 0x75, // Delete (forward)
|
||||||
|
|
||||||
|
// Navigation
|
||||||
|
0xff50: 0x73, // Home
|
||||||
|
0xff51: 0x7B, // Left
|
||||||
|
0xff52: 0x7E, // Up
|
||||||
|
0xff53: 0x7C, // Right
|
||||||
|
0xff54: 0x7D, // Down
|
||||||
|
0xff55: 0x74, // Page_Up
|
||||||
|
0xff56: 0x79, // Page_Down
|
||||||
|
0xff57: 0x77, // End
|
||||||
|
0xff63: 0x72, // Insert (Help on Mac)
|
||||||
|
|
||||||
|
// Modifiers
|
||||||
|
0xffe1: 0x38, // Shift_L
|
||||||
|
0xffe2: 0x3C, // Shift_R
|
||||||
|
0xffe3: 0x3B, // Control_L
|
||||||
|
0xffe4: 0x3E, // Control_R
|
||||||
|
0xffe5: 0x39, // Caps_Lock
|
||||||
|
0xffe9: 0x3A, // Alt_L (Option)
|
||||||
|
0xffea: 0x3D, // Alt_R (Option)
|
||||||
|
0xffe7: 0x37, // Meta_L (Command)
|
||||||
|
0xffe8: 0x36, // Meta_R (Command)
|
||||||
|
0xffeb: 0x37, // Super_L (Command) - noVNC sends this
|
||||||
|
0xffec: 0x36, // Super_R (Command)
|
||||||
|
|
||||||
|
// Mode_switch / ISO_Level3_Shift (sent by noVNC for macOS Option remap)
|
||||||
|
0xff7e: 0x3A, // Mode_switch -> Option
|
||||||
|
0xfe03: 0x3D, // ISO_Level3_Shift -> Right Option
|
||||||
|
|
||||||
|
// Function keys
|
||||||
|
0xffbe: 0x7A, // F1
|
||||||
|
0xffbf: 0x78, // F2
|
||||||
|
0xffc0: 0x63, // F3
|
||||||
|
0xffc1: 0x76, // F4
|
||||||
|
0xffc2: 0x60, // F5
|
||||||
|
0xffc3: 0x61, // F6
|
||||||
|
0xffc4: 0x62, // F7
|
||||||
|
0xffc5: 0x64, // F8
|
||||||
|
0xffc6: 0x65, // F9
|
||||||
|
0xffc7: 0x6D, // F10
|
||||||
|
0xffc8: 0x67, // F11
|
||||||
|
0xffc9: 0x6F, // F12
|
||||||
|
0xffca: 0x69, // F13
|
||||||
|
0xffcb: 0x6B, // F14
|
||||||
|
0xffcc: 0x71, // F15
|
||||||
|
0xffcd: 0x6A, // F16
|
||||||
|
0xffce: 0x40, // F17
|
||||||
|
0xffcf: 0x4F, // F18
|
||||||
|
0xffd0: 0x50, // F19
|
||||||
|
0xffd1: 0x5A, // F20
|
||||||
|
|
||||||
|
// Punctuation (US keyboard layout, keysym = ASCII code)
|
||||||
|
0x002d: 0x1B, // minus -
|
||||||
|
0x003d: 0x18, // equal =
|
||||||
|
0x005b: 0x21, // bracketleft [
|
||||||
|
0x005d: 0x1E, // bracketright ]
|
||||||
|
0x005c: 0x2A, // backslash
|
||||||
|
0x003b: 0x29, // semicolon ;
|
||||||
|
0x0027: 0x27, // apostrophe '
|
||||||
|
0x0060: 0x32, // grave `
|
||||||
|
0x002c: 0x2B, // comma ,
|
||||||
|
0x002e: 0x2F, // period .
|
||||||
|
0x002f: 0x2C, // slash /
|
||||||
|
|
||||||
|
// Shifted punctuation (noVNC sends these as separate keysyms)
|
||||||
|
0x005f: 0x1B, // underscore _ (shift+minus)
|
||||||
|
0x002b: 0x18, // plus + (shift+equal)
|
||||||
|
0x007b: 0x21, // braceleft { (shift+[)
|
||||||
|
0x007d: 0x1E, // braceright } (shift+])
|
||||||
|
0x007c: 0x2A, // bar | (shift+\)
|
||||||
|
0x003a: 0x29, // colon : (shift+;)
|
||||||
|
0x0022: 0x27, // quotedbl " (shift+')
|
||||||
|
0x007e: 0x32, // tilde ~ (shift+`)
|
||||||
|
0x003c: 0x2B, // less < (shift+,)
|
||||||
|
0x003e: 0x2F, // greater > (shift+.)
|
||||||
|
0x003f: 0x2C, // question ? (shift+/)
|
||||||
|
0x0021: 0x12, // exclam ! (shift+1)
|
||||||
|
0x0040: 0x13, // at @ (shift+2)
|
||||||
|
0x0023: 0x14, // numbersign # (shift+3)
|
||||||
|
0x0024: 0x15, // dollar $ (shift+4)
|
||||||
|
0x0025: 0x17, // percent % (shift+5)
|
||||||
|
0x005e: 0x16, // asciicircum ^ (shift+6)
|
||||||
|
0x0026: 0x1A, // ampersand & (shift+7)
|
||||||
|
0x002a: 0x1C, // asterisk * (shift+8)
|
||||||
|
0x0028: 0x19, // parenleft ( (shift+9)
|
||||||
|
0x0029: 0x1D, // parenright ) (shift+0)
|
||||||
|
|
||||||
|
// Numpad
|
||||||
|
0xffb0: 0x52, // KP_0
|
||||||
|
0xffb1: 0x53, // KP_1
|
||||||
|
0xffb2: 0x54, // KP_2
|
||||||
|
0xffb3: 0x55, // KP_3
|
||||||
|
0xffb4: 0x56, // KP_4
|
||||||
|
0xffb5: 0x57, // KP_5
|
||||||
|
0xffb6: 0x58, // KP_6
|
||||||
|
0xffb7: 0x59, // KP_7
|
||||||
|
0xffb8: 0x5B, // KP_8
|
||||||
|
0xffb9: 0x5C, // KP_9
|
||||||
|
0xffae: 0x41, // KP_Decimal
|
||||||
|
0xffaa: 0x43, // KP_Multiply
|
||||||
|
0xffab: 0x45, // KP_Add
|
||||||
|
0xffad: 0x4E, // KP_Subtract
|
||||||
|
0xffaf: 0x4B, // KP_Divide
|
||||||
|
0xff8d: 0x4C, // KP_Enter
|
||||||
|
0xffbd: 0x51, // KP_Equal
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ InputInjector = (*MacInputInjector)(nil)
|
||||||
398
client/vnc/server/input_windows.go
Normal file
398
client/vnc/server/input_windows.go
Normal file
@@ -0,0 +1,398 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"runtime"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
procOpenEventW = kernel32.NewProc("OpenEventW")
|
||||||
|
procSendInput = user32.NewProc("SendInput")
|
||||||
|
procVkKeyScanA = user32.NewProc("VkKeyScanA")
|
||||||
|
)
|
||||||
|
|
||||||
|
const eventModifyState = 0x0002
|
||||||
|
|
||||||
|
const (
|
||||||
|
inputMouse = 0
|
||||||
|
inputKeyboard = 1
|
||||||
|
|
||||||
|
mouseeventfMove = 0x0001
|
||||||
|
mouseeventfLeftDown = 0x0002
|
||||||
|
mouseeventfLeftUp = 0x0004
|
||||||
|
mouseeventfRightDown = 0x0008
|
||||||
|
mouseeventfRightUp = 0x0010
|
||||||
|
mouseeventfMiddleDown = 0x0020
|
||||||
|
mouseeventfMiddleUp = 0x0040
|
||||||
|
mouseeventfWheel = 0x0800
|
||||||
|
mouseeventfAbsolute = 0x8000
|
||||||
|
|
||||||
|
wheelDelta = 120
|
||||||
|
|
||||||
|
keyeventfKeyUp = 0x0002
|
||||||
|
keyeventfScanCode = 0x0008
|
||||||
|
)
|
||||||
|
|
||||||
|
type mouseInput struct {
|
||||||
|
Dx int32
|
||||||
|
Dy int32
|
||||||
|
MouseData uint32
|
||||||
|
DwFlags uint32
|
||||||
|
Time uint32
|
||||||
|
DwExtraInfo uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
type keybdInput struct {
|
||||||
|
WVk uint16
|
||||||
|
WScan uint16
|
||||||
|
DwFlags uint32
|
||||||
|
Time uint32
|
||||||
|
DwExtraInfo uintptr
|
||||||
|
_ [8]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type inputUnion [32]byte
|
||||||
|
|
||||||
|
type winInput struct {
|
||||||
|
Type uint32
|
||||||
|
_ [4]byte
|
||||||
|
Data inputUnion
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendMouseInput(flags uint32, dx, dy int32, mouseData uint32) {
|
||||||
|
mi := mouseInput{
|
||||||
|
Dx: dx,
|
||||||
|
Dy: dy,
|
||||||
|
MouseData: mouseData,
|
||||||
|
DwFlags: flags,
|
||||||
|
}
|
||||||
|
inp := winInput{Type: inputMouse}
|
||||||
|
copy(inp.Data[:], (*[unsafe.Sizeof(mi)]byte)(unsafe.Pointer(&mi))[:])
|
||||||
|
r, _, err := procSendInput.Call(1, uintptr(unsafe.Pointer(&inp)), unsafe.Sizeof(inp))
|
||||||
|
if r == 0 {
|
||||||
|
log.Tracef("SendInput(mouse flags=0x%x): %v", flags, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendKeyInput(vk uint16, scanCode uint16, flags uint32) {
|
||||||
|
ki := keybdInput{
|
||||||
|
WVk: vk,
|
||||||
|
WScan: scanCode,
|
||||||
|
DwFlags: flags,
|
||||||
|
}
|
||||||
|
inp := winInput{Type: inputKeyboard}
|
||||||
|
copy(inp.Data[:], (*[unsafe.Sizeof(ki)]byte)(unsafe.Pointer(&ki))[:])
|
||||||
|
r, _, err := procSendInput.Call(1, uintptr(unsafe.Pointer(&inp)), unsafe.Sizeof(inp))
|
||||||
|
if r == 0 {
|
||||||
|
log.Tracef("SendInput(key vk=0x%x): %v", vk, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const sasEventName = `Global\NetBirdVNC_SAS`
|
||||||
|
|
||||||
|
type inputCmd struct {
|
||||||
|
isKey bool
|
||||||
|
keysym uint32
|
||||||
|
down bool
|
||||||
|
buttonMask uint8
|
||||||
|
x, y int
|
||||||
|
serverW int
|
||||||
|
serverH int
|
||||||
|
}
|
||||||
|
|
||||||
|
// WindowsInputInjector delivers input events from a dedicated OS thread that
|
||||||
|
// calls switchToInputDesktop before each injection. SendInput targets the
|
||||||
|
// calling thread's desktop, so the injection thread must be on the same
|
||||||
|
// desktop the user sees.
|
||||||
|
type WindowsInputInjector struct {
|
||||||
|
ch chan inputCmd
|
||||||
|
prevButtonMask uint8
|
||||||
|
ctrlDown bool
|
||||||
|
altDown bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWindowsInputInjector creates a desktop-aware input injector.
|
||||||
|
func NewWindowsInputInjector() *WindowsInputInjector {
|
||||||
|
w := &WindowsInputInjector{ch: make(chan inputCmd, 64)}
|
||||||
|
go w.loop()
|
||||||
|
return w
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WindowsInputInjector) loop() {
|
||||||
|
runtime.LockOSThread()
|
||||||
|
|
||||||
|
for cmd := range w.ch {
|
||||||
|
// Switch to the current input desktop so SendInput reaches the right target.
|
||||||
|
switchToInputDesktop()
|
||||||
|
|
||||||
|
if cmd.isKey {
|
||||||
|
w.doInjectKey(cmd.keysym, cmd.down)
|
||||||
|
} else {
|
||||||
|
w.doInjectPointer(cmd.buttonMask, cmd.x, cmd.y, cmd.serverW, cmd.serverH)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InjectKey queues a key event for injection on the input desktop thread.
|
||||||
|
func (w *WindowsInputInjector) InjectKey(keysym uint32, down bool) {
|
||||||
|
w.ch <- inputCmd{isKey: true, keysym: keysym, down: down}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InjectPointer queues a pointer event for injection on the input desktop thread.
|
||||||
|
func (w *WindowsInputInjector) InjectPointer(buttonMask uint8, x, y, serverW, serverH int) {
|
||||||
|
w.ch <- inputCmd{buttonMask: buttonMask, x: x, y: y, serverW: serverW, serverH: serverH}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WindowsInputInjector) doInjectKey(keysym uint32, down bool) {
|
||||||
|
switch keysym {
|
||||||
|
case 0xffe3, 0xffe4:
|
||||||
|
w.ctrlDown = down
|
||||||
|
case 0xffe9, 0xffea:
|
||||||
|
w.altDown = down
|
||||||
|
}
|
||||||
|
|
||||||
|
if (keysym == 0xff9f || keysym == 0xffff) && w.ctrlDown && w.altDown && down {
|
||||||
|
signalSAS()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
vk, _, extended := keysym2VK(keysym)
|
||||||
|
if vk == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var flags uint32
|
||||||
|
if !down {
|
||||||
|
flags |= keyeventfKeyUp
|
||||||
|
}
|
||||||
|
if extended {
|
||||||
|
flags |= keyeventfScanCode
|
||||||
|
}
|
||||||
|
sendKeyInput(vk, 0, flags)
|
||||||
|
}
|
||||||
|
|
||||||
|
// signalSAS signals the SAS named event. A listener in Session 0
|
||||||
|
// (startSASListener) calls SendSAS to trigger the Secure Attention Sequence.
|
||||||
|
func signalSAS() {
|
||||||
|
namePtr, err := windows.UTF16PtrFromString(sasEventName)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("SAS UTF16: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h, _, lerr := procOpenEventW.Call(
|
||||||
|
uintptr(eventModifyState),
|
||||||
|
0,
|
||||||
|
uintptr(unsafe.Pointer(namePtr)),
|
||||||
|
)
|
||||||
|
if h == 0 {
|
||||||
|
log.Warnf("OpenEvent(%s): %v", sasEventName, lerr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ev := windows.Handle(h)
|
||||||
|
defer windows.CloseHandle(ev)
|
||||||
|
if err := windows.SetEvent(ev); err != nil {
|
||||||
|
log.Warnf("SetEvent SAS: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Info("SAS event signaled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WindowsInputInjector) doInjectPointer(buttonMask uint8, x, y, serverW, serverH int) {
|
||||||
|
if serverW == 0 || serverH == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
absX := int32(x * 65535 / serverW)
|
||||||
|
absY := int32(y * 65535 / serverH)
|
||||||
|
|
||||||
|
sendMouseInput(mouseeventfMove|mouseeventfAbsolute, absX, absY, 0)
|
||||||
|
|
||||||
|
changed := buttonMask ^ w.prevButtonMask
|
||||||
|
w.prevButtonMask = buttonMask
|
||||||
|
|
||||||
|
type btnMap struct {
|
||||||
|
bit uint8
|
||||||
|
down uint32
|
||||||
|
up uint32
|
||||||
|
}
|
||||||
|
buttons := [...]btnMap{
|
||||||
|
{0x01, mouseeventfLeftDown, mouseeventfLeftUp},
|
||||||
|
{0x02, mouseeventfMiddleDown, mouseeventfMiddleUp},
|
||||||
|
{0x04, mouseeventfRightDown, mouseeventfRightUp},
|
||||||
|
}
|
||||||
|
for _, b := range buttons {
|
||||||
|
if changed&b.bit == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var flags uint32
|
||||||
|
if buttonMask&b.bit != 0 {
|
||||||
|
flags = b.down
|
||||||
|
} else {
|
||||||
|
flags = b.up
|
||||||
|
}
|
||||||
|
sendMouseInput(flags|mouseeventfAbsolute, absX, absY, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
negWheelDelta := ^uint32(wheelDelta - 1)
|
||||||
|
if changed&0x08 != 0 && buttonMask&0x08 != 0 {
|
||||||
|
sendMouseInput(mouseeventfWheel|mouseeventfAbsolute, absX, absY, wheelDelta)
|
||||||
|
}
|
||||||
|
if changed&0x10 != 0 && buttonMask&0x10 != 0 {
|
||||||
|
sendMouseInput(mouseeventfWheel|mouseeventfAbsolute, absX, absY, negWheelDelta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// keysym2VK converts an X11 keysym to a Windows virtual key code.
|
||||||
|
func keysym2VK(keysym uint32) (vk uint16, scan uint16, extended bool) {
|
||||||
|
if keysym >= 0x20 && keysym <= 0x7e {
|
||||||
|
r, _, _ := procVkKeyScanA.Call(uintptr(keysym))
|
||||||
|
vk = uint16(r & 0xff)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if keysym >= 0xffbe && keysym <= 0xffc9 {
|
||||||
|
vk = uint16(0x70 + keysym - 0xffbe)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch keysym {
|
||||||
|
case 0xff08:
|
||||||
|
vk = 0x08 // Backspace
|
||||||
|
case 0xff09:
|
||||||
|
vk = 0x09 // Tab
|
||||||
|
case 0xff0d:
|
||||||
|
vk = 0x0d // Return
|
||||||
|
case 0xff1b:
|
||||||
|
vk = 0x1b // Escape
|
||||||
|
case 0xff63:
|
||||||
|
vk, extended = 0x2d, true // Insert
|
||||||
|
case 0xff9f, 0xffff:
|
||||||
|
vk, extended = 0x2e, true // Delete
|
||||||
|
case 0xff50:
|
||||||
|
vk, extended = 0x24, true // Home
|
||||||
|
case 0xff57:
|
||||||
|
vk, extended = 0x23, true // End
|
||||||
|
case 0xff55:
|
||||||
|
vk, extended = 0x21, true // PageUp
|
||||||
|
case 0xff56:
|
||||||
|
vk, extended = 0x22, true // PageDown
|
||||||
|
case 0xff51:
|
||||||
|
vk, extended = 0x25, true // Left
|
||||||
|
case 0xff52:
|
||||||
|
vk, extended = 0x26, true // Up
|
||||||
|
case 0xff53:
|
||||||
|
vk, extended = 0x27, true // Right
|
||||||
|
case 0xff54:
|
||||||
|
vk, extended = 0x28, true // Down
|
||||||
|
case 0xffe1, 0xffe2:
|
||||||
|
vk = 0x10 // Shift
|
||||||
|
case 0xffe3, 0xffe4:
|
||||||
|
vk = 0x11 // Control
|
||||||
|
case 0xffe9, 0xffea:
|
||||||
|
vk = 0x12 // Alt
|
||||||
|
case 0xffe5:
|
||||||
|
vk = 0x14 // CapsLock
|
||||||
|
case 0xffe7, 0xffeb:
|
||||||
|
vk, extended = 0x5B, true // Meta_L / Super_L -> Left Windows
|
||||||
|
case 0xffe8, 0xffec:
|
||||||
|
vk, extended = 0x5C, true // Meta_R / Super_R -> Right Windows
|
||||||
|
case 0xff61:
|
||||||
|
vk = 0x2c // PrintScreen
|
||||||
|
case 0xff13:
|
||||||
|
vk = 0x13 // Pause
|
||||||
|
case 0xff14:
|
||||||
|
vk = 0x91 // ScrollLock
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
procOpenClipboard = user32.NewProc("OpenClipboard")
|
||||||
|
procCloseClipboard = user32.NewProc("CloseClipboard")
|
||||||
|
procEmptyClipboard = user32.NewProc("EmptyClipboard")
|
||||||
|
procSetClipboardData = user32.NewProc("SetClipboardData")
|
||||||
|
procGetClipboardData = user32.NewProc("GetClipboardData")
|
||||||
|
procIsClipboardFormatAvailable = user32.NewProc("IsClipboardFormatAvailable")
|
||||||
|
|
||||||
|
procGlobalAlloc = kernel32.NewProc("GlobalAlloc")
|
||||||
|
procGlobalLock = kernel32.NewProc("GlobalLock")
|
||||||
|
procGlobalUnlock = kernel32.NewProc("GlobalUnlock")
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
cfUnicodeText = 13
|
||||||
|
gmemMoveable = 0x0002
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetClipboard sets the Windows clipboard to the given UTF-8 text.
|
||||||
|
func (w *WindowsInputInjector) SetClipboard(text string) {
|
||||||
|
utf16, err := windows.UTF16FromString(text)
|
||||||
|
if err != nil {
|
||||||
|
log.Tracef("clipboard UTF16 encode: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
size := uintptr(len(utf16) * 2)
|
||||||
|
hMem, _, _ := procGlobalAlloc.Call(gmemMoveable, size)
|
||||||
|
if hMem == 0 {
|
||||||
|
log.Tracef("GlobalAlloc for clipboard: allocation returned nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ptr, _, _ := procGlobalLock.Call(hMem)
|
||||||
|
if ptr == 0 {
|
||||||
|
log.Tracef("GlobalLock for clipboard: lock returned nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
copy(unsafe.Slice((*uint16)(unsafe.Pointer(ptr)), len(utf16)), utf16)
|
||||||
|
procGlobalUnlock.Call(hMem)
|
||||||
|
|
||||||
|
r, _, lerr := procOpenClipboard.Call(0)
|
||||||
|
if r == 0 {
|
||||||
|
log.Tracef("OpenClipboard: %v", lerr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer procCloseClipboard.Call()
|
||||||
|
|
||||||
|
procEmptyClipboard.Call()
|
||||||
|
r, _, lerr = procSetClipboardData.Call(cfUnicodeText, hMem)
|
||||||
|
if r == 0 {
|
||||||
|
log.Tracef("SetClipboardData: %v", lerr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClipboard reads the Windows clipboard as UTF-8 text.
|
||||||
|
func (w *WindowsInputInjector) GetClipboard() string {
|
||||||
|
r, _, _ := procIsClipboardFormatAvailable.Call(cfUnicodeText)
|
||||||
|
if r == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
r, _, lerr := procOpenClipboard.Call(0)
|
||||||
|
if r == 0 {
|
||||||
|
log.Tracef("OpenClipboard for read: %v", lerr)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer procCloseClipboard.Call()
|
||||||
|
|
||||||
|
hData, _, _ := procGetClipboardData.Call(cfUnicodeText)
|
||||||
|
if hData == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
ptr, _, _ := procGlobalLock.Call(hData)
|
||||||
|
if ptr == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer procGlobalUnlock.Call(hData)
|
||||||
|
|
||||||
|
return windows.UTF16PtrToString((*uint16)(unsafe.Pointer(ptr)))
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ InputInjector = (*WindowsInputInjector)(nil)
|
||||||
|
|
||||||
|
var _ ScreenCapturer = (*DesktopCapturer)(nil)
|
||||||
242
client/vnc/server/input_x11.go
Normal file
242
client/vnc/server/input_x11.go
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
//go:build (linux && !android) || freebsd
|
||||||
|
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/jezek/xgb"
|
||||||
|
"github.com/jezek/xgb/xproto"
|
||||||
|
"github.com/jezek/xgb/xtest"
|
||||||
|
)
|
||||||
|
|
||||||
|
// X11InputInjector injects keyboard and mouse events via the XTest extension.
|
||||||
|
type X11InputInjector struct {
|
||||||
|
conn *xgb.Conn
|
||||||
|
root xproto.Window
|
||||||
|
screen *xproto.ScreenInfo
|
||||||
|
display string
|
||||||
|
keysymMap map[uint32]byte
|
||||||
|
lastButtons uint8
|
||||||
|
clipboardTool string
|
||||||
|
clipboardToolName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewX11InputInjector connects to the X11 display and initializes XTest.
|
||||||
|
func NewX11InputInjector(display string) (*X11InputInjector, error) {
|
||||||
|
detectX11Display()
|
||||||
|
|
||||||
|
if display == "" {
|
||||||
|
display = os.Getenv("DISPLAY")
|
||||||
|
}
|
||||||
|
if display == "" {
|
||||||
|
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := xgb.NewConnDisplay(display)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := xtest.Init(conn); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, fmt.Errorf("init XTest extension: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
setup := xproto.Setup(conn)
|
||||||
|
if len(setup.Roots) == 0 {
|
||||||
|
conn.Close()
|
||||||
|
return nil, fmt.Errorf("no X11 screens")
|
||||||
|
}
|
||||||
|
screen := setup.Roots[0]
|
||||||
|
|
||||||
|
inj := &X11InputInjector{
|
||||||
|
conn: conn,
|
||||||
|
root: screen.Root,
|
||||||
|
screen: &screen,
|
||||||
|
display: display,
|
||||||
|
}
|
||||||
|
inj.cacheKeyboardMapping()
|
||||||
|
inj.resolveClipboardTool()
|
||||||
|
|
||||||
|
log.Infof("X11 input injector ready (display=%s)", display)
|
||||||
|
return inj, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InjectKey simulates a key press or release. keysym is an X11 KeySym.
|
||||||
|
func (x *X11InputInjector) InjectKey(keysym uint32, down bool) {
|
||||||
|
keycode := x.keysymToKeycode(keysym)
|
||||||
|
if keycode == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var eventType byte
|
||||||
|
if down {
|
||||||
|
eventType = xproto.KeyPress
|
||||||
|
} else {
|
||||||
|
eventType = xproto.KeyRelease
|
||||||
|
}
|
||||||
|
|
||||||
|
xtest.FakeInput(x.conn, eventType, keycode, 0, x.root, 0, 0, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InjectPointer simulates mouse movement and button events.
|
||||||
|
func (x *X11InputInjector) InjectPointer(buttonMask uint8, px, py, serverW, serverH int) {
|
||||||
|
if serverW == 0 || serverH == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scale to actual screen coordinates.
|
||||||
|
screenW := int(x.screen.WidthInPixels)
|
||||||
|
screenH := int(x.screen.HeightInPixels)
|
||||||
|
absX := px * screenW / serverW
|
||||||
|
absY := py * screenH / serverH
|
||||||
|
|
||||||
|
// Move pointer.
|
||||||
|
xtest.FakeInput(x.conn, xproto.MotionNotify, 0, 0, x.root, int16(absX), int16(absY), 0)
|
||||||
|
|
||||||
|
// Handle button events. RFB button mask: bit0=left, bit1=middle, bit2=right,
|
||||||
|
// bit3=scrollUp, bit4=scrollDown. X11 buttons: 1=left, 2=middle, 3=right,
|
||||||
|
// 4=scrollUp, 5=scrollDown.
|
||||||
|
type btnMap struct {
|
||||||
|
rfbBit uint8
|
||||||
|
x11Btn byte
|
||||||
|
}
|
||||||
|
buttons := [...]btnMap{
|
||||||
|
{0x01, 1}, // left
|
||||||
|
{0x02, 2}, // middle
|
||||||
|
{0x04, 3}, // right
|
||||||
|
{0x08, 4}, // scroll up
|
||||||
|
{0x10, 5}, // scroll down
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, b := range buttons {
|
||||||
|
pressed := buttonMask&b.rfbBit != 0
|
||||||
|
wasPressed := x.lastButtons&b.rfbBit != 0
|
||||||
|
if b.x11Btn >= 4 {
|
||||||
|
// Scroll: send press+release on each scroll event.
|
||||||
|
if pressed {
|
||||||
|
xtest.FakeInput(x.conn, xproto.ButtonPress, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||||
|
xtest.FakeInput(x.conn, xproto.ButtonRelease, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if pressed && !wasPressed {
|
||||||
|
xtest.FakeInput(x.conn, xproto.ButtonPress, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||||
|
} else if !pressed && wasPressed {
|
||||||
|
xtest.FakeInput(x.conn, xproto.ButtonRelease, b.x11Btn, 0, x.root, 0, 0, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
x.lastButtons = buttonMask
|
||||||
|
}
|
||||||
|
|
||||||
|
// cacheKeyboardMapping fetches the X11 keyboard mapping once and stores it
|
||||||
|
// as a keysym-to-keycode map, avoiding a round-trip per keystroke.
|
||||||
|
func (x *X11InputInjector) cacheKeyboardMapping() {
|
||||||
|
setup := xproto.Setup(x.conn)
|
||||||
|
minKeycode := setup.MinKeycode
|
||||||
|
maxKeycode := setup.MaxKeycode
|
||||||
|
|
||||||
|
reply, err := xproto.GetKeyboardMapping(x.conn, minKeycode,
|
||||||
|
byte(maxKeycode-minKeycode+1)).Reply()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("cache keyboard mapping: %v", err)
|
||||||
|
x.keysymMap = make(map[uint32]byte)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m := make(map[uint32]byte, int(maxKeycode-minKeycode+1)*int(reply.KeysymsPerKeycode))
|
||||||
|
keysymsPerKeycode := int(reply.KeysymsPerKeycode)
|
||||||
|
for i := int(minKeycode); i <= int(maxKeycode); i++ {
|
||||||
|
offset := (i - int(minKeycode)) * keysymsPerKeycode
|
||||||
|
for j := 0; j < keysymsPerKeycode; j++ {
|
||||||
|
ks := uint32(reply.Keysyms[offset+j])
|
||||||
|
if ks != 0 {
|
||||||
|
if _, exists := m[ks]; !exists {
|
||||||
|
m[ks] = byte(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
x.keysymMap = m
|
||||||
|
}
|
||||||
|
|
||||||
|
// keysymToKeycode looks up a cached keysym-to-keycode mapping.
|
||||||
|
// Returns 0 if the keysym is not mapped.
|
||||||
|
func (x *X11InputInjector) keysymToKeycode(keysym uint32) byte {
|
||||||
|
return x.keysymMap[keysym]
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetClipboard sets the X11 clipboard using xclip or xsel.
|
||||||
|
func (x *X11InputInjector) SetClipboard(text string) {
|
||||||
|
if x.clipboardTool == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
if x.clipboardToolName == "xclip" {
|
||||||
|
cmd = exec.Command(x.clipboardTool, "-selection", "clipboard")
|
||||||
|
} else {
|
||||||
|
cmd = exec.Command(x.clipboardTool, "--clipboard", "--input")
|
||||||
|
}
|
||||||
|
cmd.Env = x.clipboardEnv()
|
||||||
|
cmd.Stdin = strings.NewReader(text)
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
log.Debugf("set clipboard via %s: %v", x.clipboardToolName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *X11InputInjector) resolveClipboardTool() {
|
||||||
|
for _, name := range []string{"xclip", "xsel"} {
|
||||||
|
path, err := exec.LookPath(name)
|
||||||
|
if err == nil {
|
||||||
|
x.clipboardTool = path
|
||||||
|
x.clipboardToolName = name
|
||||||
|
log.Debugf("clipboard tool resolved to %s", path)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Debugf("no clipboard tool (xclip/xsel) found, clipboard sync disabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClipboard reads the X11 clipboard using xclip or xsel.
|
||||||
|
func (x *X11InputInjector) GetClipboard() string {
|
||||||
|
if x.clipboardTool == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
if x.clipboardToolName == "xclip" {
|
||||||
|
cmd = exec.Command(x.clipboardTool, "-selection", "clipboard", "-o")
|
||||||
|
} else {
|
||||||
|
cmd = exec.Command(x.clipboardTool, "--clipboard", "--output")
|
||||||
|
}
|
||||||
|
cmd.Env = x.clipboardEnv()
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
log.Tracef("get clipboard via %s: %v", x.clipboardToolName, err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *X11InputInjector) clipboardEnv() []string {
|
||||||
|
env := []string{"DISPLAY=" + x.display}
|
||||||
|
if auth := os.Getenv("XAUTHORITY"); auth != "" {
|
||||||
|
env = append(env, "XAUTHORITY="+auth)
|
||||||
|
}
|
||||||
|
return env
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close releases X11 resources.
|
||||||
|
func (x *X11InputInjector) Close() {
|
||||||
|
x.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ InputInjector = (*X11InputInjector)(nil)
|
||||||
|
var _ ScreenCapturer = (*X11Poller)(nil)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user