Compare commits

..

11 Commits

Author SHA1 Message Date
Viktor Liu
f49ec5db2e Recover from tun device read/write panics and restart the client 2026-06-12 15:41:54 +02:00
Riccardo Manfrin
2bcea9d582 [client] add MDM configuration profile support (Windows registry + macOS plist) (#6374)
* Initial scaffolding

* Applies MDM override

* Unit tests

* Helpers business logic

* Return error if trying to modify any config that is gated by MDM

* Add ManagedFields to returned config over GetConfig

* Adds initial 101 MDM policy business logic testing

* gRPC MDM changes

* MDM Name scoping for clarity

* Implements windows loading of MDM policy

* Adds missing WGPort config

* Cleanup setupKey to align to linear

* Align split tunnel code

* Adds some log

* Prefix every log with MDM

* Adds debug config cobra command

This can be useful for troubleshooting and checking config
now that its resolution is not trivial

defaults > config > env cars > CLI/UI > MDM

* Adds MDM 1m diff checker & reloader

* Adds also up/start after cancel

* Publishes event for UI to sync upon MDM changes

* Add events to resync UI to actual config

This also provide fixup for UI no aligning to changed config when coming from cli up with config flags.

* UI behavior conflicts relaxation

UI sends full config snapshot with all values. It doesn't
make sense to block it if the values are aligned with the
values constrained by the MDM policy. It's just simplier
to allow values that are compliant. (this goes for the CLI
as well at this point)

* Lock toggle Settngs

* Advanced Settings locking

* Fixup presharedkey

* Apply MDM locks

* Toggle gray in/out for Advanced Settings

* Adds support for disabling of Profiles and UpdateSettings feature flags

* Adds Gate Login as well when --disable-update-settings=true is given to service

This commit tries to settle things with an old PR-4237 which had relaxed
the case where the SetConfig returned an `Unavailable` code error.

Under this circumnstance the PR allowed the upFunc to just emit a warning and
progress further with the login gRPC. Since the login call is consuming
the --management-url coming from the `up` command, it might be possible
to abuse the "Unavailable" code to inject a management URL that is different
from the configured one even though the --disable-update-settings is set
to true (?)

* Evaluate disable-update-settings errors only when there's an actual override

* [UI] Fixup advanced Settings

* [UI] Fixup for preshared key

* [UI] Fixup for profile enable/disable toggle

We need to align the initial state to evaluate the delta in case.

The initial state has to be "true" since the profile starts visible.
Then we receive MDM and transition the cache bool value to the actual
MDM imposed state

* Enforces disable networks

* [UI] Aligns to "enable/disable once on change only"

* Fixup: MDM wins. always

* Removes --disable-advanced-settings

It was a typo in our meetings. the actual thing is --disable-update-settings

* [PROTO] Removes --disable-advanced-settings

* [UI] Removes --disable-advanced-settings

* Pins feat profile retrieval to notif event

* [UI] Fix for "hide" not working when propagating to parent with children

* Adds dep for reading plist files

* Introduces support for darwing plist loading

* Tests MDM config reload via ticker

* [PROVISIONING] ADMX/ADML/PS/bash scripts/templates

* CI fixes

- Add docstrings to `mdm_integration`
- refactor for cognitive complexity
- mod tidy

* Linting

* Add docstrings to `mdm_integration`

* nil,nil is no policy and no error. Allow it

* nil,nil is no policy and no error. Allow it

* exclude MDM profile adminstrated keys data from debug bundle

* Fixes Rosenpass left disable after MDM unlock

* Partial revert coderabbit added docstrings

* Renaming fix

* Avoid locking on clientRunning bool when the connection is aborted for whatever reason

We want to just signal this through the giveUpChan, we will manage the signal from
the waiter side and in case set it to false there. THis way we avoid locking,
which should allow the MDM down+wait_for_term_chan_signal_+up procedure

clientRunning is used to signal two different conditions here:

1. the initialization procedure is over (we have an engine)
2. the connection being up (or being attempted)

Probably these two functionalities should not alias, and the failure of the second condition
(because of any error) should just drive a reconnection (currently it's not happening,
and we silently go idle).
OR, mor probably, the two things are the SAME and there should not exist a case where
we did the "Up" initialization and connection attempt but we are not still attempting it.

* Moves test helper at te very bottom

* Addresses github comments

* No lock no copy

* Prevents engine not stopping within 10 secs from being paired by another instance

We instead juts SKIP updating the policy, so
1. the MDM ticker will kick in 1 minute time,
2. find the policy misaligned,
3. enter the onMDMPolicyChange,
4. find the s.clientRunning == true
   (because it is set to false only in server cleanupConnection,
   and not by s.actCancel())
5. call s.actCancel() again if not nil
6. immediately return from <-s.clientGiveUpChan
7. finally call s.restartEngineForMDMLocked()

* Since we ARE running there should be a config

If the config was cancelled midflight, connect will abort later on

* DisableAutoConnect should not stop a running connection.

DisableAutoConnect should just avoid the connection attempts *when the service starts*.
If we are started and we are up and running, DisableAutoConnect should not kick in.

Another PR will follow about this topic

* Removes unused vars

* Moves callback into Run method arg

* align comment to removal of DisableAutoConnect

DisableAutoConnect should just avoid the connection attempts *when the service starts*.
If we are started and we are up and running, DisableAutoConnect should not kick in

* Removes unused managed_fields data.

This was initially used to drive the UI but approach changed
to reload config/features upon notifications which makes this data redundant.

* Reorder stuff

* Unexport unrequired vars/functions

PoliciesEqual → policiesEqual
AllKeys → allKeys

* Adds list of MDM managed fields in the debug bundle
2026-06-12 12:28:49 +02:00
Maycon Santos
8ff3b06cf1 [client] Index peer tunnel IPs for faster PeerStateByIP lookup (#6412)
* [client] Index peer tunnel IPs for O(1) PeerStateByIP lookup

Replace the linear scan over all peers with an ipToKey map maintained
by AddPeer/RemovePeer, covering both IPv4 and IPv6 tunnel addresses.

Offline peers are intentionally no longer resolvable by IP: only active
peers can carry traffic, so IdentityForIP and the DNS disconnected-peer
filter now treat them as unknown, same as foreign IPs.

Skip the DNS answer filter for single-record responses; dropping the
only answer was always restored by the empty-answer escape hatch, so
the fast path is behavior-neutral.

* Ensure `ipToKey` entries are only removed if they match the peer being deleted, preventing accidental removal of unrelated mappings.
2026-06-12 10:24:15 +02:00
Maycon Santos
d7703767d5 [client, proxy] cancel context before stopping engine on embedded client (#6397)
- Engine.Start takes syncMsgMux with a deferred unlock (engine.go:445) and parks in receiveSignalEvents → WaitStreamConnected (engine.go:1762), which only wakes on
  signal-stream connect or client-context cancellation.
  - When signal never connects, the 30s startup timeout fires and embed.Client.Start's rollback (embed.go:281) called client.Stop() → Engine.Stop, which blocks acquiring
  syncMsgMux (engine.go:318). The cancel() that would unpark Start was deferred until Start returned — permanent cycle. RemovePeer calls (g43/g385) then queue behind the
  lifecycle mutex.
  - Notably, embed.Client.Stop and the daemon's cleanupConnection both cancel before stopping — the startup rollback was the only path that didn't.
  - Engine.Start takes syncMsgMux with a deferred unlock (engine.go:445) and parks in receiveSignalEvents → WaitStreamConnected (engine.go:1762), which only wakes on
  signal-stream connect or client-context cancellation.
  - When signal never connects, the 30s startup timeout fires and embed.Client.Start's rollback (embed.go:281) called client.Stop() → Engine.Stop, which blocks acquiring
  syncMsgMux (engine.go:318). The cancel() that would unpark Start was deferred until Start returned — permanent cycle. RemovePeer calls (g43/g385) then queue behind the
  lifecycle mutex.
  - Notably, embed.Client.Stop and the daemon's cleanupConnection both cancel before stopping — the startup rollback was the only path that didn't.
2026-06-10 21:26:54 +02:00
Maycon Santos
7feda907ca [management] fix L4 service update when no custom port (#6396)
This fixes an issue where L4 service update is not possible when proxy clusters don't support custom ports
2026-06-10 18:55:24 +02:00
Maycon Santos
62da482133 [management] Add version gate to stop sending deprecated RemotePeers field (#6371)
* [management] Add version gate to stop sending deprecated RemotePeers field

don't send top-level remote peers on peers in the  v0.29.3 or newer

* precompute deprecated remote peers version constraint

* [management] update tests to validate network map-based remote peers

* [management] move deprecatedRemotePeersVersion constant closer to its usage

* fix misplaced precomputed constraint definition

* ensure top-level RemotePeers is empty for v0.29.3+ clients
2026-06-10 16:59:09 +02:00
Philip Laine
079bce3c2f Add commands to discover and write Kubernetes configuration (#6260) 2026-06-10 15:00:10 +02:00
Maycon Santos
1a09aa6715 [misc] Update Go toolchain version in go.mod (#6377) 2026-06-10 14:50:57 +02:00
Maycon Santos
61abf5b9ea [proxy] Use UUID for proxy ID generation (#6391)
Use UUID for proxy ID instead of the second to avoid race conditions when running multiple nodes at the same time.
2026-06-10 13:35:26 +02:00
Boris Dolgov
e229050ba3 [proxy] Notify certificate ready for domains covered by the static certificate (#6389) 2026-06-10 12:05:34 +02:00
Zoltan Papp
e919b2d55d [client] Preserve posture checks on config-only sync updates (#6373)
* [client] Preserve posture checks on config-only sync updates

When management sends a MessageTypeControlConfig update (e.g. relay token
rotation), the SyncResponse carries no NetworkMap and no Checks. Moving the
updateChecksIfNew call after the nm == nil guard ensures posture checks are
only updated when a full network map is present, preventing relay token
rotation from silently clearing the previously applied posture check state.

* [client] Clarify posture check update logic with explicit comment

* [client] Extract NetBird config and sync persistence into helpers

Move the NetbirdConfig handling block out of handleSync into
updateNetbirdConfig and the sync response persistence into
persistSyncResponse, mirroring updateChecksIfNew. This flattens
handleSync and makes the individual update steps unit-testable.
2026-06-10 11:43:24 +02:00
177 changed files with 6357 additions and 22969 deletions

View File

@@ -3,12 +3,14 @@ package cmd
import (
"context"
"fmt"
"os/user"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/internal"
@@ -85,6 +87,73 @@ var persistenceCmd = &cobra.Command{
RunE: setSyncResponsePersistence,
}
var debugConfigCmd = &cobra.Command{
Use: "config",
Example: " netbird debug config",
Short: "Dump the effective configuration",
Long: "Prints the daemon's resolved configuration (after applying defaults, file, env, CLI input, and MDM policy overrides) as JSON. Includes the list of MDM-managed fields.",
RunE: debugConfigDump,
}
// debugConfigDump implements `netbird debug config`. It resolves the
// active profile, queries the daemon for the effective configuration
// via GetConfig, and prints the resulting GetConfigResponse as JSON
// (via protojson with EmitUnpopulated=true so the output is stable
// across runs and includes zero-valued fields).
//
// Useful for verifying MDM enforcement end-to-end: the response's
// mDMManagedFields array is the single source of truth for "which
// fields is the daemon currently enforcing from the MDM source", and
// every config field side-by-side with that list confirms the merge
// result. Secrets in the response (e.g. PreSharedKey) are already
// redacted by the daemon-side handler.
func debugConfigDump(cmd *cobra.Command, _ []string) error {
pm := profilemanager.NewProfileManager()
activeProf, err := pm.GetActiveProfile()
if err != nil {
return fmt.Errorf("get active profile: %v", err)
}
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %v", err)
}
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
ProfileName: activeProf.Name,
Username: currUser.Username,
})
if err != nil {
return fmt.Errorf("failed to get config: %v", status.Convert(err).Message())
}
// Use protojson so well-known fields render correctly; emit defaults so
// the operator sees every field even when zero/empty.
m := protojson.MarshalOptions{Multiline: true, Indent: " ", EmitUnpopulated: true}
out, err := m.Marshal(resp)
if err != nil {
return fmt.Errorf("marshal config: %w", err)
}
cmd.Println(string(out))
return nil
}
// debugBundle requests the daemon to create a debug bundle and prints
// the resulting local file path and, if uploaded, the uploaded file
// key. It uses the package flags (anonymize, system info, log file
// count, CLI version, optional upload URL) to configure the bundle
// request. Returns an error if the RPC fails or if the daemon reports
// an upload failure reason.
func debugBundle(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {

301
client/cmd/kubernetes.go Normal file
View File

@@ -0,0 +1,301 @@
package cmd
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"slices"
"strings"
"github.com/goccy/go-yaml"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/proto"
)
const (
KubernetesDNSSuffix = "netbird-kubeapi-proxy"
)
var kubernetesCmd = &cobra.Command{
Use: "kubernetes",
Short: "Kubernetes cluster commands.",
Long: "Kubernetes cluster commands.",
}
var kubernetesListCmd = &cobra.Command{
Use: "list",
RunE: kubernetesList,
Short: "List Kubernetes clusters.",
Long: "List Kubernetes clusters by discovering NetBird peers running netbird-kubeapi-proxy.",
}
var kubernetesWriteKubeconfigCmd = &cobra.Command{
Use: "write-kubeconfig",
RunE: kubernetesWriteKubeconfig,
Args: cobra.ExactArgs(1),
Short: "Write kubeconfig for a Kubernetes cluster.",
Long: "Updates kubeconfig in place to allow token-less access to the Kubernetes cluster through NetBird.",
}
func init() {
kubernetesWriteKubeconfigCmd.Flags().String("kubeconfig", "", "path to kubeconfig file")
}
func kubernetesList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
statusResp, err := client.Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
return err
}
kcs, err := getKubernetesClusters(cmd.Context(), statusResp.FullStatus.Peers, "")
if err != nil {
return err
}
if len(kcs) == 0 {
cmd.Println("No Kubernetes clusters available.")
return nil
}
cmd.Println("Available Kubernetes clusters:")
for _, k := range kcs {
cmd.Printf("\n - Name: %s\n FQDN: %s\n Version: %s\n", k.name, k.url.Host, k.version)
}
return nil
}
func kubernetesWriteKubeconfig(cmd *cobra.Command, args []string) error {
kubeconfigPath, err := resolveKubeconfigPath(cmd)
if err != nil {
return err
}
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
statusResp, err := client.Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil {
return err
}
clusterName := args[0]
kcs, err := getKubernetesClusters(cmd.Context(), statusResp.FullStatus.Peers, clusterName)
if err != nil {
return err
}
if len(kcs) == 0 {
return fmt.Errorf("kubernetes cluster named %s not found", clusterName)
}
if len(kcs) > 1 {
return fmt.Errorf("too many Kubernetes clusters returned")
}
err = writeKubeconfig(kubeconfigPath, kcs[0])
if err != nil {
return err
}
return nil
}
type kubernetesCluster struct {
name string
url *url.URL
version string
}
func getKubernetesClusters(ctx context.Context, peers []*proto.PeerState, nameFilter string) ([]kubernetesCluster, error) {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
httpClient := &http.Client{
Transport: transport,
}
resolver := net.Resolver{
// Required so both DNS records are returned.
// https://github.com/golang/go/issues/17093
PreferGo: true,
}
kcs := []kubernetesCluster{}
attempted := map[string]struct{}{}
for _, peer := range peers {
fqdns, err := resolver.LookupAddr(ctx, peer.IP)
if err != nil {
return nil, err
}
for _, fqdn := range fqdns {
if _, ok := attempted[fqdn]; ok {
continue
}
attempted[fqdn] = struct{}{}
comps := strings.Split(fqdn, ".")
if len(comps) < 2 {
continue
}
if comps[1] != KubernetesDNSSuffix {
continue
}
if nameFilter != "" && nameFilter != comps[0] {
continue
}
clusterURL, clusterVersion, err := fingerprintClusters(ctx, httpClient, fqdn)
if err != nil {
log.Debugf("could not fingerprint Kubernetes cluster %s %q", fqdn, err)
continue
}
kc := kubernetesCluster{
name: comps[0],
url: clusterURL,
version: clusterVersion,
}
if nameFilter != "" {
return []kubernetesCluster{kc}, nil
}
kcs = append(kcs, kc)
}
}
return kcs, nil
}
func fingerprintClusters(ctx context.Context, httpClient *http.Client, fqdn string) (*url.URL, string, error) {
clusterURL, err := url.Parse("https://" + fqdn)
if err != nil {
return nil, "", err
}
versionURL, err := clusterURL.Parse("/version")
if err != nil {
return nil, "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, versionURL.String(), nil)
if err != nil {
return nil, "", err
}
resp, err := httpClient.Do(req)
if err != nil {
return nil, "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, "", fmt.Errorf("expected %d response but got %s", http.StatusOK, resp.Status)
}
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, "", err
}
versionData := map[string]string{}
err = json.Unmarshal(b, &versionData)
if err != nil {
return nil, "", err
}
version, ok := versionData["gitVersion"]
if !ok {
return nil, "", errors.New("no version found in response")
}
return clusterURL, version, nil
}
func resolveKubeconfigPath(cmd *cobra.Command) (string, error) {
if cmd.Flags().Changed("kubeconfig") {
path, err := cmd.Flags().GetString("kubeconfig")
if err != nil {
return "", err
}
return path, nil
}
if env := os.Getenv("KUBECONFIG"); env != "" {
return env, nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("could not determine home directory: %w", err)
}
return filepath.Join(home, ".kube", "config"), nil
}
func writeKubeconfig(kubeconfigPath string, kc kubernetesCluster) error {
b, err := os.ReadFile(kubeconfigPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
var cfg map[string]any
if err := yaml.Unmarshal(b, &cfg); err != nil {
return err
}
if cfg == nil {
cfg = map[string]any{
"apiVersion": "v1",
"kind": "Config",
}
}
cfg["clusters"] = appendWithName(cfg["clusters"], map[string]any{
"name": kc.name,
"cluster": map[string]any{
"server": kc.url.String(),
"insecure-skip-tls-verify": true,
},
})
cfg["users"] = appendWithName(cfg["users"], map[string]any{
"name": "netbird",
"user": map[string]any{
"token": "none",
},
})
cfg["contexts"] = appendWithName(cfg["contexts"], map[string]any{
"name": kc.name,
"context": map[string]any{
"cluster": kc.name,
"user": "netbird",
"namespace": "default",
},
})
cfg["current-context"] = kc.name
out, err := yaml.Marshal(cfg)
if err != nil {
return err
}
if err := os.WriteFile(kubeconfigPath, out, 0o600); err != nil {
return err
}
return nil
}
func appendWithName(data any, add map[string]any) any {
if data == nil {
return []any{add}
}
v, ok := data.([]any)
if !ok {
return []any{add}
}
i := slices.IndexFunc(v, func(item any) bool {
m, ok := item.(map[string]any)
if !ok {
return false
}
return m["name"] == add["name"]
})
if i == -1 {
return append(v, add)
}
v[i] = add
return v
}

View File

@@ -0,0 +1,120 @@
package cmd
import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
)
func TestFingerprintClusters(t *testing.T) {
t.Parallel()
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
//nolint: errcheck
w.Write([]byte(`{"gitVersion": "foobar"}`))
}))
defer srv.Close()
clusterURL, clusterVersion, err := fingerprintClusters(t.Context(), srv.Client(), srv.Listener.Addr().String())
require.NoError(t, err)
require.Equal(t, srv.URL, clusterURL.String())
require.Equal(t, "foobar", clusterVersion)
}
func TestResolveKubeconfigPath(t *testing.T) {
home, err := os.UserHomeDir()
if err != nil {
t.Fatalf("could not determine home directory: %v", err)
}
defaultPath := filepath.Join(home, ".kube", "config")
path, err := resolveKubeconfigPath(&cobra.Command{})
require.NoError(t, err)
require.Equal(t, defaultPath, path)
flagPath := "flag-path"
cmd := &cobra.Command{}
cmd.Flags().String("kubeconfig", "", "")
err = cmd.Flags().Set("kubeconfig", flagPath)
require.NoError(t, err)
path, err = resolveKubeconfigPath(cmd)
require.NoError(t, err)
require.Equal(t, flagPath, path)
envPath := "env-path"
t.Setenv("KUBECONFIG", envPath)
path, err = resolveKubeconfigPath(&cobra.Command{})
require.NoError(t, err)
require.Equal(t, envPath, path)
}
func TestWriteKubeconfig(t *testing.T) {
t.Parallel()
tests := []struct {
name string
existing string
}{
{
name: "empty file",
},
{
name: "existing content",
existing: `apiVersion: v1
clusters:
- cluster:
insecure-skip-tls-verify: true
server: https://foobar.com
name: foo
current-context: test
kind: Config
users: []
`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
kubeconfigPath := filepath.Join(t.TempDir(), "config")
err := os.WriteFile(kubeconfigPath, []byte(tt.existing), 0o644)
require.NoError(t, err)
kc := kubernetesCluster{
name: "foo",
url: &url.URL{Scheme: "https", Host: "example.com"},
}
err = writeKubeconfig(kubeconfigPath, kc)
require.NoError(t, err)
b, err := os.ReadFile(kubeconfigPath)
require.NoError(t, err)
expected := `apiVersion: v1
clusters:
- cluster:
insecure-skip-tls-verify: true
server: https://example.com
name: foo
contexts:
- context:
cluster: foo
namespace: default
user: netbird
name: foo
current-context: foo
kind: Config
users:
- name: netbird
user:
token: none
`
require.Equal(t, expected, string(b))
})
}
}

View File

@@ -95,7 +95,9 @@ var (
}
)
// Execute executes the root command.
// Execute runs the appropriate Cobra command for the CLI.
// If the process is the update binary it delegates to updateCmd; otherwise it runs the root command.
// It returns any error produced during command execution.
func Execute() error {
if isUpdateBinary() {
return updateCmd.Execute()
@@ -103,6 +105,16 @@ func Execute() error {
return rootCmd.Execute()
}
// init initialises package-level defaults and configures the root
// Cobra command tree. Sets platform-specific config / log directory
// paths (including legacy Wiretrustee fallbacks) and a default daemon
// address; registers persistent CLI flags (daemon address,
// management / admin URLs, logging, setup key (file and inline,
// mutually exclusive), preshared key, hostname, anonymise, config
// path); attaches top-level and nested subcommands to the root
// command; and registers `up`-specific persistent flags (external IP
// maps, custom DNS resolver address, Rosenpass options, auto-connect
// disabling, lazy connection).
func init() {
defaultConfigPathDir = "/etc/netbird/"
defaultLogFileDir = "/var/log/netbird/"
@@ -168,6 +180,12 @@ func init() {
logCmd.AddCommand(logLevelCmd)
debugCmd.AddCommand(forCmd)
debugCmd.AddCommand(persistenceCmd)
debugCmd.AddCommand(debugConfigCmd)
// kubernetes commands
rootCmd.AddCommand(kubernetesCmd)
kubernetesCmd.AddCommand(kubernetesListCmd)
kubernetesCmd.AddCommand(kubernetesWriteKubeconfigCmd)
// profile commands
profileCmd.AddCommand(profileListCmd)

View File

@@ -361,12 +361,6 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(serverSSHAllowedFlag).Changed {
req.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
req.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(disableVNCApprovalFlag).Changed {
req.DisableVNCApproval = &disableVNCApproval
}
if cmd.Flag(enableSSHRootFlag).Changed {
req.EnableSSHRoot = &enableSSHRoot
}
@@ -473,14 +467,30 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
ic.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(disableVNCApprovalFlag).Changed {
ic.DisableVNCApproval = &disableVNCApproval
if cmd.Flag(enableSSHRootFlag).Changed {
ic.EnableSSHRoot = &enableSSHRoot
}
applySSHFlagsToConfig(cmd, &ic)
if cmd.Flag(enableSSHSFTPFlag).Changed {
ic.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
ic.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
@@ -556,49 +566,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
return &ic, nil
}
func applySSHFlagsToConfig(cmd *cobra.Command, ic *profilemanager.ConfigInput) {
if cmd.Flag(enableSSHRootFlag).Changed {
ic.EnableSSHRoot = &enableSSHRoot
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
ic.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
ic.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
}
}
func applySSHFlagsToLogin(cmd *cobra.Command, req *proto.LoginRequest) {
if cmd.Flag(enableSSHRootFlag).Changed {
req.EnableSSHRoot = &enableSSHRoot
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
req.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
req.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ttl := int32(sshJWTCacheTTL)
req.SshJWTCacheTTL = &ttl
}
}
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
@@ -628,14 +595,31 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
if cmd.Flag(serverSSHAllowedFlag).Changed {
loginRequest.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
loginRequest.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(disableVNCApprovalFlag).Changed {
loginRequest.DisableVNCApproval = &disableVNCApproval
if cmd.Flag(enableSSHRootFlag).Changed {
loginRequest.EnableSSHRoot = &enableSSHRoot
}
applySSHFlagsToLogin(cmd, &loginRequest)
if cmd.Flag(enableSSHSFTPFlag).Changed {
loginRequest.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
loginRequest.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if cmd.Flag(disableAutoConnectFlag).Changed {
loginRequest.DisableAutoConnect = &autoConnectDisabled

View File

@@ -1,100 +0,0 @@
//go:build windows || (darwin && !ios)
package cmd
import (
"fmt"
"net"
"net/netip"
"os"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
var (
vncAgentSocket string
vncAgentTargetUID uint32
)
func init() {
vncAgentCmd.Flags().StringVar(&vncAgentSocket, "socket", "", "Unix-domain socket path the agent listens on (required)")
vncAgentCmd.Flags().Uint32Var(&vncAgentTargetUID, "target-uid", 0, "uid the agent should drop privileges to before listening (darwin only; 0 = stay as current uid)")
rootCmd.AddCommand(vncAgentCmd)
}
// vncAgentCmd runs a VNC server inside the user's interactive session,
// listening on a Unix-domain socket. The NetBird service spawns it: on
// Windows via CreateProcessAsUser into the console session, on macOS via
// launchctl asuser into the Aqua session.
var vncAgentCmd = &cobra.Command{
Use: "vnc-agent",
Short: "Run VNC capture agent (internal, spawned by service)",
Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error {
log.SetReportCaller(true)
log.SetFormatter(&log.JSONFormatter{})
log.SetOutput(os.Stderr)
if vncAgentSocket == "" {
return fmt.Errorf("--socket is required")
}
token := os.Getenv("NB_VNC_AGENT_TOKEN")
if token == "" {
return fmt.Errorf("NB_VNC_AGENT_TOKEN not set; agent requires a token from the service")
}
// Purge the token from env so it doesn't leak via /proc/<pid>/environ.
if err := os.Unsetenv("NB_VNC_AGENT_TOKEN"); err != nil {
log.Debugf("unset NB_VNC_AGENT_TOKEN: %v", err)
}
// Drop root privileges to the target console user BEFORE creating
// the listening socket: keeps a post-auth bug in the encoder /
// input / capture paths confined to the user's own privileges
// rather than escalating to host root, and makes the daemon's
// LOCAL_PEERCRED check see the right uid. No-op on Windows
// (both processes run as SYSTEM) and when --target-uid is 0.
if vncAgentTargetUID != 0 {
if err := dropAgentPrivileges(vncAgentTargetUID); err != nil {
return fmt.Errorf("drop privileges to uid %d: %w", vncAgentTargetUID, err)
}
}
if err := os.Remove(vncAgentSocket); err != nil && !os.IsNotExist(err) {
log.Debugf("remove stale socket %s: %v", vncAgentSocket, err)
}
ln, err := net.Listen("unix", vncAgentSocket)
if err != nil {
return fmt.Errorf("listen on %s: %w", vncAgentSocket, err)
}
if err := os.Chmod(vncAgentSocket, 0o600); err != nil {
log.Debugf("chmod %s: %v", vncAgentSocket, err)
}
capturer, injector, err := newAgentResources()
if err != nil {
_ = ln.Close()
return err
}
srv := vncserver.New(vncserver.Config{
Capturer: capturer,
Injector: injector,
DisableAuth: true,
AgentTokenHex: token,
Listener: ln,
})
if err := srv.Start(cmd.Context(), netip.AddrPort{}, netip.Prefix{}); err != nil {
return fmt.Errorf("start vnc server: %w", err)
}
log.Infof("vnc-agent listening on %s, ready", vncAgentSocket)
<-cmd.Context().Done()
log.Info("vnc-agent context cancelled, shutting down")
return srv.Stop()
},
SilenceUsage: true,
}

View File

@@ -1,18 +0,0 @@
//go:build darwin && !ios
package cmd
import (
"fmt"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newAgentResources() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
capturer := vncserver.NewMacPoller()
injector, err := vncserver.NewMacInputInjector()
if err != nil {
return nil, nil, fmt.Errorf("macOS input injector: %w", err)
}
return capturer, injector, nil
}

View File

@@ -1,74 +0,0 @@
//go:build darwin && !ios
package cmd
import (
"fmt"
"os"
"os/user"
"strconv"
"syscall"
)
// dropAgentPrivileges drops the vnc-agent process from root (its
// launchctl-asuser-inherited starting uid) to the target console user
// before any other initialisation runs. Without this the agent runs as
// root for the lifetime of the session; any post-auth memory-safety
// issue in the capture/input/encode paths would then be a root-level
// RCE on the host instead of a user-level one. Also makes the daemon's
// LOCAL_PEERCRED check correctly identify the agent as the console user,
// not as root.
//
// Returns an error when the agent is running as a non-root uid that
// differs from targetUID: non-root can only setuid to itself, so a
// mismatch here means the spawn went to the wrong session.
func dropAgentPrivileges(targetUID uint32) error {
if targetUID == 0 {
return fmt.Errorf("refusing to keep agent running as root (target uid 0)")
}
cur := uint32(os.Getuid())
if cur == targetUID {
return nil
}
if cur != 0 {
return fmt.Errorf("agent uid %d does not match expected %d and we lack root to fix it", cur, targetUID)
}
// Resolve the target user's real primary group rather than reusing
// targetUID as the gid: a user's primary group on macOS is typically
// staff(20), not gid==uid. Fail closed if the lookup fails.
targetGID, err := primaryGroupID(targetUID)
if err != nil {
return err
}
// Drop supplementary groups first: setgid alone doesn't touch the
// auxiliary group list, leaving root's groups attached would let the
// dropped process write to root-only group-writable files.
if err := syscall.Setgroups([]int{}); err != nil {
return fmt.Errorf("setgroups([]): %w", err)
}
if err := syscall.Setgid(targetGID); err != nil {
return fmt.Errorf("setgid(%d): %w", targetGID, err)
}
if err := syscall.Setuid(int(targetUID)); err != nil {
return fmt.Errorf("setuid(%d): %w", targetUID, err)
}
if uint32(os.Getuid()) != targetUID || uint32(os.Geteuid()) != targetUID {
return fmt.Errorf("setuid verification: uid=%d euid=%d, expected %d", os.Getuid(), os.Geteuid(), targetUID)
}
return nil
}
// primaryGroupID resolves the real primary group id of the user with the
// given uid. Fails closed: a lookup or parse error returns an error so the
// caller never falls back to using uid as the gid.
func primaryGroupID(targetUID uint32) (int, error) {
u, err := user.LookupId(strconv.Itoa(int(targetUID)))
if err != nil {
return 0, fmt.Errorf("look up uid %d: %w", targetUID, err)
}
gid, err := strconv.Atoi(u.Gid)
if err != nil {
return 0, fmt.Errorf("parse gid %q for uid %d: %w", u.Gid, targetUID, err)
}
return gid, nil
}

View File

@@ -1,55 +0,0 @@
//go:build darwin && !ios
package cmd
import (
"strings"
"testing"
)
// TestDropAgentPrivileges_RefusesRootTarget locks in the contract that
// dropAgentPrivileges must never be a no-op when asked to keep the
// agent as root (target uid 0). A future caller that passes 0 by
// mistake would otherwise leave the post-auth attack surface running
// with full root privileges.
func TestDropAgentPrivileges_RefusesRootTarget(t *testing.T) {
err := dropAgentPrivileges(0)
if err == nil {
t.Fatal("expected refusal for target uid 0, got nil")
}
if !strings.Contains(err.Error(), "root") {
t.Fatalf("error should mention root, got: %v", err)
}
}
// TestDropAgentPrivileges_NoOpWhenAlreadyTarget covers the dev path
// where the agent is launched by hand as the target user (no root
// available, no setuid needed). The helper must succeed silently
// instead of trying (and failing) a setuid to its current uid.
func TestDropAgentPrivileges_NoOpWhenAlreadyTarget(t *testing.T) {
// Skip when running as root: the early-return path we want to
// cover only fires when current uid == target uid.
uid := currentUIDForTest()
if uid == 0 {
t.Skip("test must not run as root; cannot exercise the no-op early-return")
}
if err := dropAgentPrivileges(uid); err != nil {
t.Fatalf("expected no-op when current uid == target, got: %v", err)
}
}
// TestDropAgentPrivileges_RefusesMismatchedNonRoot guards the "non-root
// caller tries to setuid to a different uid" path: setuid would fail
// with EPERM anyway, but the helper should surface a clear error
// before issuing the syscall so a misconfigured spawn (wrong --target-uid
// flag) is debuggable.
func TestDropAgentPrivileges_RefusesMismatchedNonRoot(t *testing.T) {
uid := currentUIDForTest()
if uid == 0 {
t.Skip("test must not run as root; covered case requires non-root caller")
}
err := dropAgentPrivileges(uid + 1)
if err == nil {
t.Fatal("expected refusal when non-root caller asks to setuid elsewhere")
}
}

View File

@@ -1,11 +0,0 @@
//go:build darwin && !ios
package cmd
import "os"
// currentUIDForTest exposes os.Getuid for the darwin dropprivs tests
// without leaking an os import into the test file itself.
func currentUIDForTest() uint32 {
return uint32(os.Getuid())
}

View File

@@ -1,14 +0,0 @@
//go:build windows
package cmd
// dropAgentPrivileges is a no-op on Windows: the agent and the daemon
// both run as SYSTEM (the daemon spawns the agent into the interactive
// session via CreateProcessAsUser with an impersonation token, but the
// resulting process still runs under SYSTEM, not under the user's
// account). The Windows path relies on the DACL-restricted socket
// directory, the unpredictable per-spawn socket name, the listen-readiness
// gate, and the per-spawn token for integrity instead.
func dropAgentPrivileges(_ uint32) error {
return nil
}

View File

@@ -1,15 +0,0 @@
//go:build windows
package cmd
import (
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newAgentResources() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
sessionID := vncserver.GetCurrentSessionID()
log.Infof("VNC agent running in Windows session %d", sessionID)
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector(), nil
}

View File

@@ -1,16 +0,0 @@
package cmd
const (
serverVNCAllowedFlag = "allow-server-vnc"
disableVNCApprovalFlag = "disable-vnc-approval"
)
var (
serverVNCAllowed bool
disableVNCApproval bool
)
func init() {
upCmd.PersistentFlags().BoolVar(&serverVNCAllowed, serverVNCAllowedFlag, false, "Allow embedded VNC server on peer")
upCmd.PersistentFlags().BoolVar(&disableVNCApproval, disableVNCApprovalFlag, false, "Disable per-connection user approval prompts for the embedded VNC server")
}

View File

@@ -6,30 +6,19 @@ import (
"runtime"
)
var (
// StateDir holds persistent state (config, profiles, install metadata).
StateDir string
// RuntimeDir holds ephemeral artifacts that should not survive reboot,
// such as Unix sockets for daemon and per-session IPC. Empty on
// platforms without a conventional /var/run-style location.
RuntimeDir string
)
var StateDir string
func init() {
StateDir = os.Getenv("NB_STATE_DIR")
if StateDir != "" {
return
}
switch runtime.GOOS {
case "windows":
StateDir = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird")
case "darwin", "linux":
StateDir = "/var/lib/netbird"
RuntimeDir = "/var/run/netbird"
case "freebsd", "openbsd", "netbsd", "dragonfly":
StateDir = "/var/db/netbird"
RuntimeDir = "/var/run/netbird"
}
if v := os.Getenv("NB_STATE_DIR"); v != "" {
StateDir = v
}
if v := os.Getenv("NB_RUNTIME_DIR"); v != "" {
RuntimeDir = v
}
}

View File

@@ -279,6 +279,10 @@ func (c *Client) Start(startCtx context.Context) error {
select {
case <-startCtx.Done():
// Cancel the client context before stopping: Engine.Start blocks on the
// signal stream while holding the engine mutex and only unblocks on
// cancellation. Stopping first would deadlock on that mutex.
cancel()
if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
}
@@ -442,8 +446,8 @@ func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession,
// IdentityForIP looks up a remote peer by its tunnel IP using the
// embedded client's status recorder. Returns the peer's WireGuard public
// key and FQDN. ok=false means the IP isn't in this client's peer
// roster — callers should treat that as "unknown peer".
// key and FQDN. ok=false means the IP doesn't belong to an active peer
// — offline roster peers are treated as unknown, same as foreign IPs.
func (c *Client) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
if !ip.IsValid() || c.recorder == nil {
return "", "", false

168
client/embed/embed_test.go Normal file
View File

@@ -0,0 +1,168 @@
package embed
import (
"context"
"net"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
const testSetupKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
// TestClientStartTimeoutRollback reproduces a deadlock between Engine.Start and
// Engine.Stop. The signal endpoint accepts gRPC connections but never serves the
// SignalExchange service, so Engine.Start parks in WaitStreamConnected while
// holding the engine mutex. When the Start context expires, the rollback path
// calls ConnectClient.Stop, which must not block forever acquiring that mutex.
func TestClientStartTimeoutRollback(t *testing.T) {
signalAddr := startBlackholeSignal(t)
mgmAddr := startManagement(t, signalAddr)
wgPort := 0
client, err := New(Options{
DeviceName: "embed-rollback-test",
SetupKey: testSetupKey,
ManagementURL: "http://" + mgmAddr,
WireguardPort: &wgPort,
})
require.NoError(t, err, "embed client creation must succeed")
startCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
startErr := make(chan error, 1)
go func() {
startErr <- client.Start(startCtx)
}()
select {
case err := <-startErr:
require.ErrorIs(t, err, context.DeadlineExceeded)
case <-time.After(60 * time.Second):
t.Fatal("client.Start did not return after its context expired: Engine.Stop deadlocked against Engine.Start waiting for the signal stream")
}
}
// startBlackholeSignal starts a gRPC server without the SignalExchange service
// registered. Connections succeed, but the signal stream can never be
// established, which keeps Engine.Start parked in WaitStreamConnected.
func startBlackholeSignal(t *testing.T) string {
t.Helper()
lis, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
s := grpc.NewServer()
go func() {
if err := s.Serve(lis); err != nil {
t.Error(err)
}
}()
t.Cleanup(s.Stop)
return lis.Addr().String()
}
func startManagement(t *testing.T, signalAddr string) string {
t.Helper()
cfg := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Relay: &config.Relay{
Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222",
},
Signal: &config.Host{
Proto: "http",
URI: signalAddr,
},
Datadir: t.TempDir(),
HttpConfig: nil,
}
lis, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
s := grpc.NewServer()
testStore, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", cfg.Datadir)
require.NoError(t, err)
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
permissionsManager := permissions.NewManager(testStore)
peersManager := peers.NewManager(testStore, permissionsManager)
jobManager := job.NewJobManager(nil, testStore, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
require.NoError(t, err)
iv, err := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
require.NoError(t, err)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil).
AnyTimes()
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
groupsManager := groups.NewManagerMock()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(context.Background(), testStore)
networkMapController := controller.NewController(context.Background(), testStore, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(testStore, peersManager), cfg)
accountManager, err := mgmt.BuildManager(context.Background(), cfg, testStore, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
require.NoError(t, err)
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, cfg.TURNConfig, cfg.Relay, settingsMockManager, groupsManager)
require.NoError(t, err)
mgmtServer, err := nbgrpc.NewServer(cfg, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
require.NoError(t, err)
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {
if err := s.Serve(lis); err != nil {
t.Error(err)
}
}()
t.Cleanup(s.Stop)
return lis.Addr().String()
}

View File

@@ -1,10 +1,13 @@
package device
import (
"fmt"
"net/netip"
"runtime/debug"
"sync"
"sync/atomic"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun"
)
@@ -41,10 +44,13 @@ type PacketCapture interface {
type FilteredDevice struct {
tun.Device
filter PacketFilter
capture atomic.Pointer[PacketCapture]
mutex sync.RWMutex
closeOnce sync.Once
filter PacketFilter
capture atomic.Pointer[PacketCapture]
// panicHandler is invoked after a panic in the underlying device is
// recovered in Read or Write.
panicHandler atomic.Pointer[func()]
mutex sync.RWMutex
closeOnce sync.Once
}
// newDeviceFilter constructor function
@@ -70,7 +76,7 @@ func (d *FilteredDevice) Close() error {
// Read wraps read method with filtering feature
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
if n, err = d.deviceRead(bufs, sizes, offset); err != nil {
return 0, err
}
@@ -112,7 +118,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
d.mutex.RUnlock()
if filter == nil {
return d.Device.Write(bufs, offset)
return d.deviceWrite(bufs, offset)
}
filteredBufs := make([][]byte, 0, len(bufs))
@@ -125,9 +131,44 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
}
}
n, err := d.Device.Write(filteredBufs, offset)
n += dropped
return n, err
n, err := d.deviceWrite(filteredBufs, offset)
if err != nil {
return n, err
}
return n + dropped, nil
}
// deviceRead calls the underlying device Read, recovering from panics in the
// wintun read path and converting them into errors.
func (d *FilteredDevice) deviceRead(bufs [][]byte, sizes []int, offset int) (n int, err error) {
defer d.recoverFromPanic("read", &n, &err)
return d.Device.Read(bufs, sizes, offset)
}
// deviceWrite calls the underlying device Write, recovering from panics in the
// wintun write path and converting them into errors.
func (d *FilteredDevice) deviceWrite(bufs [][]byte, offset int) (n int, err error) {
defer d.recoverFromPanic("write", &n, &err)
return d.Device.Write(bufs, offset)
}
// recoverFromPanic converts a panic in the underlying device into a regular
// error and invokes the registered panic handler. The wintun read path is
// known to panic on zero-length packets that third-party filter drivers can
// place in the ring.
func (d *FilteredDevice) recoverFromPanic(op string, n *int, err *error) {
r := recover()
if r == nil {
return
}
log.Errorf("recovered panic in tun device %s: %v\n%s", op, r, debug.Stack())
*n = 0
*err = fmt.Errorf("tun device %s panic: %v", op, r)
if handler := d.panicHandler.Load(); handler != nil {
(*handler)()
}
}
// SetFilter sets packet filter to device
@@ -137,6 +178,17 @@ func (d *FilteredDevice) SetFilter(filter PacketFilter) {
d.mutex.Unlock()
}
// SetPanicHandler registers a handler invoked after a recovered panic in Read
// or Write. The device is unusable after such a panic; the handler should
// trigger recreation of the interface. Pass nil to remove.
func (d *FilteredDevice) SetPanicHandler(handler func()) {
if handler == nil {
d.panicHandler.Store(nil)
return
}
d.panicHandler.Store(&handler)
}
// SetCapture sets or clears the packet capture sink. Pass nil to disable.
// Uses atomic store so the hot path (Read/Write) is a single pointer load
// with no locking overhead when capture is off.

View File

@@ -221,3 +221,60 @@ func TestDeviceWrapperRead(t *testing.T) {
}
})
}
func TestDeviceWrapperReadPanic(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
tun := mocks.NewMockDevice(ctrl)
tun.EXPECT().Read(gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) {
// Reproduce the wintun zero-length packet panic (index out of range).
packet := make([]byte, 0)
return int(packet[0]), nil
})
wrapped := newDeviceFilter(tun)
handlerCalled := false
wrapped.SetPanicHandler(func() { handlerCalled = true })
n, err := wrapped.Read([][]byte{{}}, []int{0}, 0)
if err == nil {
t.Errorf("expected error from recovered panic, got nil")
}
if n != 0 {
t.Errorf("expected n=0, got %d", n)
}
if !handlerCalled {
t.Errorf("expected panic handler to be called")
}
}
func TestDeviceWrapperWritePanic(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
tun := mocks.NewMockDevice(ctrl)
tun.EXPECT().Write(gomock.Any(), gomock.Any()).
DoAndReturn(func(bufs [][]byte, offset int) (int, error) {
packet := make([]byte, 0)
return int(packet[0]), nil
})
wrapped := newDeviceFilter(tun)
handlerCalled := false
wrapped.SetPanicHandler(func() { handlerCalled = true })
n, err := wrapped.Write([][]byte{{0x45, 0x00}}, 0)
if err == nil {
t.Errorf("expected error from recovered panic, got nil")
}
if n != 0 {
t.Errorf("expected n=0, got %d", n)
}
if !handlerCalled {
t.Errorf("expected panic handler to be called")
}
}

View File

@@ -1,219 +0,0 @@
// Package approval brokers per-attempt user-accept prompts for inbound
// remote access (VNC today, SSH and others in the future). A caller pushes
// a Prompt; the broker emits a SystemEvent on the daemon→UI stream and
// blocks until the UI calls the daemon's RespondApproval RPC, the per-
// request timeout fires, or no subscriber is connected. The latter case
// fails closed so a backgrounded UI cannot silently bypass the gate.
package approval
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/proto"
)
// Metadata keys the broker reserves on the emitted SystemEvent. Callers
// should not set these themselves; values in Prompt.Metadata that collide
// are overwritten by the broker.
const (
MetaRequestID = "request_id"
MetaKind = "kind"
MetaExpiresAt = "expires_at"
)
// ShortKeyFingerprint formats a hex-encoded Noise_IK static pubkey as a
// short, eyeball-able fingerprint to display in the approval dialog.
// The dashboard-supplied display name attached to a SessionPubKey isn't
// cryptographically asserted by the connecting client, so the prompt
// must also show something that IS: the key fingerprint, a hash of
// the static public key the client just proved possession of during the
// Noise handshake. Returns the empty string when the input is too short
// to plausibly be a hex pubkey, so the row is omitted rather than
// rendered as a misleading partial.
//
// Output format: 16 hex chars grouped as XXXX-XXXX-XXXX-XXXX (64 bits of
// fingerprint, resistant to random-prefix collisions and easy for a human
// to compare with an out-of-band reference).
func ShortKeyFingerprint(hexKey string) string {
if len(hexKey) < 8 {
return ""
}
src := hexKey
if len(src) > 16 {
src = src[:16]
}
var out []byte
for i, c := range src {
if i > 0 && i%4 == 0 {
out = append(out, '-')
}
out = append(out, byte(c))
}
return string(out)
}
// Kind values for the well-known prompt subjects. New subsystems should
// add a constant here so the UI can dispatch on a known string.
const (
KindVNC = "vnc"
KindSSH = "ssh"
)
// DefaultTimeout is the wall-clock window the user has to accept or deny a
// pending approval before the broker fails closed and returns ErrTimeout.
// Kept well under typical VNC client and dashboard connection timeouts so
// the RFB rejection actually reaches the browser instead of racing the
// browser's own "connection timed out" message.
const DefaultTimeout = 15 * time.Second
// timeoutValue returns the active timeout. It's a var so tests in this
// package can shorten the wait without exposing a setter on the public
// API. Production code always sees DefaultTimeout.
var timeoutValue = func() time.Duration { return DefaultTimeout }
// ErrNoSubscriber indicates no UI is connected to consume the prompt.
// The caller must reject the underlying connection (fail-closed).
var ErrNoSubscriber = errors.New("no UI subscriber connected for approval")
// ErrTimeout indicates the user did not respond within DefaultTimeout.
var ErrTimeout = errors.New("approval timed out")
// ErrDenied indicates the user explicitly denied the connection.
var ErrDenied = errors.New("approval denied")
// EventPublisher is the subset of peer.Status used to emit prompts.
type EventPublisher interface {
PublishEvent(
severity proto.SystemEvent_Severity,
category proto.SystemEvent_Category,
msg string,
userMsg string,
metadata map[string]string,
)
HasEventSubscribers() bool
}
// Prompt describes the pending request shown to the user. Kind selects
// the UI dispatch path (e.g. "vnc", "ssh"). Subject is the human-readable
// one-liner the UI may show as a title or notification body. Metadata is
// passed through verbatim and is the subsystem-specific payload (peer
// name, source IP, mode, etc.).
type Prompt struct {
Kind string
Subject string
Metadata map[string]string
}
// Decision carries the user's response to an approval prompt. ViewOnly is
// only meaningful when Accept is true; it lets the host grant the
// connection but signal the requester that input control is withheld.
type Decision struct {
Accept bool
ViewOnly bool
}
// Broker holds in-flight approval requests keyed by request ID.
type Broker struct {
pub EventPublisher
mu sync.Mutex
pending map[string]chan Decision
}
// New returns a broker that publishes prompts via pub.
func New(pub EventPublisher) *Broker {
return &Broker{
pub: pub,
pending: make(map[string]chan Decision),
}
}
// Request emits a SystemEvent for p and blocks until the UI calls Respond,
// ctx is cancelled, or DefaultTimeout elapses. Returns a Decision when
// the user replied; ErrDenied / ErrTimeout / ErrNoSubscriber / ctx.Err
// otherwise. Callers must treat any non-nil error as a deny.
func (b *Broker) Request(ctx context.Context, p Prompt) (Decision, error) {
var zero Decision
if b == nil || b.pub == nil {
return zero, fmt.Errorf("approval broker not configured")
}
if !b.pub.HasEventSubscribers() {
return zero, ErrNoSubscriber
}
id := uuid.NewString()
resp := make(chan Decision, 1)
b.mu.Lock()
b.pending[id] = resp
b.mu.Unlock()
defer b.dropPending(id)
timeout := timeoutValue()
expiresAt := time.Now().Add(timeout)
meta := make(map[string]string, len(p.Metadata)+3)
for k, v := range p.Metadata {
meta[k] = v
}
meta[MetaRequestID] = id
meta[MetaKind] = p.Kind
meta[MetaExpiresAt] = expiresAt.UTC().Format(time.RFC3339)
subject := p.Subject
if subject == "" {
subject = fmt.Sprintf("%s connection requires approval", p.Kind)
}
b.pub.PublishEvent(proto.SystemEvent_INFO, proto.SystemEvent_APPROVAL, subject, subject, meta)
log.Debugf("approval request %s (%s) emitted: %s", id, p.Kind, subject)
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case d := <-resp:
if !d.Accept {
return zero, ErrDenied
}
return d, nil
case <-timer.C:
return zero, ErrTimeout
case <-ctx.Done():
return zero, ctx.Err()
}
}
// Respond delivers the user's decision for id. Returns true when a pending
// request matched and was woken, false when id was unknown or already done.
func (b *Broker) Respond(id string, d Decision) bool {
if b == nil {
return false
}
b.mu.Lock()
ch, ok := b.pending[id]
if ok {
delete(b.pending, id)
}
b.mu.Unlock()
if !ok {
return false
}
select {
case ch <- d:
default:
}
return true
}
func (b *Broker) dropPending(id string) {
b.mu.Lock()
delete(b.pending, id)
b.mu.Unlock()
}

View File

@@ -1,434 +0,0 @@
package approval
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/proto"
)
// fakePublisher records published events and reports whether subscribers
// are connected. The subscribers flag is the security-critical signal:
// when false the broker must refuse to emit and the gate must fail closed.
type fakePublisher struct {
mu sync.Mutex
subscribers bool
events []*proto.SystemEvent
}
func (p *fakePublisher) PublishEvent(
severity proto.SystemEvent_Severity,
category proto.SystemEvent_Category,
msg string,
userMsg string,
metadata map[string]string,
) {
p.mu.Lock()
p.events = append(p.events, &proto.SystemEvent{
Severity: severity,
Category: category,
Message: msg,
UserMessage: userMsg,
Metadata: metadata,
})
p.mu.Unlock()
}
func (p *fakePublisher) HasEventSubscribers() bool {
p.mu.Lock()
defer p.mu.Unlock()
return p.subscribers
}
func (p *fakePublisher) lastEvent(t *testing.T) *proto.SystemEvent {
t.Helper()
p.mu.Lock()
defer p.mu.Unlock()
require.NotEmpty(t, p.events, "publisher saw no events")
return p.events[len(p.events)-1]
}
func (p *fakePublisher) eventCount() int {
p.mu.Lock()
defer p.mu.Unlock()
return len(p.events)
}
// TestRequestNoSubscriberFailsClosed is the core fail-closed invariant:
// when the UI is not subscribed, the broker must refuse without emitting
// an event or arming a waiter. A regression here is a silent bypass.
func TestRequestNoSubscriberFailsClosed(t *testing.T) {
pub := &fakePublisher{subscribers: false}
b := New(pub)
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
assert.ErrorIs(t, err, ErrNoSubscriber)
assert.Equal(t, 0, pub.eventCount(), "no event must be emitted when fail-closed")
b.mu.Lock()
pending := len(b.pending)
b.mu.Unlock()
assert.Equal(t, 0, pending, "no waiter must be registered on fail-closed")
}
// TestRequestTimeoutDenies verifies that a request without a UI response
// returns ErrTimeout (deny) rather than nil (silent accept). Uses a short
// per-test broker timeout via Respond after the fact to keep the test fast.
func TestRequestTimeoutDenies(t *testing.T) {
// Replace DefaultTimeout for the lifetime of this test.
orig := DefaultTimeout
defaultTimeout(t, 60*time.Millisecond)
defer defaultTimeout(t, orig)
pub := &fakePublisher{subscribers: true}
b := New(pub)
start := time.Now()
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
assert.ErrorIs(t, err, ErrTimeout, "missing user response must yield ErrTimeout, not nil")
assert.GreaterOrEqual(t, time.Since(start), 50*time.Millisecond, "timeout fired prematurely")
}
// TestRequestDenied returns ErrDenied when the UI responds with false.
func TestRequestDenied(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
var requestID string
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
}()
requestID = waitForRequestID(t, pub)
require.True(t, b.Respond(requestID, Decision{Accept: false}))
select {
case err := <-done:
assert.ErrorIs(t, err, ErrDenied)
case <-time.After(time.Second):
t.Fatal("Request did not return after Respond(false)")
}
}
// TestRequestAccepted is the happy path. Failure here doesn't bypass the
// gate but breaks the feature.
func TestRequestAccepted(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
}()
id := waitForRequestID(t, pub)
require.True(t, b.Respond(id, Decision{Accept: true}))
select {
case err := <-done:
assert.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("Request did not return after Respond(true)")
}
}
// TestRequestCtxCancelDenies verifies that an upstream cancel (e.g. the
// engine shutting down mid-prompt) returns the cancel error rather than
// nil. A nil here would be a silent bypass on shutdown races.
func TestRequestCtxCancelDenies(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
ctx, cancel := context.WithCancel(context.Background())
done := make(chan error, 1)
go func() {
done <- requestErr(b, ctx, Prompt{Kind: KindVNC, Subject: "test"})
}()
// Wait until the prompt is in flight so cancel races a live waiter.
_ = waitForRequestID(t, pub)
cancel()
select {
case err := <-done:
assert.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
t.Fatal("Request did not return after ctx cancel")
}
}
// TestRespondUnknownIsNoop ensures a stray RespondApproval RPC cannot
// affect or accidentally accept any in-flight request whose id it doesn't
// match. Also confirms it doesn't panic.
func TestRespondUnknownIsNoop(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
// No in-flight prompts: Respond returns false.
assert.False(t, b.Respond("does-not-exist", Decision{Accept: true}))
// With an in-flight prompt, a wrong id still returns false and the
// prompt remains armed (eventually timing out as a deny).
defaultTimeout(t, 60*time.Millisecond)
defer defaultTimeout(t, DefaultTimeout)
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
}()
realID := waitForRequestID(t, pub)
assert.False(t, b.Respond("totally-bogus", Decision{Accept: true}), "unknown id must not match")
assert.NotEqual(t, "totally-bogus", realID)
select {
case err := <-done:
assert.ErrorIs(t, err, ErrTimeout, "armed prompt must still time out, not accept")
case <-time.After(time.Second):
t.Fatal("prompt did not resolve")
}
}
// TestRespondAfterTimeoutNoop confirms a late accept response can't
// retroactively flip a denied (timed-out) request. The dropPending defer
// in Request must have removed the entry by the time Respond races in.
func TestRespondAfterTimeoutNoop(t *testing.T) {
defaultTimeout(t, 30*time.Millisecond)
defer defaultTimeout(t, DefaultTimeout)
pub := &fakePublisher{subscribers: true}
b := New(pub)
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
}()
id := waitForRequestID(t, pub)
select {
case err := <-done:
require.ErrorIs(t, err, ErrTimeout)
case <-time.After(time.Second):
t.Fatal("prompt did not time out")
}
assert.False(t, b.Respond(id, Decision{Accept: true}), "late respond must be no-op")
}
// TestRespondDoubleNoop ensures a duplicate ack from the UI doesn't leak
// past the matched waiter or panic on a closed/full channel.
func TestRespondDoubleNoop(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
}()
id := waitForRequestID(t, pub)
require.True(t, b.Respond(id, Decision{Accept: true}))
assert.False(t, b.Respond(id, Decision{Accept: false}), "second response must be no-op")
select {
case err := <-done:
assert.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("prompt did not resolve")
}
}
// TestNilBrokerRequestErrors guards the engine pre-init path where the
// broker may not yet exist (or its publisher is nil): Request must
// error, never silently accept.
func TestNilBrokerRequestErrors(t *testing.T) {
var b *Broker
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC})
assert.Error(t, err, "nil broker must error, never silently accept")
b2 := New(nil)
_, err = b2.Request(context.Background(), Prompt{Kind: KindVNC})
assert.Error(t, err, "broker with nil publisher must error, never silently accept")
}
// TestPromptMetadataInjected confirms the broker stamps request_id, kind,
// and expires_at on the emitted event. The UI relies on these keys; if
// they are dropped, the user cannot route the prompt and the response
// path breaks (which fails closed via timeout).
func TestPromptMetadataInjected(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{
Kind: KindVNC,
Subject: "VNC connection from peerA",
Metadata: map[string]string{"peer_name": "peerA"},
})
}()
id := waitForRequestID(t, pub)
ev := pub.lastEvent(t)
assert.Equal(t, proto.SystemEvent_APPROVAL, ev.Category)
assert.Equal(t, KindVNC, ev.Metadata[MetaKind])
assert.Equal(t, id, ev.Metadata[MetaRequestID])
assert.NotEmpty(t, ev.Metadata[MetaExpiresAt])
assert.Equal(t, "peerA", ev.Metadata["peer_name"], "caller metadata must pass through")
require.True(t, b.Respond(id, Decision{Accept: true}))
<-done
}
// TestConcurrentRequests verifies that two concurrent prompts are tracked
// independently. A bug that aliases ids would let one Respond unblock
// the wrong waiter (a silent accept across prompts).
func TestConcurrentRequests(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
const n = 20
results := make(chan error, n)
for i := 0; i < n; i++ {
go func() {
results <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
}()
}
ids := waitForNRequestIDs(t, pub, n)
require.Len(t, ids, n)
// Deny exactly half, accept the rest. Track outcome per id so we can
// match each Request's return value against the response we sent.
denySet := make(map[string]bool, n)
for i, id := range ids {
deny := i%2 == 0
denySet[id] = deny
require.True(t, b.Respond(id, Decision{Accept: !deny}))
}
// Collect all returns and check no nil errors slipped past a deny.
var accepted, denied atomic.Int32
for i := 0; i < n; i++ {
select {
case err := <-results:
if err == nil {
accepted.Add(1)
} else {
assert.ErrorIs(t, err, ErrDenied)
denied.Add(1)
}
case <-time.After(2 * time.Second):
t.Fatalf("only got %d/%d responses", i, n)
}
}
assert.Equal(t, int32(n/2), denied.Load())
assert.Equal(t, int32(n/2), accepted.Load())
}
// waitForRequestID blocks until the publisher sees its next event and
// returns the request_id stamped on it.
func waitForRequestID(t *testing.T, pub *fakePublisher) string {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
pub.mu.Lock()
count := len(pub.events)
var id string
if count > 0 {
id = pub.events[count-1].Metadata[MetaRequestID]
}
pub.mu.Unlock()
if id != "" {
return id
}
time.Sleep(2 * time.Millisecond)
}
t.Fatal("timeout waiting for emitted event")
return ""
}
func waitForNRequestIDs(t *testing.T, pub *fakePublisher, n int) []string {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
pub.mu.Lock()
count := len(pub.events)
pub.mu.Unlock()
if count >= n {
break
}
time.Sleep(2 * time.Millisecond)
}
pub.mu.Lock()
defer pub.mu.Unlock()
out := make([]string, 0, len(pub.events))
seen := make(map[string]struct{}, len(pub.events))
for _, ev := range pub.events {
id := ev.Metadata[MetaRequestID]
if id == "" {
continue
}
if _, dup := seen[id]; dup {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
if len(out) < n {
t.Fatalf("only got %d/%d request ids", len(out), n)
}
return out
}
// defaultTimeout swaps the broker's per-request wall-clock window so the
// timeout tests run quickly. Restores the prior value on the next call.
func defaultTimeout(t *testing.T, d time.Duration) {
t.Helper()
if d <= 0 {
t.Fatal("defaultTimeout must be > 0")
}
timeoutValue = func() time.Duration { return d }
}
// requestErr wraps Broker.Request to drop the Decision when tests only
// care about the error path. Keeps the goroutine bodies tight.
func requestErr(b *Broker, ctx context.Context, p Prompt) error {
_, err := b.Request(ctx, p)
return err
}
// TestRequestViewOnly checks the view-only outcome flows through Request's
// Decision return without being silently swallowed.
func TestRequestViewOnly(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
type result struct {
d Decision
err error
}
done := make(chan result, 1)
go func() {
d, err := b.Request(context.Background(), Prompt{Kind: KindVNC})
done <- result{d, err}
}()
id := waitForRequestID(t, pub)
require.True(t, b.Respond(id, Decision{Accept: true, ViewOnly: true}))
select {
case r := <-done:
assert.NoError(t, r.err)
assert.True(t, r.d.Accept)
assert.True(t, r.d.ViewOnly, "ViewOnly must survive the round-trip")
case <-time.After(time.Second):
t.Fatal("view-only request did not resolve")
}
}

View File

@@ -1,62 +0,0 @@
package approval
import "testing"
// TestShortKeyFingerprint locks in the format the VNC approval prompt
// shows to the user. The fingerprint is the user's only cryptographic
// anchor against a malicious management server that pushes a spoofed
// display name, so accidental changes to its format would silently
// undermine that defence.
func TestShortKeyFingerprint(t *testing.T) {
cases := []struct {
name string
in string
want string
}{
{
name: "full_32_byte_pubkey",
in: "0123456789abcdeffedcba9876543210ffeeddccbbaa99887766554433221100",
want: "0123-4567-89ab-cdef",
},
{
name: "exactly_16_chars",
in: "0123456789abcdef",
want: "0123-4567-89ab-cdef",
},
{
name: "borderline_8_chars",
in: "01234567",
want: "0123-4567",
},
{
name: "too_short_returns_empty",
in: "0123",
want: "",
},
{
name: "empty_returns_empty",
in: "",
want: "",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := ShortKeyFingerprint(tc.in)
if got != tc.want {
t.Fatalf("ShortKeyFingerprint(%q) = %q, want %q", tc.in, got, tc.want)
}
})
}
}
// TestShortKeyFingerprint_DistinctKeysDistinctOutputs guards against a
// formatting bug that would collapse different prefixes onto the same
// displayed fingerprint and let an attacker substitute their pubkey for
// a victim's while keeping the prompt visually identical.
func TestShortKeyFingerprint_DistinctKeysDistinctOutputs(t *testing.T) {
a := ShortKeyFingerprint("0123456789abcdef" + "rest_of_pubkey_ignored")
b := ShortKeyFingerprint("0123456789abcde0" + "rest_of_pubkey_ignored")
if a == b {
t.Fatalf("expected distinct outputs for distinct prefixes, both = %q", a)
}
}

View File

@@ -315,7 +315,6 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
a.config.RosenpassEnabled,
a.config.RosenpassPermissive,
a.config.ServerSSHAllowed,
a.config.ServerVNCAllowed,
a.config.DisableClientRoutes,
a.config.DisableServerRoutes,
a.config.DisableDNS,

View File

@@ -568,8 +568,6 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
ServerVNCAllowed: config.ServerVNCAllowed != nil && *config.ServerVNCAllowed,
DisableVNCApproval: config.DisableVNCApproval,
EnableSSHRoot: config.EnableSSHRoot,
EnableSSHSFTP: config.EnableSSHSFTP,
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
@@ -652,7 +650,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.ServerVNCAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,

View File

@@ -516,6 +516,14 @@ func (g *BundleGenerator) addConfig() error {
}
}
// Surface the set of MDM-enforced keys so a support engineer reading
// the bundle can tell which field values are user-set vs MDM-overridden.
// Same semantics as the mDMManagedFields list returned by the
// GetConfig RPC consumed by `netbird debug config`.
if managed := g.internalConfig.Policy().ManagedKeys(); len(managed) > 0 {
configContent.WriteString(fmt.Sprintf("MDMManagedFields: %v\n", managed))
}
configReader := strings.NewReader(configContent.String())
if err := g.addFileToZip(configReader, "config.txt"); err != nil {
return fmt.Errorf("add config file to zip: %w", err)
@@ -644,12 +652,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
if g.internalConfig.SSHJWTCacheTTL != nil {
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
}
if g.internalConfig.ServerVNCAllowed != nil {
configContent.WriteString(fmt.Sprintf("ServerVNCAllowed: %v\n", *g.internalConfig.ServerVNCAllowed))
}
if g.internalConfig.DisableVNCApproval != nil {
configContent.WriteString(fmt.Sprintf("DisableVNCApproval: %v\n", *g.internalConfig.DisableVNCApproval))
}
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))

View File

@@ -843,6 +843,7 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
"PreSharedKey": "sensitive: WireGuard pre-shared key",
"SSHKey": "sensitive: SSH private key",
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
"policy": "non-config: in-memory MDM policy snapshot, surfaced via Config.Policy() / GetConfigResponse.MDMManagedFields",
}
mURL, _ := url.Parse("https://api.example.com:443")
@@ -862,8 +863,6 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
RosenpassEnabled: true,
RosenpassPermissive: true,
ServerSSHAllowed: &bTrue,
ServerVNCAllowed: &bTrue,
DisableVNCApproval: &bTrue,
EnableSSHRoot: &bTrue,
EnableSSHSFTP: &bTrue,
EnableSSHLocalPortForwarding: &bTrue,

View File

@@ -482,7 +482,7 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16,
// completely when every proxy peer is offline (the upstream may still
// be reachable some other way, or the peerstore may be stale).
func (d *Resolver) filterDisconnectedPeerAnswers(logger *log.Entry, question dns.Question, records []dns.RR) []dns.RR {
if len(records) == 0 {
if len(records) < 2 {
return records
}
d.mu.RLock()

View File

@@ -2738,6 +2738,17 @@ func TestLocalResolver_FilterDisconnectedPeerAnswers(t *testing.T) {
connByIP: nil,
wantInOrder: []string{"100.64.0.10", "100.64.0.11"},
},
{
// A single answer is never filtered: dropping it would only
// trigger the empty-answer escape hatch, so the fast path
// returns it untouched.
name: "single disconnected answer passes through",
records: []nbdns.SimpleRecord{disconnectedRec},
connByIP: map[string]ipState{
"100.64.0.11": {known: true, connected: false},
},
wantInOrder: []string{"100.64.0.11"},
},
}
for _, tc := range tests {

View File

@@ -34,7 +34,6 @@ import (
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/approval"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
@@ -125,8 +124,6 @@ type EngineConfig struct {
RosenpassPermissive bool
ServerSSHAllowed bool
ServerVNCAllowed bool
DisableVNCApproval *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -212,9 +209,7 @@ type Engine struct {
networkMonitor *networkmonitor.NetworkMonitor
sshServer sshServer
vncSrv vncServer
approvalBroker *approval.Broker
sshServer sshServer
statusRecorder *peer.Status
@@ -245,7 +240,7 @@ type Engine struct {
syncStore syncstore.Store
syncStoreDir string
flowManager nftypes.FlowManager
flowManager nftypes.FlowManager
// auto-update
updateManager *updater.Manager
@@ -300,7 +295,6 @@ func NewEngine(
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: services.StatusRecorder,
approvalBroker: approval.New(services.StatusRecorder),
stateManager: services.StateManager,
portForwardManager: portforward.NewManager(),
checks: services.Checks,
@@ -337,10 +331,6 @@ func (e *Engine) Stop() error {
log.Warnf("failed to stop SSH server: %v", err)
}
if err := e.stopVNCServer(); err != nil {
log.Warnf("failed to stop VNC server: %v", err)
}
e.cleanupSSHConfig()
if e.ingressGatewayMgr != nil {
@@ -541,6 +531,10 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return fmt.Errorf("create wg interface: %w", err)
}
if filteredDevice := e.wgInterface.GetDevice(); filteredDevice != nil {
filteredDevice.SetPanicHandler(e.triggerClientRestart)
}
if err := e.createFirewall(); err != nil {
e.close()
return err
@@ -890,62 +884,25 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
}
if update.GetNetbirdConfig() != nil {
wCfg := update.GetNetbirdConfig()
err := e.updateTURNs(wCfg.GetTurns())
if err != nil {
return fmt.Errorf("update TURNs: %w", err)
}
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
return err
}
err = e.updateSTUNs(wCfg.GetStuns())
if err != nil {
return fmt.Errorf("update STUNs: %w", err)
}
var stunTurn []*stun.URI
stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...)
e.stunTurn.Store(stunTurn)
err = e.handleRelayUpdate(wCfg.GetRelay())
if err != nil {
return err
}
err = e.handleFlowUpdate(wCfg.GetFlow())
if err != nil {
return fmt.Errorf("handle the flow configuration: %w", err)
}
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
log.Warnf("Failed to update DNS server config: %v", err)
}
// todo update signal
// Posture checks are bound to the network map presence:
// NetworkMap != nil, checks present -> apply the received checks
// NetworkMap != nil, checks nil -> posture checks were removed, clear them
// NetworkMap == nil -> config-only update (e.g. relay token rotation),
// leave the previously applied checks untouched
nm := update.GetNetworkMap()
if nm == nil {
return nil
}
if err := e.updateChecksIfNew(update.Checks); err != nil {
return err
}
nm := update.GetNetworkMap()
if nm == nil {
return nil
}
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
// A non-nil syncStore is what marks persistence as enabled. Hold the lock for
// the whole Set so the store cannot be cleared (disabled / engine close)
// mid-call and have this write resurrect a file that was just removed.
e.syncRespMux.RLock()
if e.syncStore != nil {
if err := e.syncStore.Set(update); err != nil {
log.Errorf("failed to persist sync response: %v", err)
} else {
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
}
}
e.syncRespMux.RUnlock()
e.persistSyncResponse(update)
// only apply new changes and ignore old ones
if err := e.updateNetworkMap(nm); err != nil {
@@ -957,6 +914,64 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil
}
// updateNetbirdConfig applies the management-provided NetBird configuration:
// STUN/TURN and relay servers, flow logging and DNS settings. A nil config is a no-op,
// which is the case for sync updates carrying only a network map.
func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error {
if wCfg == nil {
return nil
}
if err := e.updateTURNs(wCfg.GetTurns()); err != nil {
return fmt.Errorf("update TURNs: %w", err)
}
if err := e.updateSTUNs(wCfg.GetStuns()); err != nil {
return fmt.Errorf("update STUNs: %w", err)
}
var stunTurn []*stun.URI
stunTurn = append(stunTurn, e.STUNs...)
stunTurn = append(stunTurn, e.TURNs...)
e.stunTurn.Store(stunTurn)
if err := e.handleRelayUpdate(wCfg.GetRelay()); err != nil {
return err
}
if err := e.handleFlowUpdate(wCfg.GetFlow()); err != nil {
return fmt.Errorf("handle the flow configuration: %w", err)
}
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
log.Warnf("Failed to update DNS server config: %v", err)
}
// todo update signal
return nil
}
// persistSyncResponse stores the full sync response so it can be restored on the next
// startup. Persistence is enabled only when syncStore is set. The dedicated syncRespMux
// (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled /
// engine close) mid-call and have this write resurrect a file that was just removed.
func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) {
e.syncRespMux.RLock()
defer e.syncRespMux.RUnlock()
if e.syncStore == nil {
return
}
if err := e.syncStore.Set(update); err != nil {
log.Errorf("failed to persist sync response: %v", err)
return
}
log.Debugf("sync response persisted with serial %d", update.GetNetworkMap().GetSerial())
}
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
if update != nil {
// when we receive token we expect valid address list too
@@ -1030,7 +1045,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -1078,10 +1092,6 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
}
}
if err := e.updateVNC(); err != nil {
log.Warnf("failed handling VNC server setup: %v", err)
}
state := e.statusRecorder.GetLocalPeerState()
state.IP = e.wgInterface.Address().String()
state.IPv6 = e.wgInterface.Address().IPv6String()
@@ -1209,7 +1219,6 @@ func (e *Engine) receiveManagementEvents() {
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -1399,11 +1408,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.updateSSHServerAuth(networkMap.GetSshAuth())
}
// VNC auth: always sync, including nil so cleared auth on the management
// side is applied locally, and so it isn't skipped on the RemotePeersIsEmpty
// cleanup path.
e.updateVNCServerAuth(networkMap.GetVncAuth())
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
@@ -1871,7 +1875,6 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -2656,16 +2659,3 @@ func decodeRelayIP(b []byte) netip.Addr {
}
return ip.Unmap()
}
// RespondApproval relays the user's decision for a pending approval to
// the broker. viewOnly is honoured only when accept is true. Returns
// true when the request_id matched a live prompt.
func (e *Engine) RespondApproval(requestID string, accept, viewOnly bool) bool {
if e == nil || e.approvalBroker == nil {
return false
}
return e.approvalBroker.Respond(requestID, approval.Decision{
Accept: accept,
ViewOnly: accept && viewOnly,
})
}

View File

@@ -12,10 +12,10 @@ import (
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/netstack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
@@ -237,18 +237,22 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
return errors.New("wg interface not initialized")
}
wgAddr := e.wgInterface.Address()
serverConfig := &sshserver.Config{
HostKeyPEM: e.config.SSHKey,
JWT: jwtConfig,
NetstackNet: e.wgInterface.GetNet(),
NetworkValidation: wgAddr,
HostKeyPEM: e.config.SSHKey,
JWT: jwtConfig,
}
server := sshserver.New(serverConfig)
wgAddr := e.wgInterface.Address()
server.SetNetworkValidation(wgAddr)
netbirdIP := wgAddr.IP
listenAddr := netip.AddrPortFrom(netbirdIP, sshserver.InternalSSHPort)
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
server.SetNetstackNet(netstackNet)
}
e.configureSSHServer(server)
if err := server.Start(e.ctx, listenAddr); err != nil {

View File

@@ -1,302 +0,0 @@
//go:build !js && !ios && !android
package internal
import (
"context"
"errors"
"fmt"
"net/netip"
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/approval"
"github.com/netbirdio/netbird/client/internal/metrics"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/vnc"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
type vncServer interface {
Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error
Stop() error
ActiveSessions() []vncserver.ActiveSessionInfo
}
func (e *Engine) setupVNCPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, vnc.ExternalPort, vnc.InternalPort); err != nil {
return fmt.Errorf("add VNC port redirection: %w", err)
}
log.Infof("VNC port redirection: %s:%d -> %s:%d", localAddr, vnc.ExternalPort, localAddr, vnc.InternalPort)
return nil
}
func (e *Engine) cleanupVNCPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, vnc.ExternalPort, vnc.InternalPort); err != nil {
return fmt.Errorf("remove VNC port redirection: %w", err)
}
return nil
}
// updateVNC handles starting/stopping the VNC server based on the config flag.
func (e *Engine) updateVNC() error {
if !e.config.ServerVNCAllowed {
if e.vncSrv != nil {
log.Info("VNC server disabled, stopping")
}
return e.stopVNCServer()
}
if e.config.BlockInbound {
log.Info("VNC server disabled because inbound connections are blocked")
return e.stopVNCServer()
}
if e.vncSrv != nil {
return nil
}
return e.startVNCServer()
}
func (e *Engine) startVNCServer() error {
if e.wgInterface == nil {
return errors.New("wg interface not initialized")
}
capturer, injector, ok := newPlatformVNC()
if !ok {
log.Debug("VNC server not supported on this platform")
return nil
}
netbirdIP := e.wgInterface.Address().IP
var sessionRecorder func(vncserver.SessionTick)
if e.clientMetrics != nil {
sessionRecorder = func(t vncserver.SessionTick) {
e.clientMetrics.RecordVNCSessionTick(e.ctx, metrics.VNCSessionTick{
Period: t.Period,
BytesOut: t.BytesOut,
Writes: t.Writes,
FBUs: t.FBUs,
MaxFBUBytes: t.MaxFBUBytes,
MaxFBURects: t.MaxFBURects,
MaxWriteBytes: t.MaxWriteBytes,
WriteNanos: t.WriteNanos,
})
}
}
serviceMode := vncNeedsServiceMode()
if serviceMode {
log.Info("VNC: running as system service, enabling service mode (per-session agent proxy)")
}
requireApproval := e.config.DisableVNCApproval == nil || !*e.config.DisableVNCApproval
srv := vncserver.New(vncserver.Config{
Capturer: capturer,
Injector: injector,
IdentityKey: e.config.WgPrivateKey[:],
ServiceMode: serviceMode,
SessionRecorder: sessionRecorder,
NetstackNet: e.wgInterface.GetNet(),
RequireApproval: requireApproval,
Approver: &vncApprover{broker: e.approvalBroker, statusRecorder: e.statusRecorder},
})
listenAddr := netip.AddrPortFrom(netbirdIP, vnc.InternalPort)
network := e.wgInterface.Address().Network
if err := srv.Start(e.ctx, listenAddr, network); err != nil {
return fmt.Errorf("start VNC server: %w", err)
}
e.vncSrv = srv
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.RegisterNetstackService(nftypes.TCP, vnc.InternalPort)
log.Debugf("registered VNC service with netstack for TCP:%d", vnc.InternalPort)
}
}
if err := e.setupVNCPortRedirection(); err != nil {
log.Warnf("setup VNC port redirection: %v", err)
}
log.Info("VNC server enabled")
return nil
}
// updateVNCServerAuth updates VNC fine-grained access control from management.
// A nil vncAuth clears all authorized users and session pubkeys so management
// can revoke access by omitting the field on the next sync.
func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
if e.vncSrv == nil {
return
}
vncSrv, ok := e.vncSrv.(*vncserver.Server)
if !ok {
return
}
if vncAuth == nil {
vncSrv.UpdateVNCAuth(&sshauth.Config{})
return
}
protoUsers := vncAuth.GetAuthorizedUsers()
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
for i, hash := range protoUsers {
if len(hash) != 16 {
log.Warnf("invalid VNC auth hash length %d, expected 16", len(hash))
return
}
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
}
machineUsers := make(map[string][]uint32)
for osUser, indexes := range vncAuth.GetMachineUsers() {
machineUsers[osUser] = indexes.GetIndexes()
}
sessionPubKeys := make([]sshauth.SessionPubKey, 0, len(vncAuth.GetSessionPubKeys()))
for _, pk := range vncAuth.GetSessionPubKeys() {
pub := pk.GetPubKey()
if len(pub) != 32 {
log.Warnf("VNC session pubkey wrong length %d", len(pub))
continue
}
hash := pk.GetUserIdHash()
if len(hash) != 16 {
log.Warnf("VNC session user id hash wrong length %d", len(hash))
continue
}
sessionPubKeys = append(sessionPubKeys, sshauth.SessionPubKey{
PubKey: pub,
UserIDHash: sshuserhash.UserIDHash(hash),
DisplayName: pk.GetDisplayName(),
})
}
vncSrv.UpdateVNCAuth(&sshauth.Config{
AuthorizedUsers: authorizedUsers,
MachineUsers: machineUsers,
SessionPubKeys: sessionPubKeys,
})
}
// GetVNCServerStatus returns whether the VNC server is running and the list
// of active VNC sessions. The pointer is captured under syncMsgMux so a
// concurrent updateVNC/stopVNCServer cannot swap it out between the nil
// check and the ActiveSessions call.
func (e *Engine) GetVNCServerStatus() (enabled bool, sessions []vncserver.ActiveSessionInfo) {
e.syncMsgMux.Lock()
vncSrv := e.vncSrv
e.syncMsgMux.Unlock()
if vncSrv == nil {
return false, nil
}
return true, vncSrv.ActiveSessions()
}
func (e *Engine) stopVNCServer() error {
if e.vncSrv == nil {
return nil
}
if err := e.cleanupVNCPortRedirection(); err != nil {
log.Warnf("cleanup VNC port redirection: %v", err)
}
if e.wgInterface != nil && e.wgInterface.GetNet() != nil {
if registrar, ok := e.firewall.(interface {
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.UnregisterNetstackService(nftypes.TCP, vnc.InternalPort)
}
}
log.Info("stopping VNC server")
err := e.vncSrv.Stop()
e.vncSrv = nil
if err != nil {
return fmt.Errorf("stop VNC server: %w", err)
}
return nil
}
// vncApprover adapts the generic approval.Broker for the VNC server.
type vncApprover struct {
broker *approval.Broker
statusRecorder *peer.Status
}
func (a *vncApprover) Request(ctx context.Context, info vncserver.ApprovalInfo) (vncserver.ApprovalDecision, error) {
// Resolve the source overlay IP to a peer FQDN for the prompt label.
if info.PeerName == "" && info.SourceIP != "" && a.statusRecorder != nil {
if fqdn, ok := a.statusRecorder.PeerByIP(info.SourceIP); ok {
info.PeerName = fqdn
}
}
subject := fmt.Sprintf("VNC connection from %s", displayPeer(info))
meta := map[string]string{
"peer_name": info.PeerName,
"peer_pubkey": info.PeerPubKey,
"source_ip": info.SourceIP,
"mode": info.Mode,
"username": info.Username,
"initiator": info.Initiator,
}
d, err := a.broker.Request(ctx, approval.Prompt{
Kind: approval.KindVNC,
Subject: subject,
Metadata: meta,
})
if err != nil {
return vncserver.ApprovalDecision{}, err
}
return vncserver.ApprovalDecision{ViewOnly: d.ViewOnly}, nil
}
func displayPeer(info vncserver.ApprovalInfo) string {
if info.Initiator != "" {
return info.Initiator
}
if info.PeerName != "" {
return info.PeerName
}
if info.SourceIP != "" {
return info.SourceIP
}
if info.PeerPubKey != "" {
return info.PeerPubKey
}
return "unknown peer"
}

View File

@@ -1,31 +0,0 @@
//go:build freebsd
package internal
import (
"fmt"
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
// newConsoleVNC builds the FreeBSD console fallback: vt(4) framebuffer
// for capture, /dev/uinput for input. The uinput device requires the
// `uinput` kernel module (`kldload uinput`); without it, input init
// fails and we drop to a stub injector so the user still gets a
// view-only screen mirror.
func newConsoleVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
poller := vncserver.NewFBPoller("")
w, h := poller.Width(), poller.Height()
if w == 0 || h == 0 {
poller.Close()
return nil, nil, fmt.Errorf("vt framebuffer init failed (vt may not allow mmap on this driver)")
}
if inj, err := vncserver.NewUInputInjector(w, h); err == nil {
return poller, inj, nil
} else {
log.Infof("VNC console: uinput unavailable (%v); view-only mode. Run `kldload uinput` to enable input.", err)
return poller, &vncserver.StubInputInjector{}, nil
}
}

View File

@@ -1,30 +0,0 @@
//go:build linux && !android
package internal
import (
"fmt"
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
// newConsoleVNC builds a framebuffer + uinput VNC backend for boxes
// without a running X server. Used as the auto-fallback when
// newPlatformVNC can't reach X. Returns an error when /dev/fb0 or
// /dev/uinput aren't usable so the caller can drop back to a stub.
func newConsoleVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
poller := vncserver.NewFBPoller("")
w, h := poller.Width(), poller.Height()
if w == 0 || h == 0 {
poller.Close()
return nil, nil, fmt.Errorf("framebuffer capturer init failed (is /dev/fb0 readable?)")
}
inj, err := vncserver.NewUInputInjector(w, h)
if err != nil {
log.Debugf("uinput unavailable, falling back to view-only VNC: %v", err)
return poller, &vncserver.StubInputInjector{}, nil
}
return poller, inj, nil
}

View File

@@ -1,34 +0,0 @@
//go:build darwin && !ios
package internal
import (
"os"
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
capturer := vncserver.NewMacPoller()
// Prompt for Screen Recording at server-enable time rather than first
// client-connect. The native prompt is far easier for users to act on
// in the moment they toggled VNC on than later when "the screen looks
// like wallpaper" would otherwise be the only clue.
vncserver.PrimeScreenCapturePermission()
injector, err := vncserver.NewMacInputInjector()
if err != nil {
log.Debugf("VNC: macOS input injector: %v", err)
return capturer, &vncserver.StubInputInjector{}, true
}
return capturer, injector, true
}
// vncNeedsServiceMode reports whether the running process is a system
// LaunchDaemon (root, parented by launchd). Daemons sit in the global
// bootstrap namespace and cannot talk to WindowServer; we route capture
// through a per-user agent in that case.
func vncNeedsServiceMode() bool {
return os.Geteuid() == 0 && os.Getppid() == 1
}

View File

@@ -1,23 +0,0 @@
//go:build js || ios || android
package internal
import (
log "github.com/sirupsen/logrus"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
type vncServer interface{}
func (e *Engine) updateVNC() error { return nil }
func (e *Engine) updateVNCServerAuth(auth *mgmProto.VNCAuth) {
if auth == nil {
return
}
log.Debugf("ignoring VNC auth push on platform without a VNC server: %d session pubkeys, %d authorized users",
len(auth.GetSessionPubKeys()), len(auth.GetAuthorizedUsers()))
}
func (e *Engine) stopVNCServer() error { return nil }

View File

@@ -1,13 +0,0 @@
//go:build windows
package internal
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector(), true
}
func vncNeedsServiceMode() bool {
return vncserver.GetCurrentSessionID() == 0
}

View File

@@ -1,35 +0,0 @@
//go:build (linux && !android) || freebsd
package internal
import (
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
// Prefer X11 when an X server is reachable. NewX11InputInjector probes
// DISPLAY (and /proc) eagerly, so a non-nil error here means no X.
injector, err := vncserver.NewX11InputInjector("", "", "")
if err == nil {
return vncserver.NewX11Poller("", ""), injector, true
}
log.Debugf("VNC: X11 not available: %v", err)
// Fallback for headless / pre-X states (kernel console, login manager
// without X, physical server in recovery): stream the framebuffer and
// inject input via /dev/uinput.
consoleCap, consoleInj, err := newConsoleVNC()
if err == nil {
log.Infof("VNC: using framebuffer console capture (%dx%d)", consoleCap.Width(), consoleCap.Height())
return consoleCap, consoleInj, true
}
log.Debugf("VNC: framebuffer console fallback unavailable: %v", err)
return &vncserver.StubCapturer{}, &vncserver.StubInputInjector{}, false
}
func vncNeedsServiceMode() bool {
return false
}

View File

@@ -120,36 +120,6 @@ func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentI
m.trimLocked()
}
func (m *influxDBMetrics) RecordVNCSessionTick(_ context.Context, agentInfo AgentInfo, tick VNCSessionTick) {
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s",
agentInfo.DeploymentType.String(),
agentInfo.Version,
agentInfo.OS,
agentInfo.Arch,
agentInfo.peerID,
)
m.mu.Lock()
defer m.mu.Unlock()
m.samples = append(m.samples, influxSample{
measurement: "netbird_vnc_traffic",
tags: tags,
fields: map[string]float64{
"period_seconds": tick.Period.Seconds(),
"bytes_out": float64(tick.BytesOut),
"writes": float64(tick.Writes),
"fbus": float64(tick.FBUs),
"max_fbu_bytes": float64(tick.MaxFBUBytes),
"max_fbu_rects": float64(tick.MaxFBURects),
"max_write_bytes": float64(tick.MaxWriteBytes),
"write_time_seconds": float64(tick.WriteNanos) / 1e9,
},
timestamp: time.Now(),
})
m.trimLocked()
}
func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) {
result := "success"
if !success {

View File

@@ -59,11 +59,6 @@ type metricsImplementation interface {
// RecordLoginDuration records how long the login to management took
RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool)
// RecordVNCSessionTick records a periodic snapshot of one VNC
// session's wire activity. Called once per metricsConn tick interval
// (and once at session close), only when the tick saw activity.
RecordVNCSessionTick(ctx context.Context, agentInfo AgentInfo, tick VNCSessionTick)
// Export exports metrics in InfluxDB line protocol format
Export(w io.Writer) error
@@ -83,21 +78,6 @@ type ClientMetrics struct {
pushCancel context.CancelFunc
}
// VNCSessionTick is one sampling slice of a VNC session's wire activity.
// BytesOut / Writes / FBUs / WriteNanos are deltas observed during this
// tick; Max* fields are the high-water marks observed during the tick.
// Period is the wall-clock duration the deltas cover.
type VNCSessionTick struct {
Period time.Duration
BytesOut uint64
Writes uint64
FBUs uint64
MaxFBUBytes uint64
MaxFBURects uint64
MaxWriteBytes uint64
WriteNanos uint64
}
// ConnectionStageTimestamps holds timestamps for each connection stage
type ConnectionStageTimestamps struct {
SignalingReceived time.Time // First signal received from remote peer (both initial and reconnection)
@@ -147,17 +127,6 @@ func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Du
c.impl.RecordSyncDuration(ctx, agentInfo, duration)
}
// RecordVNCSessionTick records a periodic snapshot of one VNC session.
func (c *ClientMetrics) RecordVNCSessionTick(ctx context.Context, tick VNCSessionTick) {
if c == nil {
return
}
c.mu.RLock()
agentInfo := c.agentInfo
c.mu.RUnlock()
c.impl.RecordVNCSessionTick(ctx, agentInfo, tick)
}
// RecordLoginDuration records how long the login to management server took
func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) {
if c == nil {

View File

@@ -73,9 +73,6 @@ func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.
func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) {
}
func (m *mockMetrics) RecordVNCSessionTick(_ context.Context, _ AgentInfo, _ VNCSessionTick) {
}
func (m *mockMetrics) Export(w io.Writer) error {
if m.exportData != "" {
_, err := w.Write([]byte(m.exportData))

View File

@@ -26,7 +26,6 @@ type connStatusInputs struct {
iceInProgress bool // a negotiation is currently in flight
}
// ConnStatus describe the status of a peer's connection
type ConnStatus int32

View File

@@ -193,6 +193,7 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
type Status struct {
mux sync.RWMutex
peers map[string]State
ipToKey map[string]string
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
signalState bool
signalError error
@@ -231,6 +232,7 @@ type Status struct {
func NewRecorder(mgmAddress string) *Status {
return &Status{
peers: make(map[string]State),
ipToKey: make(map[string]string),
changeNotify: make(map[string]map[string]*StatusChangeSubscription),
eventStreams: make(map[string]chan *proto.SystemEvent),
eventQueue: NewEventQueue(eventQueueSize),
@@ -282,6 +284,12 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string)
Mux: new(sync.RWMutex),
}
d.peerListChangedForNotification = true
if ipv6 != "" {
d.ipToKey[ipv6] = peerPubKey
}
if ip != "" {
d.ipToKey[ip] = peerPubKey
}
return nil
}
@@ -311,28 +319,22 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
// PeerStateByIP returns the full peer State for the given tunnel IP.
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
// address so dual-stack peers are reachable on either family. Searches
// both d.peers and d.offlinePeers — peers that have been moved into
// the offline slice by ReplaceOfflinePeers are still part of the
// account's roster and callers (DNS filter, embed.Client.IdentityForIP)
// need to recognise them rather than treating them as unknown. Returns
// the zero State and false when no peer matches or the input is empty.
// address so dual-stack peers are reachable on either family. Only
// active peers are matched; peers moved into the offline slice by
// ReplaceOfflinePeers are intentionally treated as unknown.
func (d *Status) PeerStateByIP(ip string) (State, bool) {
if ip == "" {
return State{}, false
}
d.mux.RLock()
defer d.mux.RUnlock()
for _, state := range d.peers {
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
return state, true
}
key, ok := d.ipToKey[ip]
if !ok {
return State{}, false
}
for _, state := range d.offlinePeers {
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
return state, true
}
state, ok := d.peers[key]
if ok {
return state, true
}
return State{}, false
}
@@ -342,12 +344,18 @@ func (d *Status) RemovePeer(peerPubKey string) error {
d.mux.Lock()
defer d.mux.Unlock()
_, ok := d.peers[peerPubKey]
p, ok := d.peers[peerPubKey]
if !ok {
return errors.New("no peer with to remove")
}
delete(d.peers, peerPubKey)
if mappedKey, exists := d.ipToKey[p.IP]; exists && mappedKey == peerPubKey {
delete(d.ipToKey, p.IP)
}
if mappedKey, exists := d.ipToKey[p.IPv6]; exists && mappedKey == peerPubKey {
delete(d.ipToKey, p.IPv6)
}
d.peerListChangedForNotification = true
return nil
}
@@ -1223,15 +1231,6 @@ func (d *Status) SubscribeToEvents() *EventSubscription {
}
}
// HasEventSubscribers reports whether any client is currently subscribed
// to the daemon's SystemEvent stream. Used by the VNC approval broker to
// fail closed when no UI is connected to prompt the user.
func (d *Status) HasEventSubscribers() bool {
d.eventMux.Lock()
defer d.eventMux.Unlock()
return len(d.eventStreams) > 0
}
// UnsubscribeFromEvents removes an event subscription
func (d *Status) UnsubscribeFromEvents(sub *EventSubscription) {
if sub == nil {

View File

@@ -90,12 +90,11 @@ func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
}
// TestStatus_PeerStateByIP_MatchesOfflinePeers covers peers that have
// been moved into the offline slice via ReplaceOfflinePeers. Callers
// (DNS filter, embed.Client.IdentityForIP) need to treat them as known
// rather than unknown — otherwise authentication / DNS filtering treats
// known-but-offline peers as foreign IPs.
func TestStatus_PeerStateByIP_MatchesOfflinePeers(t *testing.T) {
// TestStatus_PeerStateByIP_IgnoresOfflinePeers documents that peers
// moved into the offline slice via ReplaceOfflinePeers are intentionally
// not resolvable by IP: only active peers can carry traffic, so callers
// (DNS filter, embed.Client.IdentityForIP) treat them as unknown.
func TestStatus_PeerStateByIP_IgnoresOfflinePeers(t *testing.T) {
status := NewRecorder("https://mgm")
req := require.New(t)
@@ -103,13 +102,31 @@ func TestStatus_PeerStateByIP_MatchesOfflinePeers(t *testing.T) {
{PubKey: "pk-offline", FQDN: "offline.netbird", IP: "100.64.0.20", IPv6: "fd00::20"},
})
state, ok := status.PeerStateByIP("100.64.0.20")
req.True(ok, "offline peer must resolve by IPv4 tunnel address")
req.Equal("pk-offline", state.PubKey, "matching state must carry the offline peer's pub key")
_, ok := status.PeerStateByIP("100.64.0.20")
req.False(ok, "offline peer must not resolve by IPv4 tunnel address")
state, ok = status.PeerStateByIP("fd00::20")
req.True(ok, "offline peer must resolve by IPv6 tunnel address")
req.Equal("pk-offline", state.PubKey, "IPv6 match must carry the offline peer's pub key")
_, ok = status.PeerStateByIP("fd00::20")
req.False(ok, "offline peer must not resolve by IPv6 tunnel address")
}
// TestStatus_PeerStateByIP_RemovedPeer verifies RemovePeer drops the
// IP index entries for both address families.
func TestStatus_PeerStateByIP_RemovedPeer(t *testing.T) {
status := NewRecorder("https://mgm")
req := require.New(t)
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", "fd00::1"))
_, ok := status.PeerStateByIP("100.64.0.10")
req.True(ok, "active peer must resolve before removal")
req.NoError(status.RemovePeer("pk-1"))
_, ok = status.PeerStateByIP("100.64.0.10")
req.False(ok, "removed peer must not resolve by IPv4 tunnel address")
_, ok = status.PeerStateByIP("fd00::1")
req.False(ok, "removed peer must not resolve by IPv6 tunnel address")
}
func TestStatus_UpdatePeerFQDN(t *testing.T) {

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/mdm"
"github.com/netbirdio/netbird/client/ssh"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain"
@@ -57,6 +58,10 @@ var DefaultInterfaceBlacklist = []string{
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
}
// loadMDMPolicy is the package-level indirection used by apply() to read the
// active MDM policy. Tests override this to inject a fake policy.
var loadMDMPolicy = mdm.LoadPolicy
// ConfigInput carries configuration changes to the client
type ConfigInput struct {
ManagementURL string
@@ -65,8 +70,6 @@ type ConfigInput struct {
StateFilePath string
PreSharedKey *string
ServerSSHAllowed *bool
ServerVNCAllowed *bool
DisableVNCApproval *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -118,8 +121,6 @@ type Config struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed *bool
ServerVNCAllowed *bool
DisableVNCApproval *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -178,6 +179,23 @@ type Config struct {
LazyConnectionEnabled bool
MTU uint16
// policy is the MDM policy that produced the currently-set values for
// any MDM-enforced fields. Set by applyMDMPolicy at the tail of apply()
// and reset on every apply() invocation. Never persisted to disk.
// Callers query enforcement state via Policy() and the mdm.Policy API
// (HasKey, ManagedKeys, IsEmpty).
policy *mdm.Policy `json:"-"`
}
// Policy returns the MDM policy applied to this Config. Returns a non-nil
// empty Policy when MDM enforcement is inactive; callers can always invoke
// HasKey / ManagedKeys / IsEmpty without a nil check.
func (config *Config) Policy() *mdm.Policy {
if config == nil || config.policy == nil {
return mdm.NewPolicy(nil)
}
return config.policy
}
var ConfigDirOverride string
@@ -422,33 +440,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.ServerVNCAllowed != nil {
if config.ServerVNCAllowed == nil || *input.ServerVNCAllowed != *config.ServerVNCAllowed {
if *input.ServerVNCAllowed {
log.Infof("enabling VNC server")
} else {
log.Infof("disabling VNC server")
}
config.ServerVNCAllowed = input.ServerVNCAllowed
updated = true
}
} else if config.ServerVNCAllowed == nil {
config.ServerVNCAllowed = util.False()
updated = true
}
if input.DisableVNCApproval != nil {
if config.DisableVNCApproval == nil || *input.DisableVNCApproval != *config.DisableVNCApproval {
if *input.DisableVNCApproval {
log.Infof("disabling VNC connection approval prompt")
} else {
log.Infof("enabling VNC connection approval prompt")
}
config.DisableVNCApproval = input.DisableVNCApproval
updated = true
}
}
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
if *input.EnableSSHRoot {
log.Infof("enabling SSH root login")
@@ -643,10 +634,93 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
// MDM is the last override layer: any key present in the policy
// supersedes defaults, on-disk config, env vars and CLI input.
config.applyMDMPolicy(loadMDMPolicy())
return updated, nil
}
// parseURL parses and validates a service URL
// applyMDMPolicy overlays MDM-supplied values on top of the resolved Config.
// The provided Policy is also stored on the Config so callers can later query
// which fields are enforced. Invalid values (e.g. malformed URLs) are logged
// and skipped to avoid bricking the client; the field keeps its previous
// resolved value but is still marked as managed (Policy.HasKey returns true
// for the key, so per-field rejection of user writes still applies).
func (config *Config) applyMDMPolicy(policy *mdm.Policy) {
config.policy = policy
if policy.IsEmpty() {
return
}
// Helper: log the application of a single MDM-managed key. Values for
// keys in mdm.SecretKeys are redacted.
logApplied := func(key string, displayValue any) {
if _, secret := mdm.SecretKeys[key]; secret {
log.Infof("MDM override %s = ********** (secret)", key)
return
}
log.Infof("MDM override %s = %v", key, displayValue)
}
if v, ok := policy.GetString(mdm.KeyManagementURL); ok {
if u, err := parseURL("Management URL", v); err != nil {
log.Warnf("MDM management URL %q invalid: %v; keeping previous value", v, err)
} else {
config.ManagementURL = u
logApplied(mdm.KeyManagementURL, u.String())
}
}
if v, ok := policy.GetString(mdm.KeyPreSharedKey); ok {
// Defensive: refuse the redaction mask in case it round-tripped
// through a manifest by mistake.
if !isPreSharedKeyHidden(&v) {
config.PreSharedKey = v
logApplied(mdm.KeyPreSharedKey, "")
}
}
// applyBool collapses the per-key "read + set + log" boilerplate
// for every plain bool MDM key into a single helper. Keeps the
// outer function's cognitive complexity below SonarCube's
// threshold; functional behaviour is identical to the inlined
// branches it replaces.
applyBool := func(key string, setter func(bool)) {
v, ok := policy.GetBool(key)
if !ok {
return
}
setter(v)
logApplied(key, v)
}
applyBool(mdm.KeyAllowServerSSH, func(v bool) { bv := v; config.ServerSSHAllowed = &bv })
applyBool(mdm.KeyDisableClientRoutes, func(v bool) { config.DisableClientRoutes = v })
applyBool(mdm.KeyDisableServerRoutes, func(v bool) { config.DisableServerRoutes = v })
applyBool(mdm.KeyBlockInbound, func(v bool) { config.BlockInbound = v })
applyBool(mdm.KeyDisableAutoConnect, func(v bool) { config.DisableAutoConnect = v })
applyBool(mdm.KeyRosenpassEnabled, func(v bool) { config.RosenpassEnabled = v })
applyBool(mdm.KeyRosenpassPermissive, func(v bool) { config.RosenpassPermissive = v })
if v, ok := policy.GetInt(mdm.KeyWireguardPort); ok {
// REG_DWORD is 32-bit; UDP port range is 1-65535. Clamp at the
// upper bound and reject obviously-invalid values to avoid the
// engine binding to an unusable port if the admin pushes garbage.
if v >= 1 && v <= 65535 {
config.WgPort = int(v)
logApplied(mdm.KeyWireguardPort, v)
} else {
log.Warnf("MDM wireguard port %d out of range [1,65535]; keeping previous value", v)
}
}
}
// parseURL parses and validates the URL for the named service. The URL
// must use the http or https scheme; if no port is present, ":443" is
// appended for https or ":80" for http. The serviceName parameter is
// used to contextualise error messages. On success returns the parsed
// *url.URL; on failure returns a non-nil error.
func parseURL(serviceName, serviceURL string) (*url.URL, error) {
parsedMgmtURL, err := url.ParseRequestURI(serviceURL)
if err != nil {

View File

@@ -0,0 +1,152 @@
package profilemanager
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/mdm"
)
// withMDMPolicy temporarily overrides the package-level loadMDMPolicy hook so
// apply() observes the supplied Policy. The original loader is restored at
// test cleanup.
func withMDMPolicy(t *testing.T, policy *mdm.Policy) {
t.Helper()
prev := loadMDMPolicy
loadMDMPolicy = func() *mdm.Policy { return policy }
t.Cleanup(func() { loadMDMPolicy = prev })
}
func TestApply_MDMEmpty_NoEnforcement(t *testing.T) {
withMDMPolicy(t, mdm.NewPolicy(nil))
cfg, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
})
require.NoError(t, err)
require.NotNil(t, cfg)
assert.True(t, cfg.Policy().IsEmpty(), "no MDM source ⇒ empty Policy")
assert.False(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
assert.Empty(t, cfg.Policy().ManagedKeys())
// Default management URL still resolves.
assert.Equal(t, DefaultManagementURL, cfg.ManagementURL.String())
}
func TestApply_MDMOnly_OverridesDefaults(t *testing.T) {
const mdmURL = "https://corp.mdm.example.com:443"
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: mdmURL,
mdm.KeyDisableClientRoutes: true,
mdm.KeyBlockInbound: true,
}))
cfg, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
})
require.NoError(t, err)
require.NotNil(t, cfg)
assert.Equal(t, mdmURL, cfg.ManagementURL.String())
assert.True(t, cfg.DisableClientRoutes)
assert.True(t, cfg.BlockInbound)
assert.True(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
assert.True(t, cfg.Policy().HasKey(mdm.KeyDisableClientRoutes))
assert.True(t, cfg.Policy().HasKey(mdm.KeyBlockInbound))
assert.False(t, cfg.Policy().HasKey(mdm.KeyAllowServerSSH))
}
func TestApply_MDMBeatsCLIInput(t *testing.T) {
const mdmURL = "https://mdm.example.com:443"
const cliURL = "https://cli.example.com:443"
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: mdmURL,
}))
cfg, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
ManagementURL: cliURL,
})
require.NoError(t, err)
require.NotNil(t, cfg)
// MDM wins over CLI-supplied management URL.
assert.Equal(t, mdmURL, cfg.ManagementURL.String())
assert.True(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
}
func TestApply_MDMInvalidURL_KeepsPreviousValue(t *testing.T) {
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: "not-a-url",
}))
cfg, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
})
require.NoError(t, err)
require.NotNil(t, cfg)
// Invalid MDM URL is logged and skipped: default URL stays in place
// to keep the client functional.
assert.Equal(t, DefaultManagementURL, cfg.ManagementURL.String())
// But the key is still considered MDM-managed (admin intent is to
// enforce, daemon rejects user writes to this field — phase-1 scaffolding
// reflects this by keeping Policy.HasKey true even on parse failure).
assert.True(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
}
func TestApply_MDMBoolKeysOverrideOnDiskValue(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "config.json")
// Seed without MDM.
withMDMPolicy(t, mdm.NewPolicy(nil))
_, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: tmp,
DisableClientRoutes: boolPtr(false),
RosenpassEnabled: boolPtr(false),
})
require.NoError(t, err)
// Now enable MDM enforcement for these keys.
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyDisableClientRoutes: true,
mdm.KeyRosenpassEnabled: true,
}))
cfg, err := UpdateOrCreateConfig(ConfigInput{ConfigPath: tmp})
require.NoError(t, err)
require.NotNil(t, cfg)
assert.True(t, cfg.DisableClientRoutes, "MDM override should flip on-disk false to true")
assert.True(t, cfg.RosenpassEnabled)
assert.True(t, cfg.Policy().HasKey(mdm.KeyDisableClientRoutes))
assert.True(t, cfg.Policy().HasKey(mdm.KeyRosenpassEnabled))
}
func TestApply_MDMPreSharedKeyRedactionSentinelRejected(t *testing.T) {
const maskSentinel = "**********"
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyPreSharedKey: maskSentinel,
}))
cfg, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
})
require.NoError(t, err)
require.NotNil(t, cfg)
// Mask sentinel must not be persisted as the actual PSK.
assert.NotEqual(t, maskSentinel, cfg.PreSharedKey)
// Key still marked managed so user writes are still rejected.
assert.True(t, cfg.Policy().HasKey(mdm.KeyPreSharedKey))
}
func boolPtr(b bool) *bool { return &b }

View File

@@ -74,14 +74,6 @@ func New(filePath string) *Manager {
}
}
// FilePath returns the path of the underlying state file.
func (m *Manager) FilePath() string {
if m == nil {
return ""
}
return m.filePath
}
// Start starts the state manager periodic save routine
func (m *Manager) Start() {
if m == nil {

View File

@@ -0,0 +1,50 @@
//go:build windows || darwin
package mdm
import "strings"
// allKeys is the set of recognised MDM keys. Unknown keys in a managed
// configuration are ignored but logged. Lives in this build-tagged file
// (windows || darwin) because only desktop loaders need the
// canonicalisation table that consumes it; including it unconditionally
// would trigger the `unused` golangci-lint check on platforms that
// don't import canonical_loaders.go.
var allKeys = []string{
KeyManagementURL,
KeyDisableUpdateSettings,
KeyDisableProfiles,
KeyDisableNetworks,
KeyDisableClientRoutes,
KeyDisableServerRoutes,
KeyBlockInbound,
KeyDisableMetricsCollection,
KeyAllowServerSSH,
KeyDisableAutoConnect,
KeyPreSharedKey,
KeyRosenpassEnabled,
KeyRosenpassPermissive,
KeyWireguardPort,
KeySplitTunnelMode,
KeySplitTunnelApps,
}
// canonicalKey maps the lowercase form of a managed-config value name to
// its canonical mdm.Key* form. Admins commonly write PascalCase value
// names in ADMX / Group Policy ("ManagementURL"); the iOS/AppConfig and
// macOS plist conventions are camelCase ("managementURL"); both must
// resolve to the same Policy lookup.
//
// Lives in a desktop-loader-only file (build tag `windows || darwin`)
// because no other build path consumes it. Linux / FreeBSD / mobile
// builds don't ship a platform loader that reads arbitrary-case key
// names, so they don't need the canonicalisation table — and including
// the var unconditionally would trigger the `unused` golangci-lint
// check on those platforms.
var canonicalKey = func() map[string]string {
m := make(map[string]string, len(allKeys))
for _, k := range allKeys {
m[strings.ToLower(k)] = k
}
return m
}()

247
client/mdm/policy.go Normal file
View File

@@ -0,0 +1,247 @@
// Package mdm reads MDM-managed configuration from platform-native sources
// (plist on macOS, registry on Windows, UserDefaults on iOS,
// RestrictionsManager on Android). The returned Policy is consumed by
// profilemanager.Config.apply() as the highest-priority override layer.
//
// An empty Policy (no source present, or source present with zero keys)
// means no MDM enforcement is active and the client behaves as if the
// feature did not exist.
package mdm
import (
"sort"
"strconv"
log "github.com/sirupsen/logrus"
)
// Well-known policy keys. Names mirror the corresponding ConfigInput Go field
// names (lowerCamelCase) so the daemon can map a Policy key directly to a
// configuration field.
const (
KeyManagementURL = "managementURL"
KeyDisableUpdateSettings = "disableUpdateSettings"
KeyDisableProfiles = "disableProfiles"
KeyDisableNetworks = "disableNetworks"
KeyDisableClientRoutes = "disableClientRoutes"
KeyDisableServerRoutes = "disableServerRoutes"
KeyBlockInbound = "blockInbound"
KeyDisableMetricsCollection = "disableMetricsCollection"
KeyAllowServerSSH = "allowServerSSH"
KeyDisableAutoConnect = "disableAutoConnect"
KeyPreSharedKey = "preSharedKey"
KeyRosenpassEnabled = "rosenpassEnabled"
KeyRosenpassPermissive = "rosenpassPermissive"
KeyWireguardPort = "wireguardPort"
// Split tunnel is modeled as a single conceptual policy with two
// registry/plist values. KeySplitTunnelMode is the discriminator
// ("allow" or "disallow"); KeySplitTunnelApps is a comma-separated
// list of package names. The values are mutually exclusive by
// construction — only one mode can be set at a time.
KeySplitTunnelMode = "splitTunnelMode"
KeySplitTunnelApps = "splitTunnelApps"
)
// Split-tunnel mode literals (KeySplitTunnelMode values).
const (
SplitTunnelModeAllow = "allow"
SplitTunnelModeDisallow = "disallow"
)
// SecretKeys lists keys whose values must be redacted in logs.
var SecretKeys = map[string]struct{}{
KeyPreSharedKey: {},
}
// boolStringLiterals enumerates the textual boolean encodings the
// platform loaders may produce (Windows REG_SZ "true", iOS / Android
// managed-config booleans-as-strings, etc.). Lookup keeps GetBool flat
// (no nested switch on the string case).
var boolStringLiterals = map[string]bool{
"true": true,
"1": true,
"yes": true,
"false": false,
"0": false,
"no": false,
}
// Policy holds MDM-managed settings read from the platform source. A nil or
// empty Policy means no enforcement is active.
type Policy struct {
values map[string]any
}
// NewPolicy constructs a Policy from a key→value map. Pass nil or an
// empty map to construct an empty (no-enforcement) Policy. The returned
// *Policy is always non-nil.
func NewPolicy(values map[string]any) *Policy {
if values == nil {
values = map[string]any{}
}
return &Policy{values: values}
}
// LoadPolicy reads the platform-native MDM configuration. Returns an
// empty (but non-nil) Policy when no source is present, the source is
// empty, or the platform is unsupported.
//
// Diagnostic logging differentiates the three states:
// - source absent / unsupported platform: trace log only
// - source present, zero keys: info "MDM enrolled (no managed keys)"
// - source present, N keys: info "MDM enrolled with N managed keys: [...]"
func LoadPolicy() *Policy {
values, err := loadPlatformPolicy()
if err != nil {
log.Tracef("MDM policy load: %v", err)
return &Policy{values: map[string]any{}}
}
if values == nil {
return &Policy{values: map[string]any{}}
}
if len(values) == 0 {
log.Info("MDM enrolled (no managed keys)")
} else {
log.Infof("MDM enrolled with %d managed key(s): %v", len(values), sortedKeys(values))
}
return &Policy{values: values}
}
// IsEmpty reports whether the Policy has no managed keys.
func (p *Policy) IsEmpty() bool {
return p == nil || len(p.values) == 0
}
// HasKey reports whether the given key is MDM-managed.
func (p *Policy) HasKey(key string) bool {
if p == nil {
return false
}
_, ok := p.values[key]
return ok
}
// ManagedKeys returns the sorted list of managed key names. Returns an empty
// slice (not nil) on an empty Policy.
func (p *Policy) ManagedKeys() []string {
if p == nil {
return []string{}
}
return sortedKeys(p.values)
}
// GetString returns the managed value for key coerced to string, and whether
// the key was set. A non-string value returns ("", false).
func (p *Policy) GetString(key string) (string, bool) {
if p == nil {
return "", false
}
v, ok := p.values[key]
if !ok {
return "", false
}
s, ok := v.(string)
if !ok || s == "" {
return "", false
}
return s, true
}
// GetBool returns the managed value for key coerced to bool, and whether the
// key was set. Accepts native bool and string literals "true"/"false"/"1"/"0".
func (p *Policy) GetBool(key string) (bool, bool) {
if p == nil {
return false, false
}
v, ok := p.values[key]
if !ok {
return false, false
}
switch t := v.(type) {
case bool:
return t, true
case string:
b, known := boolStringLiterals[t]
return b, known
case int:
return t != 0, true
case int64:
return t != 0, true
}
return false, false
}
// GetInt returns the managed value for key as int64, and whether the key
// was set. Accepts native int / int64 (as produced by the Windows registry
// loader for REG_DWORD/REG_QWORD) and numeric strings (decimal).
func (p *Policy) GetInt(key string) (int64, bool) {
if p == nil {
return 0, false
}
v, ok := p.values[key]
if !ok {
return 0, false
}
switch t := v.(type) {
case int64:
return t, true
case int:
return int64(t), true
case int32:
return int64(t), true
case uint64:
return int64(t), true
case float64:
return int64(t), true
case string:
if n, err := strconv.ParseInt(t, 10, 64); err == nil {
return n, true
}
}
return 0, false
}
// GetStringSlice returns the managed value for key as []string, and whether
// the key was set. Accepts []string, []any (of strings), and a single string
// (treated as a one-element list).
func (p *Policy) GetStringSlice(key string) ([]string, bool) {
if p == nil {
return nil, false
}
v, ok := p.values[key]
if !ok {
return nil, false
}
switch t := v.(type) {
case []string:
return append([]string(nil), t...), true
case []any:
out := make([]string, 0, len(t))
for _, item := range t {
s, ok := item.(string)
if !ok {
return nil, false
}
out = append(out, s)
}
return out, true
case string:
return []string{t}, true
}
return nil, false
}
// sortedKeys returns the keys of m as a deterministic, lexicographically
// sorted slice. Used internally by Policy.ManagedKeys and LoadPolicy's
// diagnostic log line so callers see a stable key order across runs
// regardless of Go's randomised map iteration.
func sortedKeys(m map[string]any) []string {
out := make([]string, 0, len(m))
for k := range m {
out = append(out, k)
}
sort.Strings(out)
return out
}

View File

@@ -0,0 +1,90 @@
//go:build darwin && !ios
package mdm
import (
"errors"
"fmt"
"io/fs"
"os"
"strings"
log "github.com/sirupsen/logrus"
"howett.net/plist"
)
// policyPlistPath is the well-known location where macOS writes the
// device-level mandatory MDM payload for NetBird. The path is fixed by
// Apple convention: when an MDM provider (Jamf / Kandji / Mosyle /
// Intune for Mac / Workspace ONE) pushes a Configuration Profile that
// contains a com.apple.ManagedClient.preferences payload targeting the
// bundle id io.netbird.client, the OS materializes the payload here.
//
// Read-only — only the OS (root) is supposed to write this file. The
// loader sanity-checks the file mode and refuses to honour a world-
// writable plist, as a defense against tampered installs.
const policyPlistPath = "/Library/Managed Preferences/io.netbird.client.plist"
// loadPlatformPolicy reads the MDM-managed configuration from the macOS
// managed-preferences plist at policyPlistPath. Returns:
// - (nil, nil) when the plist is absent (device not MDM-enrolled for
// NetBird, or admin has not yet pushed a payload)
// - (map, nil) with N entries when N managed values are present
// (N may be 0 — empty plist still signals enrollment to the caller)
// - (nil, err) on permission / parse / safety errors (including
// refusal to read a world-writable plist)
//
// Top-level plist keys are canonicalised case-insensitively to the
// package's internal mdm.Key* names; unknown keys are logged and
// skipped so a stray entry in the payload does not block startup.
// Native plist value types map naturally onto the Policy accessor
// expectations (GetString / GetBool / GetInt / GetStringSlice).
func loadPlatformPolicy() (map[string]any, error) {
f, err := os.Open(policyPlistPath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
// Not enrolled for NetBird. Caller treats nil as
// "no MDM source present".
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
return nil, nil
}
return nil, fmt.Errorf("open %s: %w", policyPlistPath, err)
}
defer func() {
if closeErr := f.Close(); closeErr != nil {
log.Warnf("MDM close plist %s: %v", policyPlistPath, closeErr)
}
}()
info, err := f.Stat()
if err != nil {
return nil, fmt.Errorf("stat %s: %w", policyPlistPath, err)
}
// World-writable plist => tampered install. Refuse rather than
// honour potentially attacker-controlled policy values.
if info.Mode().Perm()&0o002 != 0 {
return nil, fmt.Errorf("refusing to read world-writable MDM source %s (mode %o)",
policyPlistPath, info.Mode().Perm())
}
raw := make(map[string]any)
if err := plist.NewDecoder(f).Decode(&raw); err != nil {
return nil, fmt.Errorf("decode plist %s: %w", policyPlistPath, err)
}
out := make(map[string]any, len(raw))
for name, val := range raw {
// macOS / AppConfig conventions both use camelCase for managed
// preferences keys; canonicalize to the mdm.Key* form so a key
// written as "ManagementURL" (PascalCase, rare on macOS but
// possible if the admin reused an ADMX-style name) still
// resolves.
canonical, known := canonicalKey[strings.ToLower(name)]
if !known {
log.Warnf("MDM ignoring unknown plist key %s: %s", policyPlistPath, name)
continue
}
out[canonical] = val
}
return out, nil
}

View File

@@ -0,0 +1,14 @@
//go:build ios || android
package mdm
// loadPlatformPolicy is unused on mobile: the native layer (Swift on iOS,
// Kotlin/Java on Android) reads the OS managed-config store and pushes the
// resulting dictionary in-process via a gomobile entry point that lands in
// Phase 5 / Phase 6. The stub keeps the package compilable for mobile
// builds and returns (nil, nil) — the platform-absent sentinel that
// LoadPolicy in policy.go treats as "no MDM source present".
func loadPlatformPolicy() (map[string]any, error) {
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
return nil, nil
}

View File

@@ -0,0 +1,14 @@
//go:build !windows && !darwin && !ios && !android
package mdm
// loadPlatformPolicy returns no policy on platforms without an MDM channel
// (Linux, FreeBSD). MDM enforcement is off and the client behaves as if
// the feature did not exist. Returns (nil, nil) — the platform-absent
// sentinel the caller (LoadPolicy in policy.go) treats as "no MDM
// source present"; an error here would just translate to the same
// outcome with an extra log line.
func loadPlatformPolicy() (map[string]any, error) {
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
return nil, nil
}

160
client/mdm/policy_test.go Normal file
View File

@@ -0,0 +1,160 @@
package mdm
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPolicy_NilSafe(t *testing.T) {
var p *Policy
assert.True(t, p.IsEmpty())
assert.False(t, p.HasKey(KeyManagementURL))
assert.Empty(t, p.ManagedKeys())
_, ok := p.GetString(KeyManagementURL)
assert.False(t, ok)
_, ok = p.GetBool(KeyDisableProfiles)
assert.False(t, ok)
_, ok = p.GetStringSlice(KeySplitTunnelApps)
assert.False(t, ok)
}
func TestPolicy_Empty(t *testing.T) {
p := NewPolicy(nil)
require.NotNil(t, p)
assert.True(t, p.IsEmpty())
assert.False(t, p.HasKey(KeyManagementURL))
assert.Empty(t, p.ManagedKeys())
}
func TestPolicy_HasKey(t *testing.T) {
p := NewPolicy(map[string]any{
KeyManagementURL: "https://corp.example.com",
KeyDisableProfiles: true,
})
assert.False(t, p.IsEmpty())
assert.True(t, p.HasKey(KeyManagementURL))
assert.True(t, p.HasKey(KeyDisableProfiles))
assert.False(t, p.HasKey(KeyPreSharedKey))
}
func TestPolicy_ManagedKeysSorted(t *testing.T) {
p := NewPolicy(map[string]any{
KeyDisableProfiles: true,
KeyManagementURL: "https://x",
KeyAllowServerSSH: false,
})
got := p.ManagedKeys()
assert.Equal(t, []string{KeyAllowServerSSH, KeyDisableProfiles, KeyManagementURL}, got)
}
func TestPolicy_GetString(t *testing.T) {
p := NewPolicy(map[string]any{
KeyManagementURL: "https://corp.example.com",
KeyDisableProfiles: true, // wrong type for GetString
KeyPreSharedKey: "", // empty rejected
})
v, ok := p.GetString(KeyManagementURL)
assert.True(t, ok)
assert.Equal(t, "https://corp.example.com", v)
_, ok = p.GetString(KeyDisableProfiles)
assert.False(t, ok, "non-string value must not be reported as string")
_, ok = p.GetString(KeyPreSharedKey)
assert.False(t, ok, "empty string treated as unset")
_, ok = p.GetString("nonexistent")
assert.False(t, ok)
}
func TestPolicy_GetBool(t *testing.T) {
cases := []struct {
name string
raw any
want bool
ok bool
}{
{"native true", true, true, true},
{"native false", false, false, true},
{"string true", "true", true, true},
{"string false", "false", false, true},
{"string 1", "1", true, true},
{"string 0", "0", false, true},
{"string yes", "yes", true, true},
{"string no", "no", false, true},
{"int nonzero", 1, true, true},
{"int zero", 0, false, true},
{"int64 nonzero", int64(2), true, true},
{"int64 zero", int64(0), false, true},
{"string garbage", "maybe", false, false},
{"float unsupported", 1.0, false, false},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
p := NewPolicy(map[string]any{KeyDisableProfiles: c.raw})
got, ok := p.GetBool(KeyDisableProfiles)
assert.Equal(t, c.ok, ok)
if c.ok {
assert.Equal(t, c.want, got)
}
})
}
_, ok := NewPolicy(nil).GetBool(KeyDisableProfiles)
assert.False(t, ok)
}
func TestPolicy_GetStringSlice(t *testing.T) {
t.Run("native string slice", func(t *testing.T) {
p := NewPolicy(map[string]any{
KeySplitTunnelApps: []string{"com.a", "com.b"},
})
got, ok := p.GetStringSlice(KeySplitTunnelApps)
assert.True(t, ok)
assert.Equal(t, []string{"com.a", "com.b"}, got)
})
t.Run("any slice of strings", func(t *testing.T) {
p := NewPolicy(map[string]any{
KeySplitTunnelApps: []any{"com.a", "com.b"},
})
got, ok := p.GetStringSlice(KeySplitTunnelApps)
assert.True(t, ok)
assert.Equal(t, []string{"com.a", "com.b"}, got)
})
t.Run("single string lifts to one-element slice", func(t *testing.T) {
p := NewPolicy(map[string]any{
KeySplitTunnelApps: "com.a",
})
got, ok := p.GetStringSlice(KeySplitTunnelApps)
assert.True(t, ok)
assert.Equal(t, []string{"com.a"}, got)
})
t.Run("mixed any slice rejected", func(t *testing.T) {
p := NewPolicy(map[string]any{
KeySplitTunnelApps: []any{"com.a", 1},
})
_, ok := p.GetStringSlice(KeySplitTunnelApps)
assert.False(t, ok)
})
t.Run("missing key", func(t *testing.T) {
p := NewPolicy(nil)
_, ok := p.GetStringSlice(KeySplitTunnelApps)
assert.False(t, ok)
})
}
func TestLoadPolicy_PlatformStubReturnsEmpty(t *testing.T) {
// loadPlatformPolicy is a stub on every OS for Phase 1. LoadPolicy must
// degrade gracefully and never return nil.
p := LoadPolicy()
require.NotNil(t, p)
assert.True(t, p.IsEmpty())
assert.Empty(t, p.ManagedKeys())
}

View File

@@ -0,0 +1,108 @@
//go:build windows
package mdm
import (
"errors"
"fmt"
"strings"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry"
)
// policyRegistryPath is the well-known MDM policy registry key for NetBird.
// Admins push values here through Group Policy, Intune ADMX ingestion, an
// Intune custom Registry CSP profile, or `reg add` during MSI deployment.
// Listed in the project's docs/mdm/netbird.admx schema.
const policyRegistryPath = `Software\Policies\NetBird`
// readRegistryValue reads a single value under policyRegistryPath and,
// on success, stores the type-coerced result in out[canonical]. Type
// coercion mirrors loadPlatformPolicy's documented mapping:
// - REG_SZ / REG_EXPAND_SZ -> string (REG_EXPAND_SZ is expanded by the API)
// - REG_DWORD / REG_QWORD -> int64
// - REG_MULTI_SZ -> []string
//
// Unsupported value types and per-value read failures are logged at
// warn level and skipped — one malformed value must not block the
// surrounding loop. Extracted from loadPlatformPolicy to keep that
// function's cognitive complexity in check.
func readRegistryValue(k registry.Key, name, canonical string, out map[string]any) {
_, valType, err := k.GetValue(name, nil)
if err != nil {
log.Warnf("MDM stat %s\\%s: %v", policyRegistryPath, name, err)
return
}
switch valType {
case registry.SZ, registry.EXPAND_SZ:
if v, _, err := k.GetStringValue(name); err == nil {
out[canonical] = v
} else {
log.Warnf("MDM read string %s\\%s: %v", policyRegistryPath, name, err)
}
case registry.DWORD, registry.QWORD:
if v, _, err := k.GetIntegerValue(name); err == nil {
// uint64 from the registry API; Policy.GetBool / GetInt
// helpers consume int64, so narrow safely.
out[canonical] = int64(v)
} else {
log.Warnf("MDM read int %s\\%s: %v", policyRegistryPath, name, err)
}
case registry.MULTI_SZ:
if v, _, err := k.GetStringsValue(name); err == nil {
out[canonical] = v
} else {
log.Warnf("MDM read multi-string %s\\%s: %v", policyRegistryPath, name, err)
}
default:
log.Warnf("MDM ignoring unsupported registry value type %d at %s\\%s",
valType, policyRegistryPath, name)
}
}
// loadPlatformPolicy reads the MDM-managed configuration from the
// Windows registry under HKLM\Software\Policies\NetBird. Returns:
// - (nil, nil) when the key is absent (device not MDM-enrolled for NetBird)
// - (map, nil) with N entries when N managed values are set (N may be 0)
// - (nil, err) on open / enumerate registry errors
//
// Per-value type coercion + skip-on-error is delegated to
// readRegistryValue. Unknown value names are logged and skipped so a
// malformed deployment does not block startup.
func loadPlatformPolicy() (map[string]any, error) {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, policyRegistryPath, registry.QUERY_VALUE)
if err != nil {
if errors.Is(err, registry.ErrNotExist) {
// Not enrolled. Caller treats nil as "no MDM source present".
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
return nil, nil
}
return nil, fmt.Errorf("open %s: %w", policyRegistryPath, err)
}
defer func() {
if closeErr := k.Close(); closeErr != nil {
log.Warnf("MDM close registry key %s: %v", policyRegistryPath, closeErr)
}
}()
names, err := k.ReadValueNames(-1)
if err != nil {
return nil, fmt.Errorf("enumerate values of %s: %w", policyRegistryPath, err)
}
out := make(map[string]any, len(names))
for _, name := range names {
// Canonicalize the registry value name against the known MDM key
// set so Policy.HasKey lookups (which use the canonical names)
// succeed regardless of the casing used by the admin's ADMX or
// `reg add` command.
canonical, known := canonicalKey[strings.ToLower(name)]
if !known {
log.Warnf("MDM ignoring unknown registry value %s\\%s", policyRegistryPath, name)
continue
}
readRegistryValue(k, name, canonical, out)
}
return out, nil
}

129
client/mdm/ticker.go Normal file
View File

@@ -0,0 +1,129 @@
package mdm
import (
"context"
"reflect"
"sort"
"time"
log "github.com/sirupsen/logrus"
)
// DefaultReloadInterval is the production cadence at which the desktop daemon
// re-reads the OS-native MDM policy. Picked to balance responsiveness against
// registry/plist I/O overhead. Mobile builds use OS-side notifications
// instead, hence anticipating the ticker mechanism entirely.
const DefaultReloadInterval = 1 * time.Minute
// policyLoader is the indirection through which the ticker reads the
// OS-native policy, both for the initial observation and on every tick.
// Production points it at LoadPolicy; tests in this package override it to
// feed a scripted sequence of policies without touching the real OS store.
var policyLoader = LoadPolicy
// Ticker periodically re-reads the OS-native MDM policy via LoadPolicy and
// invokes the onChange callback (supplied to Run) whenever the observed
// Policy diverges from the last observation (added / removed / changed
// keys). Launch with Run from a goroutine; cancel the supplied context
// to stop.
type Ticker struct {
interval time.Duration
prev *Policy
}
// NewTicker constructs a Ticker that will re-read the OS-native policy
// every reloadInterval once Run is called.
// The initial snapshot is populated by calling policyLoader at
// construction time so the first tick only fires
// onChange when the policy actually changed since boot — without
// this baseline the first tick would report every currently-managed
// key as "added" and trigger a spurious engine restart.
func NewTicker(reloadInterval time.Duration) *Ticker {
return &Ticker{
interval: reloadInterval,
prev: policyLoader(),
}
}
// Run blocks until ctx is cancelled, polling the OS-native policy store at
// the configured cadence and emitting log lines + onChange callback on
// every observed diff. onChange must be non-nil.
func (t *Ticker) Run(ctx context.Context, onChange func(prev, curr *Policy) error) {
tk := time.NewTicker(t.interval)
defer tk.Stop()
log.Infof("MDM policy reload ticker started (interval=%s)", t.interval)
for {
select {
case <-ctx.Done():
log.Info("MDM policy reload ticker stopped")
return
case <-tk.C:
curr := policyLoader()
if policiesEqual(t.prev, curr) {
continue
}
added, removed, changed := diffPolicies(t.prev, curr)
log.Infof("MDM policy changed: added=%v removed=%v changed=%v",
added, removed, changed)
prev := t.prev
if err := onChange(prev, curr); err != nil {
log.Errorf("MDM policy change handler failed (retrying in 1 minute): %v", err)
continue
}
t.prev = curr
}
}
}
// policiesEqual reports whether two Policy instances carry the same
// managed key set with identical values. Nil and empty policies
// compare equal; one-nil/one-non-empty compare not equal; otherwise
// the underlying values maps are compared with reflect.DeepEqual.
func policiesEqual(a, b *Policy) bool {
if a.IsEmpty() && b.IsEmpty() {
return true
}
if a == nil || b == nil {
return false
}
return reflect.DeepEqual(a.values, b.values)
}
// diffPolicies returns the keys added in curr, removed from prev, and
// whose values changed between prev and curr. Each slice is sorted
// lexicographically for stable log output; value differences are
// determined with reflect.DeepEqual.
func diffPolicies(prev, curr *Policy) (added, removed, changed []string) {
prevKVs := mapOf(prev)
currKVs := mapOf(curr)
for k := range currKVs {
if _, ok := prevKVs[k]; !ok {
added = append(added, k)
} else if !reflect.DeepEqual(prevKVs[k], currKVs[k]) {
changed = append(changed, k)
}
}
for k := range prevKVs {
if _, ok := currKVs[k]; !ok {
removed = append(removed, k)
}
}
sort.Strings(added)
sort.Strings(removed)
sort.Strings(changed)
return added, removed, changed
}
// mapOf returns a (possibly empty, never nil) copy of the underlying
// values map of a Policy so callers outside this package can compare
// keys/values across the type boundary. Returns an empty map on nil p.
func mapOf(p *Policy) map[string]any {
if p == nil {
return map[string]any{}
}
out := make(map[string]any, len(p.values))
for k, v := range p.values {
out[k] = v
}
return out
}

100
client/mdm/ticker_test.go Normal file
View File

@@ -0,0 +1,100 @@
package mdm
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testReloadInterval for speeding up the ticker cadence under `go test`
const testReloadInterval = 1 * time.Second
// withPolicyLoader overrides the package-level policyLoader for the duration
// of the test so the ticker observes a scripted policy instead of the real
// OS-native store. The original loader is restored on cleanup.
func withPolicyLoader(t *testing.T, fn func() *Policy) {
t.Helper()
prev := policyLoader
policyLoader = fn
t.Cleanup(func() { policyLoader = prev })
}
func TestTicker_FiresOnChangeWithDelta(t *testing.T) {
var mu sync.Mutex
current := NewPolicy(nil) // initial observation: empty (no enforcement)
withPolicyLoader(t, func() *Policy {
mu.Lock()
defer mu.Unlock()
return current
})
type change struct{ prev, curr *Policy }
changes := make(chan change, 1)
tk := NewTicker(testReloadInterval)
require.Equal(t, testReloadInterval, tk.interval)
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
tk.Run(ctx, func(prev, curr *Policy) error {
select {
case changes <- change{prev, curr}:
default:
}
return nil
})
close(done)
}()
// Stop Run and wait for it to exit before returning, so the policyLoader
// restore in t.Cleanup can't race the ticker goroutine still reading it.
defer func() { cancel(); <-done }()
// Flip the OS-observed policy from empty to one managed key. The next
// tick must detect the diff and invoke onChange.
mu.Lock()
current = NewPolicy(map[string]any{KeyManagementURL: "https://mdm.example.com:443"})
mu.Unlock()
select {
case c := <-changes:
assert.True(t, c.prev.IsEmpty(), "prev should be the initial empty policy")
assert.True(t, c.curr.HasKey(KeyManagementURL), "curr should carry the newly-pushed managed key")
case <-time.After(5 * time.Second):
t.Fatal("onChange not invoked within 5s; ticker should fire every 1s under test")
}
}
func TestTicker_NoCallbackWhenPolicyUnchanged(t *testing.T) {
withPolicyLoader(t, func() *Policy {
return NewPolicy(map[string]any{KeyBlockInbound: true})
})
fired := make(chan struct{}, 1)
tk := NewTicker(testReloadInterval)
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
tk.Run(ctx, func(_, _ *Policy) error {
select {
case fired <- struct{}{}:
default:
}
return nil
})
close(done)
}()
defer func() { cancel(); <-done }()
// Over ~2 ticks at the 1s test cadence the policy never changes, so the
// diff guard must suppress the callback entirely.
select {
case <-fired:
t.Fatal("onChange fired despite an unchanged policy")
case <-time.After(2500 * time.Millisecond):
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -119,14 +119,6 @@ service DaemonService {
// ExposeService exposes a local port via the NetBird reverse proxy
rpc ExposeService(ExposeServiceRequest) returns (stream ExposeServiceEvent) {}
// RespondApproval delivers the user's accept/deny decision for a
// pending user-approval prompt. The daemon pushes the prompt as a
// SystemEvent with category APPROVAL and metadata key "request_id";
// the UI calls this RPC with the same request_id to unblock whichever
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
// the UI which subsystem the prompt belongs to.
rpc RespondApproval(RespondApprovalRequest) returns (RespondApprovalResponse) {}
}
@@ -213,10 +205,6 @@ message LoginRequest {
optional bool disableSSHAuth = 38;
optional int32 sshJWTCacheTTL = 39;
optional bool disable_ipv6 = 40;
optional bool serverVNCAllowed = 41;
optional bool disableVNCApproval = 42;
}
message LoginResponse {
@@ -327,9 +315,12 @@ message GetConfigResponse {
bool disable_ipv6 = 27;
bool serverVNCAllowed = 28;
bool disableVNCApproval = 29;
// mDMManagedFields lists the names of configuration keys whose value is
// currently enforced by an MDM policy. Names match mdm.Key* constants
// (e.g. "managementURL", "disableClientRoutes"). UI/CLI clients should
// render the corresponding inputs as read-only and display a "managed
// by MDM" indicator.
repeated string mDMManagedFields = 28;
}
// PeerState contains the latest state of a peer
@@ -411,25 +402,6 @@ message SSHServerState {
repeated SSHSessionInfo sessions = 2;
}
// VNCSessionInfo contains information about an active VNC session
message VNCSessionInfo {
string remoteAddress = 1;
string mode = 2;
string username = 3;
// userID is the Noise-verified session identity (hashed user ID from
// the ACL session-key entry), empty when auth is disabled.
string userID = 4;
// initiator is the human-readable display name of the dashboard user
// who minted the SessionPubKey, when known.
string initiator = 5;
}
// VNCServerState contains the latest state of the VNC server
message VNCServerState {
bool enabled = 1;
repeated VNCSessionInfo sessions = 2;
}
// FullStatus contains the full state held by the Status instance
message FullStatus {
ManagementState managementState = 1;
@@ -444,7 +416,6 @@ message FullStatus {
bool lazyConnectionEnabled = 9;
SSHServerState sshServerState = 10;
VNCServerState vncServerState = 11;
}
// Networks
@@ -633,7 +604,6 @@ message SystemEvent {
AUTHENTICATION = 2;
CONNECTIVITY = 3;
SYSTEM = 4;
APPROVAL = 5;
}
string id = 1;
@@ -717,10 +687,6 @@ message SetConfigRequest {
optional bool disableSSHAuth = 33;
optional int32 sshJWTCacheTTL = 34;
optional bool disable_ipv6 = 35;
optional bool serverVNCAllowed = 36;
optional bool disableVNCApproval = 37;
}
message SetConfigResponse{}
@@ -774,6 +740,15 @@ message GetFeaturesResponse{
bool disable_networks = 3;
}
// MDMManagedFieldsViolation is attached as a gRPC error detail on a
// FailedPrecondition status returned from SetConfig (and similar mutating
// RPCs) when the caller tries to modify one or more MDM-enforced fields.
// The fields list contains the offending key names; the entire request is
// rejected (no partial apply).
message MDMManagedFieldsViolation {
repeated string fields = 1;
}
message TriggerUpdateRequest {}
message TriggerUpdateResponse {
@@ -915,18 +890,3 @@ message StartBundleCaptureRequest {
message StartBundleCaptureResponse {}
message StopBundleCaptureRequest {}
message StopBundleCaptureResponse {}
message RespondApprovalRequest {
// request_id matches the SystemEvent metadata key emitted by the daemon
// when a subsystem awaits user approval for an inbound connection.
string request_id = 1;
// accept is true if the user approved the request, false if they
// denied it. A missing or unknown request_id is treated as a no-op.
bool accept = 2;
// view_only signals that the user granted the connection but withheld
// input control. Only meaningful when accept is true; ignored when
// accept is false.
bool view_only = 3;
}
message RespondApprovalResponse {}

View File

@@ -58,7 +58,6 @@ const (
DaemonService_StopCPUProfile_FullMethodName = "/daemon.DaemonService/StopCPUProfile"
DaemonService_GetInstallerResult_FullMethodName = "/daemon.DaemonService/GetInstallerResult"
DaemonService_ExposeService_FullMethodName = "/daemon.DaemonService/ExposeService"
DaemonService_RespondApproval_FullMethodName = "/daemon.DaemonService/RespondApproval"
)
// DaemonServiceClient is the client API for DaemonService service.
@@ -135,13 +134,6 @@ type DaemonServiceClient interface {
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
// ExposeService exposes a local port via the NetBird reverse proxy
ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ExposeServiceEvent], error)
// RespondApproval delivers the user's accept/deny decision for a
// pending user-approval prompt. The daemon pushes the prompt as a
// SystemEvent with category APPROVAL and metadata key "request_id";
// the UI calls this RPC with the same request_id to unblock whichever
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
// the UI which subsystem the prompt belongs to.
RespondApproval(ctx context.Context, in *RespondApprovalRequest, opts ...grpc.CallOption) (*RespondApprovalResponse, error)
}
type daemonServiceClient struct {
@@ -569,16 +561,6 @@ func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServi
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type DaemonService_ExposeServiceClient = grpc.ServerStreamingClient[ExposeServiceEvent]
func (c *daemonServiceClient) RespondApproval(ctx context.Context, in *RespondApprovalRequest, opts ...grpc.CallOption) (*RespondApprovalResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(RespondApprovalResponse)
err := c.cc.Invoke(ctx, DaemonService_RespondApproval_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility.
@@ -653,13 +635,6 @@ type DaemonServiceServer interface {
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
// ExposeService exposes a local port via the NetBird reverse proxy
ExposeService(*ExposeServiceRequest, grpc.ServerStreamingServer[ExposeServiceEvent]) error
// RespondApproval delivers the user's accept/deny decision for a
// pending user-approval prompt. The daemon pushes the prompt as a
// SystemEvent with category APPROVAL and metadata key "request_id";
// the UI calls this RPC with the same request_id to unblock whichever
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
// the UI which subsystem the prompt belongs to.
RespondApproval(context.Context, *RespondApprovalRequest) (*RespondApprovalResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -787,9 +762,6 @@ func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *Ins
func (UnimplementedDaemonServiceServer) ExposeService(*ExposeServiceRequest, grpc.ServerStreamingServer[ExposeServiceEvent]) error {
return status.Error(codes.Unimplemented, "method ExposeService not implemented")
}
func (UnimplementedDaemonServiceServer) RespondApproval(context.Context, *RespondApprovalRequest) (*RespondApprovalResponse, error) {
return nil, status.Error(codes.Unimplemented, "method RespondApproval not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
func (UnimplementedDaemonServiceServer) testEmbeddedByValue() {}
@@ -1492,24 +1464,6 @@ func _DaemonService_ExposeService_Handler(srv interface{}, stream grpc.ServerStr
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type DaemonService_ExposeServiceServer = grpc.ServerStreamingServer[ExposeServiceEvent]
func _DaemonService_RespondApproval_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RespondApprovalRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).RespondApproval(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: DaemonService_RespondApproval_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).RespondApproval(ctx, req.(*RespondApprovalRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -1661,10 +1615,6 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetInstallerResult",
Handler: _DaemonService_GetInstallerResult_Handler,
},
{
MethodName: "RespondApproval",
Handler: _DaemonService_RespondApproval_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@@ -111,7 +111,7 @@ func (s *Server) StartCapture(req *proto.StartCaptureRequest, stream proto.Daemo
return status.Errorf(codes.Internal, "create capture session: %v", err)
}
engine, err := s.claimCapture(sess, func() { pw.Close() })
engine, err := s.claimCapture(sess)
if err != nil {
sess.Stop()
pw.Close()
@@ -190,7 +190,10 @@ func (s *Server) StartBundleCapture(_ context.Context, req *proto.StartBundleCap
s.stopBundleCaptureLocked()
s.cleanupBundleCapture()
s.evictActiveCaptureLocked()
if s.activeCapture != nil {
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
}
engine, err := s.getCaptureEngineLocked()
if err != nil {
@@ -301,58 +304,29 @@ func (s *Server) cleanupBundleCapture() {
s.bundleCapture = nil
}
// claimCapture reserves the engine's capture slot for sess. If another
// capture is already running it is evicted: a previous streaming session
// whose gRPC client died and never freed the slot stays stuck otherwise,
// and a bundle capture is just informational state.
func (s *Server) claimCapture(sess *capture.Session, cancel func()) (*internal.Engine, error) {
// claimCapture reserves the engine's capture slot for sess. Returns
// FailedPrecondition if another capture is already active.
func (s *Server) claimCapture(sess *capture.Session) (*internal.Engine, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.evictActiveCaptureLocked()
if s.activeCapture != nil {
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
}
engine, err := s.getCaptureEngineLocked()
if err != nil {
return nil, err
}
s.activeCapture = sess
s.activeCaptureCancel = cancel
return engine, nil
}
// evictActiveCaptureLocked tears down whatever capture currently owns
// the engine slot so a fresh claim can succeed. Caller must hold mutex.
func (s *Server) evictActiveCaptureLocked() {
if s.activeCapture == nil {
return
}
if s.bundleCapture != nil && s.bundleCapture.sess == s.activeCapture {
log.Infof("evicting running bundle capture to start a new capture")
s.stopBundleCaptureLocked()
return
}
log.Infof("evicting previous streaming capture to start a new one")
prev := s.activeCapture
cancel := s.activeCaptureCancel
if engine, err := s.getCaptureEngineLocked(); err == nil {
if err := engine.SetCapture(nil); err != nil {
log.Debugf("clear previous capture: %v", err)
}
}
s.activeCapture = nil
s.activeCaptureCancel = nil
prev.Stop()
if cancel != nil {
cancel()
}
}
// releaseCapture clears the active-capture owner if it still matches sess.
func (s *Server) releaseCapture(sess *capture.Session) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.activeCapture == sess {
s.activeCapture = nil
s.activeCaptureCancel = nil
}
}
@@ -367,7 +341,6 @@ func (s *Server) clearCaptureIfOwner(sess *capture.Session, engine *internal.Eng
log.Debugf("clear capture: %v", err)
}
s.activeCapture = nil
s.activeCaptureCancel = nil
}
func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) {

419
client/server/mdm.go Normal file
View File

@@ -0,0 +1,419 @@
package server
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/mdm"
"github.com/netbirdio/netbird/client/proto"
)
// preSharedKeyRedactedSentinel is the value GetConfig returns in place
// of an actual PSK, so a UI that round-trips the field back to the
// daemon (via SetConfig / Login) can be distinguished from a deliberate
// override. Any incoming PSK that equals this sentinel is treated as
// a no-op echo, never as a conflict with the policy.
const preSharedKeyRedactedSentinel = "**********"
// loadMDMPolicy is the indirection used by server handlers to read the
// active MDM policy. Tests override this to inject a fake policy.
var loadMDMPolicy = mdm.LoadPolicy
// conflictCheck is a value-aware comparison between a single field in
// the incoming request and the corresponding MDM-enforced value. It
// runs only when the field was actually set in the request (presence
// already filtered upstream); ok=true reports the policy value, ok=false
// means the policy is silent on the key — both are treated as conflicts
// to be safe (an MDM key declared as managed must hold a value).
type conflictCheck struct {
key string
check func(*mdm.Policy) (match bool)
}
// onMDMPolicyChange is invoked by the MDM reload ticker every time the
// OS-native managed-config store reports a diff vs the last observation.
//
// Restart sequence:
// 1. Cancel the active engine context (terminates connectWithRetryRuns).
// 2. Wait briefly for that goroutine to exit (giveUpChan is closed on exit).
// 3. Re-resolve Config from disk + MDM policy (Config.apply re-runs
// applyMDMPolicy with the freshly loaded Policy).
// 4. Spawn a fresh connectWithRetryRuns with the new context and config.
// 5. Broadcast a SystemEvent so any GUI / CLI subscriber (SubscribeEvents
// RPC) can refresh its cached config view without polling.
//
// The callback runs in the ticker's own goroutine. Ticker has already
// logged the per-key diff before invoking this hook.
func (s *Server) onMDMPolicyChange(_, _ *mdm.Policy) error {
log.Warn("MDM policy changed; restarting engine to apply new configuration")
// Hold s.mutex for the entire restart sequence (cancel + quiescence
// wait + re-spawn). Any concurrent Up/Down/Status arriving while
// MDM is restarting blocks on the Lock until we are done — they
// then observe the post-restart state coherently. This is safe
// because the connectWithRetryRuns goroutine no longer acquires
// s.mutex in its defer (intent vs. goroutine-alive concerns are
// fully separated; see the connectionGoroutineRunning helper).
s.mutex.Lock()
defer s.mutex.Unlock()
if !s.clientRunning {
// The client is not running, so there's no engine to restart.
return nil
}
if s.actCancel != nil {
s.actCancel()
}
// Wait for previous connectWithRetryRuns to exit so we don't end up
// with two goroutines fighting over the same status recorder + engine.
// The teardown engages a fan-out of engine goroutines (peer workers,
// signal handler, route manager, ...). close(clientGiveUpChan)
// happens in the function-scope defer of connectWithRetryRuns, on
// every exit path (ctx cancel, backoff exhausted, panic) — see the
// defer in server.go.
if s.clientGiveUpChan != nil {
select {
case <-s.clientGiveUpChan:
case <-time.After(10 * time.Second):
return fmt.Errorf("failed to restart the engine due to timeout")
}
}
if err := s.restartEngineForMDMLocked(); err != nil {
log.Errorf("MDM restart failed: %v", err)
return err
}
// publishConfigChangedEvent has already fired inside
// restartEngineForMDMLocked with source="mdm". Emit an MDM-specific
// user-visible toast so the operator knows their IT policy was
// applied (UserMessage != "" triggers the GUI notifier).
s.statusRecorder.PublishEvent(
proto.SystemEvent_INFO,
proto.SystemEvent_SYSTEM,
"MDM policy applied",
"NetBird configuration was updated by your IT policy.",
map[string]string{"source": "mdm", "type": "policy_applied"},
)
return nil
}
// publishConfigChangedEvent broadcasts a SystemEvent informing any active
// SubscribeEvents subscriber (typically the GUI tray) that the daemon's
// effective Config has been replaced and any cached client-side view
// should be refreshed. Callers pass a stable `source` label so the GUI
// can distinguish a startup spawn from a user-triggered Up or an
// MDM-driven restart. Reusing the SYSTEM category keeps the proto enum
// stable; metadata.type="config_changed" routes to the GUI's refresh
// handler. UserMessage is left empty so the system tray does not toast
// for every internal restart; the MDM path emits a separate
// "policy_applied" event (with UserMessage) for that purpose.
func (s *Server) publishConfigChangedEvent(source string) {
if s.statusRecorder == nil {
return
}
s.statusRecorder.PublishEvent(
proto.SystemEvent_INFO,
proto.SystemEvent_SYSTEM,
fmt.Sprintf("daemon config changed (source=%s)", source),
"",
map[string]string{
"source": source,
"type": "config_changed",
},
)
}
// restartEngineForMDMLocked re-resolves the active profile config
// (re-running applyMDMPolicy via Config.apply) and re-spawns
// connectWithRetryRuns. Mirrors the tail of Server.Start so a runtime
// MDM change behaves identically to a fresh boot under the new policy.
//
// MUST be called with s.mutex held — onMDMPolicyChange holds the lock
// for the entire restart sequence (cancel + quiescence wait + re-spawn)
// so concurrent Up/Down/Status RPCs observe a coherent post-restart
// state.
func (s *Server) restartEngineForMDMLocked() error {
activeProf, err := s.profileManager.GetActiveProfileState()
if err != nil {
return fmt.Errorf("get active profile state: %w", err)
}
config, _, err := s.getConfig(activeProf)
if err != nil {
return fmt.Errorf("get active profile config: %w", err)
}
s.config = config
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
ctx, cancel := context.WithCancel(s.rootCtx)
s.actCancel = cancel
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
log.Info("MDM restart: spawning connectWithRetryRuns with re-resolved config")
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
s.publishConfigChangedEvent("mdm")
return nil
}
// conflictBool builds a conflictCheck for a boolean MDM key. If p is nil
// the field is treated as matching (no override requested); otherwise the
// check returns true only when the policy contains the key and its
// boolean value equals *p.
func conflictBool(key string, p *bool) conflictCheck {
return conflictCheck{
key: key,
check: func(pol *mdm.Policy) bool {
if p == nil {
return true // absent → match by definition
}
want, ok := pol.GetBool(key)
return ok && want == *p
},
}
}
// conflictString builds a conflictCheck for a string MDM key. An empty
// `got` is treated as "field not set" (no override requested); otherwise
// the check returns true only when the policy contains the key and its
// value equals got.
func conflictString(key, got string) conflictCheck {
return conflictCheck{
key: key,
check: func(pol *mdm.Policy) bool {
if got == "" {
return true
}
want, ok := pol.GetString(key)
return ok && want == got
},
}
}
// conflictInt64 builds a conflictCheck for an integer MDM key. If p is
// nil the field is treated as matching; otherwise the check returns
// true only when the policy contains the key and its int value equals *p.
func conflictInt64(key string, p *int64) conflictCheck {
return conflictCheck{
key: key,
check: func(pol *mdm.Policy) bool {
if p == nil {
return true
}
want, ok := pol.GetInt(key)
return ok && want == *p
},
}
}
// resolveConflicts walks the per-field checks against the active MDM
// policy and returns the names of keys whose requested value diverges
// from the policy-enforced value. Keys not present in the policy are
// skipped silently (the gate fires only for keys the admin has
// actually pushed). Returns nil for an empty policy.
func resolveConflicts(policy *mdm.Policy, checks []conflictCheck) []string {
if policy.IsEmpty() {
return nil
}
var conflicts []string
for _, c := range checks {
if !policy.HasKey(c.key) {
continue
}
if !c.check(policy) {
conflicts = append(conflicts, c.key)
}
}
return conflicts
}
// mdmManagedFieldConflicts returns the names of MDM-managed keys whose
// requested value in the SetConfigRequest differs from the MDM-enforced
// value. A field set to the same value the policy already enforces is
// treated as a no-op echo (the GUI tray sends a full Config snapshot on
// every toggle, so most fields in a typical request match the policy
// exactly and must NOT be flagged as conflicts). The redacted PSK
// sentinel ("**********") returned by GetConfig is recognised and
// treated as no-op so the UI can safely round-trip it.
func mdmManagedFieldConflicts(msg *proto.SetConfigRequest, policy *mdm.Policy) []string {
if msg == nil {
return nil
}
// PSK round-trip echo: collapse the sentinel to empty so the
// shared check treats it as "field not set".
pskGot := ""
if msg.OptionalPreSharedKey != nil && *msg.OptionalPreSharedKey != preSharedKeyRedactedSentinel {
pskGot = *msg.OptionalPreSharedKey
}
return resolveConflicts(policy, []conflictCheck{
conflictString(mdm.KeyManagementURL, msg.ManagementUrl),
conflictString(mdm.KeyPreSharedKey, pskGot),
conflictBool(mdm.KeyRosenpassEnabled, msg.RosenpassEnabled),
conflictBool(mdm.KeyRosenpassPermissive, msg.RosenpassPermissive),
conflictBool(mdm.KeyDisableAutoConnect, msg.DisableAutoConnect),
conflictBool(mdm.KeyAllowServerSSH, msg.ServerSSHAllowed),
conflictBool(mdm.KeyDisableClientRoutes, msg.DisableClientRoutes),
conflictBool(mdm.KeyDisableServerRoutes, msg.DisableServerRoutes),
conflictBool(mdm.KeyBlockInbound, msg.BlockInbound),
conflictInt64(mdm.KeyWireguardPort, msg.WireguardPort),
})
}
// setConfigRequestHasConfigOverrides reports whether the SetConfigRequest
// carries ANY field that would actually mutate the persisted config.
// The CLI builds a SetConfigRequest unconditionally on every
// `netbird up` (see setupSetConfigReq in cmd/up.go) — a plain
// `netbird up` produces a request with every field at its zero value;
// the gate must skip such no-op invocations or it would always fire
// even when the user did not pass any --flag. Returns false on a nil
// msg; true when any management/admin URL, PSK, DNS/NAT list+clean
// flag, interface/port/MTU, or any optional bool/duration field is set.
func setConfigRequestHasConfigOverrides(msg *proto.SetConfigRequest) bool {
if msg == nil {
return false
}
return msg.ManagementUrl != "" ||
msg.AdminURL != "" ||
msg.OptionalPreSharedKey != nil ||
len(msg.CustomDNSAddress) > 0 ||
len(msg.NatExternalIPs) > 0 || msg.CleanNATExternalIPs ||
len(msg.ExtraIFaceBlacklist) > 0 ||
len(msg.DnsLabels) > 0 || msg.CleanDNSLabels ||
msg.DnsRouteInterval != nil ||
msg.RosenpassEnabled != nil ||
msg.RosenpassPermissive != nil ||
msg.InterfaceName != nil ||
msg.WireguardPort != nil ||
msg.Mtu != nil ||
msg.DisableAutoConnect != nil ||
msg.ServerSSHAllowed != nil ||
msg.NetworkMonitor != nil ||
msg.DisableClientRoutes != nil ||
msg.DisableServerRoutes != nil ||
msg.DisableDns != nil ||
msg.DisableFirewall != nil ||
msg.BlockLanAccess != nil ||
msg.DisableNotifications != nil ||
msg.LazyConnectionEnabled != nil ||
msg.BlockInbound != nil ||
msg.DisableIpv6 != nil ||
msg.EnableSSHRoot != nil ||
msg.EnableSSHSFTP != nil ||
msg.EnableSSHLocalPortForwarding != nil ||
msg.EnableSSHRemotePortForwarding != nil ||
msg.DisableSSHAuth != nil ||
msg.SshJWTCacheTTL != nil
}
// loginRequestHasConfigOverrides reports whether the LoginRequest
// carries ANY field that would mutate persisted daemon configuration
// (as opposed to pure-auth fields like setupKey, hostname, hint,
// profileName, username). Used by the Login handler to decide whether
// the `--disable-update-settings` / MDM gates must run: a re-auth that
// changes nothing about the configuration is always allowed.
func loginRequestHasConfigOverrides(msg *proto.LoginRequest) bool {
if msg == nil {
return false
}
return msg.ManagementUrl != "" ||
msg.AdminURL != "" ||
msg.PreSharedKey != "" || //nolint:staticcheck // SA1019: legacy proto field still accepted by Login
msg.OptionalPreSharedKey != nil ||
len(msg.CustomDNSAddress) > 0 ||
len(msg.NatExternalIPs) > 0 || msg.CleanNATExternalIPs ||
msg.RosenpassEnabled != nil ||
msg.InterfaceName != nil ||
msg.WireguardPort != nil ||
msg.DisableAutoConnect != nil ||
msg.ServerSSHAllowed != nil ||
msg.RosenpassPermissive != nil ||
len(msg.ExtraIFaceBlacklist) > 0 ||
msg.NetworkMonitor != nil ||
msg.DnsRouteInterval != nil ||
msg.DisableClientRoutes != nil ||
msg.DisableServerRoutes != nil ||
msg.DisableDns != nil ||
msg.DisableFirewall != nil ||
msg.BlockLanAccess != nil ||
msg.DisableNotifications != nil ||
len(msg.DnsLabels) > 0 || msg.CleanDNSLabels ||
msg.LazyConnectionEnabled != nil ||
msg.BlockInbound != nil
}
// loginRequestMDMConflicts mirrors mdmManagedFieldConflicts but for the
// LoginRequest surface. Same value-aware semantics: a field set to the
// MDM-enforced value is a no-op echo, not a conflict; only a divergent
// value is flagged. PSK has two proto fields — PreSharedKey (deprecated)
// and OptionalPreSharedKey (current); either route trips the gate if it
// diverges from the MDM-enforced PSK. OptionalPreSharedKey wins when
// both are set; the redaction sentinel ("**********") is accepted as
// a no-op echo.
func loginRequestMDMConflicts(msg *proto.LoginRequest, policy *mdm.Policy) []string {
if msg == nil {
return nil
}
// Collapse the two PSK fields + the redaction sentinel down to a
// single "got" string the shared check can compare against the
// policy: OptionalPreSharedKey wins if set; PreSharedKey (deprecated)
// is the fallback; sentinel echo is treated as "field not set".
pskGot := ""
if msg.OptionalPreSharedKey != nil {
pskGot = *msg.OptionalPreSharedKey
} else if msg.PreSharedKey != "" { //nolint:staticcheck // SA1019: legacy proto field still accepted by Login
pskGot = msg.PreSharedKey //nolint:staticcheck // SA1019
}
if pskGot == preSharedKeyRedactedSentinel {
pskGot = ""
}
return resolveConflicts(policy, []conflictCheck{
conflictString(mdm.KeyManagementURL, msg.ManagementUrl),
conflictString(mdm.KeyPreSharedKey, pskGot),
conflictBool(mdm.KeyRosenpassEnabled, msg.RosenpassEnabled),
conflictBool(mdm.KeyRosenpassPermissive, msg.RosenpassPermissive),
conflictBool(mdm.KeyDisableAutoConnect, msg.DisableAutoConnect),
conflictBool(mdm.KeyAllowServerSSH, msg.ServerSSHAllowed),
conflictBool(mdm.KeyDisableClientRoutes, msg.DisableClientRoutes),
conflictBool(mdm.KeyDisableServerRoutes, msg.DisableServerRoutes),
conflictBool(mdm.KeyBlockInbound, msg.BlockInbound),
conflictInt64(mdm.KeyWireguardPort, msg.WireguardPort),
})
}
// rejectMDMManagedFieldConflicts returns a FailedPrecondition gRPC error
// with an MDMManagedFieldsViolation detail when any of the requested
// fields tries to change an MDM-enforced value to something else, and
// nil otherwise. The whole request is rejected on any conflict; non-
// conflicting fields in the same request are not applied either (no
// partial apply).
func rejectMDMManagedFieldConflicts(conflicts []string) error {
if len(conflicts) == 0 {
return nil
}
log.Warnf("MDM rejected request: tried to modify %d managed key(s): %v",
len(conflicts), conflicts)
st := gstatus.New(
codes.FailedPrecondition,
fmt.Sprintf("fields managed by MDM cannot be modified: %v", conflicts),
)
detailed, err := st.WithDetails(&proto.MDMManagedFieldsViolation{Fields: conflicts})
if err != nil {
// Detail attachment is best-effort; fall back to the plain status
// so the caller still gets a usable FailedPrecondition.
return st.Err()
}
return detailed.Err()
}

View File

@@ -30,7 +30,7 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
s.mutex.Lock()
defer s.mutex.Unlock()
if s.networksDisabled {
if s.checkNetworksDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
}
@@ -143,7 +143,7 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
s.mutex.Lock()
defer s.mutex.Unlock()
if s.networksDisabled {
if s.checkNetworksDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
}
@@ -195,7 +195,7 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
s.mutex.Lock()
defer s.mutex.Unlock()
if s.networksDisabled {
if s.checkNetworksDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
}

View File

@@ -24,6 +24,7 @@ import (
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/internal/profilemanager"
sleephandler "github.com/netbirdio/netbird/client/internal/sleep/handler"
"github.com/netbirdio/netbird/client/mdm"
"github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain"
@@ -71,7 +72,13 @@ type Server struct {
mutex sync.Mutex
config *profilemanager.Config
proto.UnimplementedDaemonServiceServer
clientRunning bool // protected by mutex
// clientRunning tracks "the daemon wants to be connected" — set true by
// Start / Up, cleared by Down / Logout. Persists across retry
// loops, signal disconnects, and ErrResetConnection cycles. NOT
// changed by connectWithRetryRuns goroutine exit — for that
// (goroutine-still-alive) check, see connectionGoroutineRunning() which
// derives from clientGiveUpChan close state. Protected by s.mutex.
clientRunning bool
clientRunningChan chan struct{}
clientGiveUpChan chan struct{} // closed when connectWithRetryRuns goroutine exits
@@ -93,15 +100,16 @@ type Server struct {
captureEnabled bool
bundleCapture *bundleCapture
// activeCapture is the session currently installed on the engine; guarded by s.mutex.
activeCapture *capture.Session
// activeCaptureCancel tears down the streaming pipe/cancel for the
// active streaming capture so eviction unblocks the StartCapture RPC
// handler. Nil for bundle captures (they own their own context).
activeCaptureCancel func()
networksDisabled bool
activeCapture *capture.Session
networksDisabled bool
sleepHandler *sleephandler.SleepHandler
// mdmTicker periodically re-reads the OS-native MDM policy and triggers
// an engine restart when the policy changes. Launched once by Start;
// stopped by the rootCtx cancellation.
mdmTicker *mdm.Ticker
updateManager *updater.Manager
jwtCache *jwtCache
@@ -159,6 +167,17 @@ func (s *Server) Start() error {
s.updateManager.CheckUpdateSuccess(s.rootCtx)
}
// MDM policy reload ticker: every minute the desktop daemon re-reads
// the OS-native managed-config store and, on diff vs the previous
// observation, cancels the active engine context so connectWithRetry-
// Runs re-resolves Config (re-running profilemanager.Config.apply which
// applies the freshly-read MDM policy as the last layer) and brings
// the engine back with the new values.
if s.mdmTicker == nil {
s.mdmTicker = mdm.NewTicker(mdm.DefaultReloadInterval)
go s.mdmTicker.Run(s.rootCtx, s.onMDMPolicyChange)
}
// if current state contains any error, return it
// in all other cases we can continue execution only if status is idle and up command was
// not in the progress or already successfully established connection.
@@ -217,17 +236,27 @@ func (s *Server) Start() error {
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
s.publishConfigChangedEvent("startup")
return nil
}
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
//
// The goroutine's exit is signalled to the daemon via close(giveUpChan)
// — placed in the function-scope defer so every return path (panic,
// DisableAutoConnect early-exit, backoff exhausted, ctx cancel) closes
// it. Callers that need to observe "is the goroutine still alive?" use
// Server.connectionGoroutineRunning() which non-blockingly checks the close state
// of clientGiveUpChan. The defer does NOT touch s.mutex; the daemon's
// "intent" (clientRunning) is maintained by the RPC handlers, not by this
// goroutine.
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
defer func() {
s.mutex.Lock()
s.clientRunning = false
s.mutex.Unlock()
if giveUpChan != nil {
close(giveUpChan)
}
}()
if s.config.DisableAutoConnect {
@@ -273,9 +302,26 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
if err := backoff.Retry(runOperation, backOff); err != nil {
log.Errorf("operation failed: %v", err)
}
// giveUpChan is closed by the function-scope defer.
}
if giveUpChan != nil {
close(giveUpChan)
// connectionGoroutineRunning reports whether the connectWithRetryRuns goroutine is
// still running. Returns false when no goroutine has ever been started
// AND when the most recent one has already closed clientGiveUpChan on
// exit (whether due to ctx cancel, DisableAutoConnect single-shot
// completion, or backoff retry exhaustion).
//
// MUST be called with s.mutex held — accesses s.clientGiveUpChan which
// is written by Start/Up under the same lock.
func (s *Server) connectionGoroutineRunning() bool {
if s.clientGiveUpChan == nil {
return false
}
select {
case <-s.clientGiveUpChan:
return false
default:
return true
}
}
@@ -308,54 +354,85 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
s.mutex.Lock()
defer s.mutex.Unlock()
if s.checkUpdateSettingsDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
// Skip the update-settings gate when the request carries no actual
// overrides: the CLI builds a SetConfigRequest unconditionally on
// every `netbird up` (setupSetConfigReq in cmd/up.go), so a plain
// `netbird up` would otherwise always trip the gate and surface a
// misleading "setConfig method is not available" warning, even when
// the user did not pass any config flag.
if setConfigRequestHasConfigOverrides(msg) {
if s.checkUpdateSettingsDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
}
}
// MDM gate: refuse the whole request if any of its fields is enforced
// by the active MDM policy. The error carries an MDMManagedFields-
// Violation detail listing the offending key names. Non-conflicting
// fields in the same request are not applied either.
policy := loadMDMPolicy()
if err := rejectMDMManagedFieldConflicts(mdmManagedFieldConflicts(msg, policy)); err != nil {
return nil, err
}
config, err := setConfigInputFromRequest(msg)
if err != nil {
return nil, err
}
if _, err := profilemanager.UpdateConfig(config); err != nil {
log.Errorf("failed to update profile config: %v", err)
return nil, fmt.Errorf("failed to update profile config: %w", err)
}
return &proto.SetConfigResponse{}, nil
}
// setConfigInputFromRequest translates a SetConfigRequest into the
// profilemanager.ConfigInput that profilemanager.UpdateConfig consumes.
// Pure mapping with no business logic beyond presence-aware copying of
// optional fields and the "empty / clean" semantics for the two slice
// fields (DNS labels, NAT external IPs). Extracted from SetConfig to
// keep the handler's cognitive complexity below the SonarCube
// threshold; the body is intentionally linear because each proto
// field is its own optional case. Returns the resolved ConfigInput
// and a non-nil error only when the active profile file path cannot
// be determined.
func setConfigInputFromRequest(msg *proto.SetConfigRequest) (profilemanager.ConfigInput, error) {
var config profilemanager.ConfigInput
profState := profilemanager.ActiveProfileState{
Name: msg.ProfileName,
Username: msg.Username,
}
profPath, err := profState.FilePath()
if err != nil {
log.Errorf("failed to get active profile file path: %v", err)
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
return config, fmt.Errorf("failed to get active profile file path: %w", err)
}
var config profilemanager.ConfigInput
config.ConfigPath = profPath
if msg.ManagementUrl != "" {
config.ManagementURL = msg.ManagementUrl
}
if msg.AdminURL != "" {
config.AdminURL = msg.AdminURL
}
if msg.InterfaceName != nil {
config.InterfaceName = msg.InterfaceName
}
if msg.WireguardPort != nil {
wgPort := int(*msg.WireguardPort)
config.WireguardPort = &wgPort
}
if msg.OptionalPreSharedKey != nil {
if *msg.OptionalPreSharedKey != "" {
config.PreSharedKey = msg.OptionalPreSharedKey
}
if msg.OptionalPreSharedKey != nil && *msg.OptionalPreSharedKey != "" {
config.PreSharedKey = msg.OptionalPreSharedKey
}
if msg.CleanDNSLabels {
config.DNSLabels = domain.List{}
} else if msg.DnsLabels != nil {
dnsLabels := domain.FromPunycodeList(msg.DnsLabels)
config.DNSLabels = dnsLabels
config.DNSLabels = domain.FromPunycodeList(msg.DnsLabels)
}
if msg.CleanNATExternalIPs {
@@ -368,7 +445,6 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
if string(msg.CustomDNSAddress) == "empty" {
config.CustomDNSAddress = []byte{}
}
config.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
if msg.DnsRouteInterval != nil {
@@ -380,8 +456,6 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.RosenpassPermissive = msg.RosenpassPermissive
config.DisableAutoConnect = msg.DisableAutoConnect
config.ServerSSHAllowed = msg.ServerSSHAllowed
config.ServerVNCAllowed = msg.ServerVNCAllowed
config.DisableVNCApproval = msg.DisableVNCApproval
config.NetworkMonitor = msg.NetworkMonitor
config.DisableClientRoutes = msg.DisableClientRoutes
config.DisableServerRoutes = msg.DisableServerRoutes
@@ -403,22 +477,31 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
ttl := int(*msg.SshJWTCacheTTL)
config.SSHJWTCacheTTL = &ttl
}
if msg.Mtu != nil {
mtu := uint16(*msg.Mtu)
config.MTU = &mtu
}
if _, err := profilemanager.UpdateConfig(config); err != nil {
log.Errorf("failed to update profile config: %v", err)
return nil, fmt.Errorf("failed to update profile config: %w", err)
}
return &proto.SetConfigResponse{}, nil
return config, nil
}
// Login uses setup key to prepare configuration for the daemon.
func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*proto.LoginResponse, error) {
// Config-override gates. LoginRequest carries the same surface as
// SetConfigRequest (managementUrl, PSK, ssh/rosenpass/port toggles,
// ...), so the same protections must apply. Without these the CLI
// command `netbird up --management-url=X` (which falls through to
// Login when SetConfig is rejected — see cmd/up.go) would silently
// bypass `--disable-update-settings` and any MDM policy.
if loginRequestHasConfigOverrides(msg) {
if s.checkUpdateSettingsDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
}
policy := loadMDMPolicy()
if err := rejectMDMManagedFieldConflicts(loginRequestMDMConflicts(msg, policy)); err != nil {
return nil, err
}
}
s.mutex.Lock()
if s.actCancel != nil {
s.actCancel()
@@ -658,7 +741,13 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
// Up starts engine work in the daemon.
func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) {
s.mutex.Lock()
if s.clientRunning {
// clientRunning is the daemon-intent flag (set by previous Up/Start, cleared
// by Down). connectionGoroutineRunning() reports whether the previous retry-loop
// goroutine is still trying. When intent is up AND goroutine is alive,
// the existing engine is on the job — just wait for it. When intent
// is up but the goroutine has given up (backoff exhausted) OR when
// intent is down, fall through to spawn a fresh retry loop.
if s.clientRunning && s.connectionGoroutineRunning() {
state := internal.CtxGetState(s.rootCtx)
status, err := state.Status()
if err != nil {
@@ -749,6 +838,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
s.clientGiveUpChan = make(chan struct{})
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
s.publishConfigChangedEvent("up_rpc")
s.mutex.Unlock()
return s.waitForUp(callerCtx)
@@ -877,6 +967,12 @@ func (s *Server) cleanupConnection() error {
return ErrServiceNotUp
}
// Daemon intent flips to "down" — all callers (Down RPC,
// Logout RPC handlers) tear down the connection because the user
// explicitly asked for it. MDM restart does NOT go through this
// path, so its clientRunning stays true.
s.clientRunning = false
// Capture the engine reference before cancelling the context.
// After actCancel(), the connectWithRetryRuns goroutine wakes up
// and sets connectClient.engine = nil, causing connectClient.Stop()
@@ -1080,10 +1176,14 @@ func (s *Server) Status(
msg *proto.StatusRequest,
) (*proto.StatusResponse, error) {
s.mutex.Lock()
clientRunning := s.clientRunning
// Only wait if the retry-loop goroutine is alive and making
// progress. clientRunning=true with connectionGoroutineRunning=false means the
// backoff has given up — there is nothing to wait for; let the
// caller observe the failed status directly.
alive := s.connectionGoroutineRunning()
s.mutex.Unlock()
if msg.WaitForReady != nil && *msg.WaitForReady && clientRunning {
if msg.WaitForReady != nil && *msg.WaitForReady && alive {
state := internal.CtxGetState(s.rootCtx)
status, err := state.Status()
if err != nil {
@@ -1142,7 +1242,6 @@ func (s *Server) Status(
pbFullStatus := fullStatus.ToProto()
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
pbFullStatus.SshServerState = s.getSSHServerState()
pbFullStatus.VncServerState = s.getVNCServerState()
statusResponse.FullStatus = pbFullStatus
}
@@ -1182,38 +1281,6 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
return sshServerState
}
// getVNCServerState retrieves the current VNC server state.
func (s *Server) getVNCServerState() *proto.VNCServerState {
s.mutex.Lock()
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return nil
}
engine := connectClient.Engine()
if engine == nil {
return nil
}
enabled, sessions := engine.GetVNCServerStatus()
pbSessions := make([]*proto.VNCSessionInfo, 0, len(sessions))
for _, sess := range sessions {
pbSessions = append(pbSessions, &proto.VNCSessionInfo{
RemoteAddress: sess.RemoteAddress,
Mode: sess.Mode,
Username: sess.Username,
UserID: sess.UserID,
Initiator: sess.Initiator,
})
}
return &proto.VNCServerState{
Enabled: enabled,
Sessions: pbSessions,
}
}
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
func (s *Server) GetPeerSSHHostKey(
ctx context.Context,
@@ -1454,27 +1521,6 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon
return nil
}
// RespondApproval relays the user's accept/deny decision for a pending
// approval prompt to the engine's broker. Unknown or already-resolved
// request_ids are silently no-op'd so a slow UI cannot deny a prompt the
// user already handled (or that already timed out).
func (s *Server) RespondApproval(_ context.Context, msg *proto.RespondApprovalRequest) (*proto.RespondApprovalResponse, error) {
s.mutex.Lock()
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "client not initialized")
}
engine := connectClient.Engine()
if engine == nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "engine not running")
}
if !engine.RespondApproval(msg.GetRequestId(), msg.GetAccept(), msg.GetViewOnly()) {
log.Debugf("approval response for unknown request_id %s", msg.GetRequestId())
}
return &proto.RespondApprovalResponse{}, nil
}
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false
@@ -1591,8 +1637,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed,
ServerVNCAllowed: cfg.ServerVNCAllowed != nil && *cfg.ServerVNCAllowed,
DisableVNCApproval: cfg.DisableVNCApproval != nil && *cfg.DisableVNCApproval,
RosenpassEnabled: cfg.RosenpassEnabled,
RosenpassPermissive: cfg.RosenpassPermissive,
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
@@ -1610,6 +1654,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
DisableSSHAuth: disableSSHAuth,
SshJWTCacheTTL: sshJWTCacheTTL,
MDMManagedFields: cfg.Policy().ManagedKeys(),
}, nil
}
@@ -1708,7 +1753,7 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
features := &proto.GetFeaturesResponse{
DisableProfiles: s.checkProfilesDisabled(),
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
DisableNetworks: s.networksDisabled,
DisableNetworks: s.checkNetworksDisabled(),
}
return features, nil
@@ -1730,22 +1775,46 @@ func (s *Server) connect(ctx context.Context, config *profilemanager.Config, sta
return nil
}
// MDM authority: when the platform-native MDM source sets a kill switch
// key (regardless of true/false value), that value wins. The CLI flag
// supplied at service install time is the fallback used only when the
// MDM source is silent on the key. This honors the "MDM decides
// everything" semantic agreed for NET-1214 — an admin pushing
// disableX=false via MDM explicitly re-enables the feature even on a
// box installed with --disable-X.
func (s *Server) checkProfilesDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.profilesDisabled {
return true
if s.config != nil {
if v, ok := s.config.Policy().GetBool(mdm.KeyDisableProfiles); ok {
return v
}
}
return s.profilesDisabled
}
return false
// checkNetworksDisabled reports whether the networks/exit-node feature
// is disabled on this daemon instance. Resolved MDM-first: when the
// active policy declares mdm.KeyDisableNetworks the policy value wins
// (regardless of true/false), so an admin can re-enable the feature
// via MDM even on a host that was installed with --disable-networks.
// Falls back to the s.networksDisabled CLI flag when the policy is
// silent on the key. Mirrors checkProfilesDisabled and
// checkUpdateSettingsDisabled.
func (s *Server) checkNetworksDisabled() bool {
if s.config != nil {
if v, ok := s.config.Policy().GetBool(mdm.KeyDisableNetworks); ok {
return v
}
}
return s.networksDisabled
}
func (s *Server) checkUpdateSettingsDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.updateSettingsDisabled {
return true
if s.config != nil {
if v, ok := s.config.Policy().GetBool(mdm.KeyDisableUpdateSettings); ok {
return v
}
}
return false
return s.updateSettingsDisabled
}
func (s *Server) startUpdateManagerForGUI() {

View File

@@ -101,6 +101,7 @@ func TestCleanupConnection_ClearsConnectClient(t *testing.T) {
require.NoError(t, err)
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
assert.False(t, s.clientRunning, "clientRunning should be cleared after cleanup (intent = down)")
}
// TestCleanState_NilConnectClient validates that CleanState doesn't panic
@@ -144,17 +145,20 @@ func TestDownThenUp_StaleRunningChan(t *testing.T) {
_, cancel := context.WithCancel(context.Background())
s.actCancel = cancel
// Simulate Down(): cleanupConnection sets connectClient = nil
// Simulate Down(): cleanupConnection sets connectClient = nil and
// flips clientRunning to false (intent = down). The connectionGoroutineRunning state
// remains independent of intent — derived from clientGiveUpChan.
s.mutex.Lock()
err := s.cleanupConnection()
s.mutex.Unlock()
require.NoError(t, err)
// After cleanup: connectClient is nil, clientRunning still true
// (goroutine hasn't exited yet)
// After cleanup: connectClient is nil, clientRunning is false (intent
// cleared by cleanupConnection), connectionGoroutineRunning may still be true
// (goroutine teardown is independent of the intent flag).
s.mutex.Lock()
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
assert.True(t, s.clientRunning, "clientRunning still true until goroutine exits")
assert.False(t, s.clientRunning, "clientRunning should be cleared by cleanupConnection (intent = down)")
s.mutex.Unlock()
// waitForUp() returns immediately due to stale closed clientRunningChan

View File

@@ -0,0 +1,198 @@
package server
import (
"context"
"os/user"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/mdm"
"github.com/netbirdio/netbird/client/proto"
)
// withMDMPolicy temporarily overrides the server-package loadMDMPolicy hook
// so SetConfig observes the supplied Policy. Restores the original loader
// at test cleanup.
func withMDMPolicy(t *testing.T, policy *mdm.Policy) {
t.Helper()
prev := loadMDMPolicy
loadMDMPolicy = func() *mdm.Policy { return policy }
t.Cleanup(func() { loadMDMPolicy = prev })
}
// setupServerWithProfile mirrors the boilerplate of TestSetConfig_AllFieldsSaved:
// overrides profilemanager paths to a temp dir, seeds a profile, sets it
// active, and constructs a Server instance. Returns the constructed server
// plus context + profile name + username + cfgPath for the seeded profile.
func setupServerWithProfile(t *testing.T) (s *Server, ctx context.Context, profName, username, cfgPath string) {
t.Helper()
tempDir := t.TempDir()
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
origDefaultConfigPath := profilemanager.DefaultConfigPath
origActiveProfileStatePath := profilemanager.ActiveProfileStatePath
profilemanager.ConfigDirOverride = tempDir
profilemanager.DefaultConfigPathDir = tempDir
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json")
t.Cleanup(func() {
profilemanager.DefaultConfigPathDir = origDefaultProfileDir
profilemanager.ActiveProfileStatePath = origActiveProfileStatePath
profilemanager.DefaultConfigPath = origDefaultConfigPath
profilemanager.ConfigDirOverride = ""
})
currUser, err := user.Current()
require.NoError(t, err)
profName = "test-profile-mdm"
cfgPath = filepath.Join(tempDir, profName+".json")
_, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: cfgPath,
ManagementURL: "https://api.netbird.io:443",
})
require.NoError(t, err)
pm := profilemanager.ServiceManager{}
require.NoError(t, pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profName,
Username: currUser.Username,
}))
ctx = context.Background()
s = New(ctx, "console", "", false, false, false, false)
return s, ctx, profName, currUser.Username, cfgPath
}
// extractViolation pulls the MDMManagedFieldsViolation detail from a
// FailedPrecondition error. Fails the test if absent or malformed.
func extractViolation(t *testing.T, err error) *proto.MDMManagedFieldsViolation {
t.Helper()
require.Error(t, err)
st, ok := gstatus.FromError(err)
require.True(t, ok, "error must be a gRPC status: %v", err)
require.Equal(t, codes.FailedPrecondition, st.Code(), "expected FailedPrecondition, got %s", st.Code())
for _, d := range st.Details() {
if v, ok := d.(*proto.MDMManagedFieldsViolation); ok {
return v
}
}
t.Fatalf("MDMManagedFieldsViolation detail not found on status; details: %v", st.Details())
return nil
}
func TestSetConfig_MDMReject_SingleField(t *testing.T) {
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: "https://mdm.example.com:443",
}))
s, ctx, profName, username, _ := setupServerWithProfile(t)
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
ProfileName: profName,
Username: username,
ManagementUrl: "https://user.tried.this.com:443",
})
v := extractViolation(t, err)
assert.Equal(t, []string{mdm.KeyManagementURL}, v.GetFields())
}
func TestSetConfig_MDMReject_MultipleFields(t *testing.T) {
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: "https://mdm.example.com:443",
mdm.KeyBlockInbound: true,
mdm.KeyRosenpassEnabled: true,
}))
s, ctx, profName, username, _ := setupServerWithProfile(t)
blockInbound := false
rosenpassEnabled := false
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
ProfileName: profName,
Username: username,
ManagementUrl: "https://user.tried.this.com:443",
BlockInbound: &blockInbound,
RosenpassEnabled: &rosenpassEnabled,
})
v := extractViolation(t, err)
assert.ElementsMatch(t, []string{
mdm.KeyManagementURL,
mdm.KeyBlockInbound,
mdm.KeyRosenpassEnabled,
}, v.GetFields())
}
func TestSetConfig_MDMReject_AllOrNothing(t *testing.T) {
// MDM enforces ManagementURL only; user request touches both the
// enforced field AND a non-enforced field (RosenpassEnabled).
// The whole request must be rejected — non-conflicting fields are not
// applied either.
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: "https://mdm.example.com:443",
}))
s, ctx, profName, username, cfgPath := setupServerWithProfile(t)
rosenpassEnabled := true
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
ProfileName: profName,
Username: username,
ManagementUrl: "https://user.tried.this.com:443",
RosenpassEnabled: &rosenpassEnabled,
})
v := extractViolation(t, err)
assert.Equal(t, []string{mdm.KeyManagementURL}, v.GetFields())
// Confirm RosenpassEnabled was NOT applied even though it was not
// in the conflict list: the request was rejected as a whole.
reloaded, err := profilemanager.GetConfig(cfgPath)
require.NoError(t, err)
assert.False(t, reloaded.RosenpassEnabled, "non-conflicting field must not be applied when request is rejected")
}
func TestSetConfig_MDMAllow_NonManagedFields(t *testing.T) {
// MDM enforces ManagementURL but the user only writes RosenpassEnabled.
// Request must succeed.
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: "https://mdm.example.com:443",
}))
s, ctx, profName, username, _ := setupServerWithProfile(t)
rosenpassEnabled := true
resp, err := s.SetConfig(ctx, &proto.SetConfigRequest{
ProfileName: profName,
Username: username,
RosenpassEnabled: &rosenpassEnabled,
})
require.NoError(t, err)
require.NotNil(t, resp)
}
func TestSetConfig_MDMEmpty_NoEnforcement(t *testing.T) {
// No MDM policy active: any field can be written.
withMDMPolicy(t, mdm.NewPolicy(nil))
s, ctx, profName, username, _ := setupServerWithProfile(t)
resp, err := s.SetConfig(ctx, &proto.SetConfigRequest{
ProfileName: profName,
Username: username,
ManagementUrl: "https://user.changed.url.com:443",
})
require.NoError(t, err)
require.NotNil(t, resp)
}

View File

@@ -58,8 +58,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
rosenpassEnabled := true
rosenpassPermissive := true
serverSSHAllowed := true
serverVNCAllowed := true
disableVNCApproval := true
interfaceName := "utun100"
wireguardPort := int64(51820)
preSharedKey := "test-psk"
@@ -85,8 +83,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
RosenpassEnabled: &rosenpassEnabled,
RosenpassPermissive: &rosenpassPermissive,
ServerSSHAllowed: &serverSSHAllowed,
ServerVNCAllowed: &serverVNCAllowed,
DisableVNCApproval: &disableVNCApproval,
InterfaceName: &interfaceName,
WireguardPort: &wireguardPort,
OptionalPreSharedKey: &preSharedKey,
@@ -131,10 +127,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
require.NotNil(t, cfg.ServerSSHAllowed)
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
require.NotNil(t, cfg.ServerVNCAllowed)
require.Equal(t, serverVNCAllowed, *cfg.ServerVNCAllowed)
require.NotNil(t, cfg.DisableVNCApproval)
require.Equal(t, disableVNCApproval, *cfg.DisableVNCApproval)
require.Equal(t, interfaceName, cfg.WgIface)
require.Equal(t, int(wireguardPort), cfg.WgPort)
require.Equal(t, preSharedKey, cfg.PreSharedKey)
@@ -187,8 +179,6 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"ServerVNCAllowed": true,
"DisableVNCApproval": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
@@ -250,8 +240,6 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
"enable-rosenpass": "RosenpassEnabled",
"rosenpass-permissive": "RosenpassPermissive",
"allow-server-ssh": "ServerSSHAllowed",
"allow-server-vnc": "ServerVNCAllowed",
"disable-vnc-approval": "DisableVNCApproval",
"interface-name": "InterfaceName",
"wireguard-port": "WireguardPort",
"preshared-key": "OptionalPreSharedKey",

View File

@@ -1,4 +1,4 @@
package sessionauth
package auth
import (
"errors"
@@ -15,8 +15,6 @@ const (
DefaultUserIDClaim = "sub"
// Wildcard is a special user ID that matches all users
Wildcard = "*"
// sessionPubKeyLen is the size of an X25519 static public key in bytes.
sessionPubKeyLen = 32
)
var (
@@ -24,7 +22,6 @@ var (
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
ErrSessionKeyNotKnown = errors.New("session pubkey not registered")
)
// Authorizer handles SSH fine-grained access control authorization
@@ -38,17 +35,6 @@ type Authorizer struct {
// machineUsers maps OS login usernames to lists of authorized user indexes
machineUsers map[string][]uint32
// sessionPubKeys maps an X25519 static public key (as map-safe
// array) to the hashed user identity that key authenticates as.
// Populated from management's temporary-access flow; used by VNC to
// authenticate via the Noise_IK handshake.
sessionPubKeys map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash
// sessionDisplayNames mirrors sessionPubKeys with the optional
// human-readable display name management associated with each
// session key. Used by the per-connection UI approval prompt; not
// consulted by any authorization decision.
sessionDisplayNames map[[sessionPubKeyLen]byte]string
// mu protects the list of users
mu sync.RWMutex
}
@@ -64,29 +50,13 @@ type Config struct {
// MachineUsers maps OS login usernames to indexes in AuthorizedUsers
// If a user wants to login as a specific OS user, their index must be in the corresponding list
MachineUsers map[string][]uint32
// SessionPubKeys binds ephemeral X25519 static public keys to hashed
// user identities. Populated for VNC; ignored on the SSH side.
SessionPubKeys []SessionPubKey
}
// SessionPubKey is a single ephemeral-key entry: the 32-byte X25519
// static public key plus the hashed user identity it authenticates as,
// optionally plus a human-readable display name for the UI approval
// prompt to identify the requester.
type SessionPubKey struct {
PubKey []byte
UserIDHash sshuserhash.UserIDHash
DisplayName string
}
// NewAuthorizer creates a new SSH authorizer with empty configuration
func NewAuthorizer() *Authorizer {
a := &Authorizer{
userIDClaim: DefaultUserIDClaim,
machineUsers: make(map[string][]uint32),
sessionPubKeys: make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash),
sessionDisplayNames: make(map[[sessionPubKeyLen]byte]string),
userIDClaim: DefaultUserIDClaim,
machineUsers: make(map[string][]uint32),
}
return a
@@ -102,8 +72,6 @@ func (a *Authorizer) Update(config *Config) {
a.userIDClaim = DefaultUserIDClaim
a.authorizedUsers = []sshuserhash.UserIDHash{}
a.machineUsers = make(map[string][]uint32)
a.sessionPubKeys = make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash)
a.sessionDisplayNames = make(map[[sessionPubKeyLen]byte]string)
log.Info("SSH authorization cleared")
return
}
@@ -126,35 +94,8 @@ func (a *Authorizer) Update(config *Config) {
}
a.machineUsers = machineUsers
sessionPubKeys := make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash, len(config.SessionPubKeys))
sessionDisplayNames := make(map[[sessionPubKeyLen]byte]string, len(config.SessionPubKeys))
conflicted := make(map[[sessionPubKeyLen]byte]struct{})
for _, e := range config.SessionPubKeys {
if len(e.PubKey) != sessionPubKeyLen {
continue
}
var key [sessionPubKeyLen]byte
copy(key[:], e.PubKey)
if _, bad := conflicted[key]; bad {
continue
}
if existing, ok := sessionPubKeys[key]; ok && existing != e.UserIDHash {
log.Warnf("SSH auth: session pubkey bound to conflicting user hashes; dropping binding")
delete(sessionPubKeys, key)
delete(sessionDisplayNames, key)
conflicted[key] = struct{}{}
continue
}
sessionPubKeys[key] = e.UserIDHash
if e.DisplayName != "" {
sessionDisplayNames[key] = e.DisplayName
}
}
a.sessionPubKeys = sessionPubKeys
a.sessionDisplayNames = sessionDisplayNames
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings, %d session pubkeys",
len(config.AuthorizedUsers), len(machineUsers), len(sessionPubKeys))
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings",
len(config.AuthorizedUsers), len(machineUsers))
}
// Authorize validates if a user is authorized to login as the specified OS user.
@@ -214,54 +155,6 @@ func (a *Authorizer) GetUserIDClaim() string {
return a.userIDClaim
}
// LookupSessionKey resolves a Noise-verified static public key to the
// hashed user identity registered with it. Fails closed when the key is
// unknown.
func (a *Authorizer) LookupSessionKey(pubKey []byte) (sshuserhash.UserIDHash, error) {
var zero sshuserhash.UserIDHash
if len(pubKey) != sessionPubKeyLen {
return zero, fmt.Errorf("session pubkey wrong length: %d", len(pubKey))
}
var key [sessionPubKeyLen]byte
copy(key[:], pubKey)
a.mu.RLock()
hash, ok := a.sessionPubKeys[key]
a.mu.RUnlock()
if !ok {
return zero, ErrSessionKeyNotKnown
}
return hash, nil
}
// LookupSessionDisplayName returns the human-readable display name
// management associated with a session pubkey, or empty string when none
// is recorded. Never returns an error: a missing/unknown key reports as
// "" and the caller falls back to other identifiers.
func (a *Authorizer) LookupSessionDisplayName(pubKey []byte) string {
if len(pubKey) != sessionPubKeyLen {
return ""
}
var key [sessionPubKeyLen]byte
copy(key[:], pubKey)
a.mu.RLock()
name := a.sessionDisplayNames[key]
a.mu.RUnlock()
return name
}
// AuthorizeOSUserBySessionKey resolves the OS-user mapping for a session
// key. Mirrors Authorize but skips the JWT-hash step since the key has
// already been verified and the user identity hash is in hand.
func (a *Authorizer) AuthorizeOSUserBySessionKey(userIDHash sshuserhash.UserIDHash, osUsername string) (string, error) {
a.mu.RLock()
defer a.mu.RUnlock()
userIndex, found := a.findUserIndex(userIDHash)
if !found {
return "", fmt.Errorf("session user (hash: %s) not in authorized list for OS user %q: %w", userIDHash, osUsername, ErrUserNotAuthorized)
}
return a.checkMachineUserMapping("session", osUsername, userIndex)
}
// findUserIndex finds the index of a hashed user ID in the authorized users list
// Returns the index and true if found, 0 and false if not found
func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) {

View File

@@ -1,7 +1,6 @@
package sessionauth
package auth
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
@@ -611,61 +610,3 @@ func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
}
func TestAuthorizer_LookupSessionKey_Valid(t *testing.T) {
pub := bytesRepeat(0x11, sessionPubKeyLen)
userHash, err := sshauth.HashUserID("alice")
require.NoError(t, err)
a := NewAuthorizer()
a.Update(&Config{
AuthorizedUsers: []sshauth.UserIDHash{userHash},
MachineUsers: map[string][]uint32{Wildcard: {0}},
SessionPubKeys: []SessionPubKey{{PubKey: pub, UserIDHash: userHash}},
})
got, err := a.LookupSessionKey(pub)
require.NoError(t, err)
assert.Equal(t, userHash, got)
if _, err := a.AuthorizeOSUserBySessionKey(got, "alice"); err != nil {
t.Fatalf("AuthorizeOSUserBySessionKey: %v", err)
}
}
func TestAuthorizer_LookupSessionKey_UnknownPub(t *testing.T) {
a := NewAuthorizer()
a.Update(&Config{})
_, err := a.LookupSessionKey(bytesRepeat(0x22, sessionPubKeyLen))
require.ErrorIs(t, err, ErrSessionKeyNotKnown)
}
func TestAuthorizer_LookupSessionKey_WrongLength(t *testing.T) {
a := NewAuthorizer()
_, err := a.LookupSessionKey([]byte("short"))
require.Error(t, err)
}
func TestAuthorizer_LookupSessionKey_UpdateClears(t *testing.T) {
pub := bytesRepeat(0x33, sessionPubKeyLen)
userHash, err := sshauth.HashUserID("alice")
require.NoError(t, err)
a := NewAuthorizer()
a.Update(&Config{SessionPubKeys: []SessionPubKey{{PubKey: pub, UserIDHash: userHash}}})
if _, err := a.LookupSessionKey(pub); err != nil {
t.Fatalf("setup lookup: %v", err)
}
a.Update(&Config{})
if _, err := a.LookupSessionKey(pub); !errors.Is(err, ErrSessionKeyNotKnown) {
t.Fatalf("expected ErrSessionKeyNotKnown, got %v", err)
}
}
func bytesRepeat(b byte, n int) []byte {
out := make([]byte, n)
for i := range out {
out[i] = b
}
return out
}

View File

@@ -28,10 +28,10 @@ import (
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)

View File

@@ -23,11 +23,11 @@ import (
"github.com/stretchr/testify/require"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)

View File

@@ -23,10 +23,10 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/auth/jwt"
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
"github.com/netbirdio/netbird/util/netrelay"
"github.com/netbirdio/netbird/version"
)
@@ -197,14 +197,6 @@ type Config struct {
// HostKey is the SSH server host key in PEM format
HostKeyPEM []byte
// NetstackNet, when non-nil, makes the SSH server listen via the
// supplied userspace network stack instead of an OS socket.
NetstackNet *netstack.Net
// NetworkValidation, when non-zero, restricts inbound connections to
// peers inside the NetBird overlay defined by this WireGuard address.
NetworkValidation wgaddr.Address
}
// SessionInfo contains information about an active SSH session
@@ -216,15 +208,12 @@ type SessionInfo struct {
PortForwards []string
}
// New creates an SSH server instance from the supplied Config. Fields are
// read once at construction; mutating Config afterwards has no effect.
// JWT == nil disables JWT authentication.
// New creates an SSH server instance with the provided host key and optional JWT configuration
// If jwtConfig is nil, JWT authentication is disabled
func New(config *Config) *Server {
s := &Server{
mu: sync.RWMutex{},
hostKeyPEM: config.HostKeyPEM,
netstackNet: config.NetstackNet,
wgAddress: config.NetworkValidation,
sessions: make(map[sessionKey]*sessionState),
pendingAuthJWT: make(map[authKey]string),
remoteForwardListeners: make(map[forwardKey]net.Listener),
@@ -445,6 +434,20 @@ func (s *Server) buildSessionInfo(state *sessionState) SessionInfo {
return info
}
// SetNetstackNet sets the netstack network for userspace networking
func (s *Server) SetNetstackNet(net *netstack.Net) {
s.mu.Lock()
defer s.mu.Unlock()
s.netstackNet = net
}
// SetNetworkValidation configures network-based connection filtering
func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
s.mu.Lock()
defer s.mu.Unlock()
s.wgAddress = addr
}
// UpdateSSHAuth updates the SSH fine-grained access control configuration
// This should be called when network map updates include new SSH auth configuration
func (s *Server) UpdateSSHAuth(config *sshauth.Config) {

View File

@@ -131,19 +131,6 @@ type SSHServerStateOutput struct {
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
}
type VNCSessionOutput struct {
RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"`
Mode string `json:"mode" yaml:"mode"`
Username string `json:"username,omitempty" yaml:"username,omitempty"`
UserID string `json:"userID,omitempty" yaml:"userID,omitempty"`
Initiator string `json:"initiator,omitempty" yaml:"initiator,omitempty"`
}
type VNCServerStateOutput struct {
Enabled bool `json:"enabled" yaml:"enabled"`
Sessions []VNCSessionOutput `json:"sessions" yaml:"sessions"`
}
type OutputOverview struct {
Peers PeersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
@@ -167,7 +154,6 @@ type OutputOverview struct {
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
ProfileName string `json:"profileName" yaml:"profileName"`
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
VNCServerState VNCServerStateOutput `json:"vncServer" yaml:"vncServer"`
}
// ConvertToStatusOutputOverview converts protobuf status to the output overview.
@@ -188,7 +174,6 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
relayOverview := mapRelays(pbFullStatus.GetRelays())
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
vncServerOverview := mapVNCServer(pbFullStatus.GetVncServerState())
peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter)
overview := OutputOverview{
@@ -214,7 +199,6 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
ProfileName: opts.ProfileName,
SSHServerState: sshServerOverview,
VNCServerState: vncServerOverview,
}
if opts.Anonymize {
@@ -289,26 +273,6 @@ func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput {
}
}
func mapVNCServer(state *proto.VNCServerState) VNCServerStateOutput {
if state == nil {
return VNCServerStateOutput{Sessions: []VNCSessionOutput{}}
}
sessions := make([]VNCSessionOutput, 0, len(state.GetSessions()))
for _, sess := range state.GetSessions() {
sessions = append(sessions, VNCSessionOutput{
RemoteAddress: sess.GetRemoteAddress(),
Mode: sess.GetMode(),
Username: sess.GetUsername(),
UserID: sess.GetUserID(),
Initiator: sess.GetInitiator(),
})
}
return VNCServerStateOutput{
Enabled: state.GetEnabled(),
Sessions: sessions,
}
}
func mapPeers(
peers []*proto.PeerState,
statusFilter string,
@@ -571,26 +535,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
}
}
vncServerStatus := "Disabled"
if o.VNCServerState.Enabled {
vncSessionCount := len(o.VNCServerState.Sessions)
if vncSessionCount > 0 {
sessionWord := "session"
if vncSessionCount > 1 {
sessionWord = "sessions"
}
vncServerStatus = fmt.Sprintf("Enabled (%d active %s)", vncSessionCount, sessionWord)
} else {
vncServerStatus = "Enabled"
}
if showSSHSessions && vncSessionCount > 0 {
for _, sess := range o.VNCServerState.Sessions {
vncServerStatus += "\n " + formatVNCSessionLine(sess)
}
}
}
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
var forwardingRulesString string
@@ -637,7 +581,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
"Quantum resistance: %s\n"+
"Lazy connection: %s\n"+
"SSH Server: %s\n"+
"VNC Server: %s\n"+
"Networks: %s\n"+
"%s"+
"Peers count: %s\n",
@@ -657,7 +600,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
rosenpassEnabledStatus,
lazyConnectionEnabledStatus,
sshServerStatus,
vncServerStatus,
networks,
forwardingRulesString,
peersCountString,
@@ -1017,26 +959,6 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *PeerStateDetailOutput) {
}
}
// formatVNCSessionLine renders a single VNC session row for the detailed
// status output. The leading slot identifies the initiator (display name
// when known, hashed UserID otherwise); the post-arrow slot is the OS
// user the session targets and is omitted in attach mode where the
// destination is the current console user (unknown to the daemon).
func formatVNCSessionLine(sess VNCSessionOutput) string {
who := sess.Initiator
if who == "" {
who = sess.UserID
}
prefix := sess.RemoteAddress
if who != "" {
prefix = fmt.Sprintf("%s@%s", who, sess.RemoteAddress)
}
if sess.Username != "" {
return fmt.Sprintf("[%s -> %s] mode=%s", prefix, sess.Username, sess.Mode)
}
return fmt.Sprintf("[%s] mode=%s", prefix, sess.Mode)
}
func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
for i, peer := range overview.Peers.Details {
peer := peer
@@ -1057,19 +979,6 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
overview.Relays.Details[i] = detail
}
anonymizeNSServerGroups(a, overview)
for i, route := range overview.Networks {
overview.Networks[i] = a.AnonymizeRoute(route)
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
anonymizeEvents(a, overview)
anonymizeServerSessions(a, overview)
}
func anonymizeNSServerGroups(a *anonymize.Anonymizer, overview *OutputOverview) {
for i, nsGroup := range overview.NSServerGroups {
for j, domain := range nsGroup.Domains {
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
@@ -1081,9 +990,13 @@ func anonymizeNSServerGroups(a *anonymize.Anonymizer, overview *OutputOverview)
}
}
}
}
func anonymizeEvents(a *anonymize.Anonymizer, overview *OutputOverview) {
for i, route := range overview.Networks {
overview.Networks[i] = a.AnonymizeRoute(route)
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
for i, event := range overview.Events {
overview.Events[i].Message = a.AnonymizeString(event.Message)
overview.Events[i].UserMessage = a.AnonymizeString(event.UserMessage)
@@ -1092,24 +1005,13 @@ func anonymizeEvents(a *anonymize.Anonymizer, overview *OutputOverview) {
event.Metadata[k] = a.AnonymizeString(v)
}
}
}
func anonymizeRemoteAddress(a *anonymize.Anonymizer, addr string) string {
if host, port, err := net.SplitHostPort(addr); err == nil {
return fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
}
return a.AnonymizeIPString(addr)
}
func anonymizeServerSessions(a *anonymize.Anonymizer, overview *OutputOverview) {
for i, session := range overview.SSHServerState.Sessions {
overview.SSHServerState.Sessions[i].RemoteAddress = anonymizeRemoteAddress(a, session.RemoteAddress)
if host, port, err := net.SplitHostPort(session.RemoteAddress); err == nil {
overview.SSHServerState.Sessions[i].RemoteAddress = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
} else {
overview.SSHServerState.Sessions[i].RemoteAddress = a.AnonymizeIPString(session.RemoteAddress)
}
overview.SSHServerState.Sessions[i].Command = a.AnonymizeString(session.Command)
}
for i, sess := range overview.VNCServerState.Sessions {
overview.VNCServerState.Sessions[i].RemoteAddress = anonymizeRemoteAddress(a, sess.RemoteAddress)
overview.VNCServerState.Sessions[i].Username = a.AnonymizeString(sess.Username)
overview.VNCServerState.Sessions[i].UserID = a.AnonymizeString(sess.UserID)
overview.VNCServerState.Sessions[i].Initiator = a.AnonymizeString(sess.Initiator)
}
}

View File

@@ -242,10 +242,6 @@ var overview = OutputOverview{
Enabled: false,
Sessions: []SSHSessionOutput{},
},
VNCServerState: VNCServerStateOutput{
Enabled: false,
Sessions: []VNCSessionOutput{},
},
}
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
@@ -411,10 +407,6 @@ func TestParsingToJSON(t *testing.T) {
"sshServer":{
"enabled":false,
"sessions":[]
},
"vncServer":{
"enabled":false,
"sessions":[]
}
}`
// @formatter:on
@@ -525,9 +517,6 @@ profileName: ""
sshServer:
enabled: false
sessions: []
vncServer:
enabled: false
sessions: []
`
assert.Equal(t, expectedYAML, yaml)
@@ -598,7 +587,6 @@ Wireguard port: %d
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
VNC Server: Disabled
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion, overview.WgPort)
@@ -625,7 +613,6 @@ Wireguard port: 51820
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
VNC Server: Disabled
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`

View File

@@ -62,7 +62,6 @@ type Info struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed bool
ServerVNCAllowed bool
DisableClientRoutes bool
DisableServerRoutes bool
@@ -84,7 +83,6 @@ type Info struct {
func (i *Info) SetFlags(
rosenpassEnabled, rosenpassPermissive bool,
serverSSHAllowed *bool,
serverVNCAllowed *bool,
disableClientRoutes, disableServerRoutes,
disableDNS, disableFirewall, blockLANAccess, blockInbound, disableIPv6, lazyConnectionEnabled bool,
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
@@ -95,9 +93,6 @@ func (i *Info) SetFlags(
if serverSSHAllowed != nil {
i.ServerSSHAllowed = *serverSSHAllowed
}
if serverVNCAllowed != nil {
i.ServerVNCAllowed = *serverVNCAllowed
}
i.DisableClientRoutes = disableClientRoutes
i.DisableServerRoutes = disableServerRoutes

View File

@@ -1,259 +0,0 @@
//go:build !(linux && 386)
package main
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"strings"
"time"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/widget"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/approval"
"github.com/netbirdio/netbird/client/proto"
)
// Approval metadata that is remote-peer or dashboard controlled is passed to
// the forked netbird-ui via environment variables rather than argv, so it is
// not exposed to other local users through ps.
const (
envApprovalInitiator = "NB_APPROVAL_INITIATOR"
envApprovalPeerName = "NB_APPROVAL_PEER_NAME"
envApprovalSourceIP = "NB_APPROVAL_SOURCE_IP"
envApprovalUsername = "NB_APPROVAL_USERNAME"
envApprovalKeyFingerprint = "NB_APPROVAL_KEY_FINGERPRINT"
envApprovalSubject = "NB_APPROVAL_SUBJECT"
)
// handleApprovalEvent forks a netbird-ui child process to render the
// dialog on its own fyne main loop. Top-level windows opened from a
// background goroutine of the tray process don't render reliably on
// Linux/GTK, so the rest of the UI (settings, login URL, update) uses
// the same fork pattern.
func (s *serviceClient) handleApprovalEvent(ev *proto.SystemEvent) {
if ev == nil || ev.Category != proto.SystemEvent_APPROVAL {
return
}
requestID := ev.Metadata["request_id"]
if requestID == "" {
log.Warnf("approval event missing request_id: %v", ev.Metadata)
return
}
// Only the request id, kind, and deadline stay on argv: they are
// daemon-issued and non-sensitive. The remote-influenced fields go
// through the child's environment.
args := []string{
"--approval-request-id=" + requestID,
"--approval-kind=" + ev.Metadata["kind"],
"--approval-expires-at=" + ev.Metadata["expires_at"],
}
env := append(os.Environ(),
envApprovalInitiator+"="+ev.Metadata["initiator"],
envApprovalPeerName+"="+ev.Metadata["peer_name"],
envApprovalSourceIP+"="+ev.Metadata["source_ip"],
envApprovalUsername+"="+ev.Metadata["username"],
envApprovalKeyFingerprint+"="+ev.Metadata["peer_pubkey"],
envApprovalSubject+"="+ev.UserMessage,
)
go s.runApprovalCommand(s.ctx, env, args)
}
// runApprovalCommand forks netbird-ui to render the approval dialog,
// inheriting the parent environment plus the approval-specific variables. It
// mirrors runSelfCommand but sets cmd.Env so the sensitive metadata never
// appears on the child's argv.
func (s *serviceClient) runApprovalCommand(ctx context.Context, env, args []string) {
proc, err := os.Executable()
if err != nil {
log.Errorf("get executable path: %v", err)
return
}
cmdArgs := append([]string{"--approval=true", "--daemon-addr=" + s.addr}, args...)
cmd := exec.CommandContext(ctx, proc, cmdArgs...)
cmd.Env = env
if out := s.attachOutput(cmd); out != nil {
defer func() {
if err := out.Close(); err != nil {
log.Errorf("close log file %s: %v", s.logFile, err)
}
}()
}
log.Printf("running approval command: %s", cmd.String())
if err := cmd.Run(); err != nil {
var exitErr *exec.ExitError
if errors.As(err, &exitErr) {
log.Printf("approval command failed with exit code %d", exitErr.ExitCode())
}
}
}
// showApprovalUI runs the dialog on the forked process's fyne main loop
// and forwards the user's decision to the daemon via RespondApproval.
func (s *serviceClient) showApprovalUI(req approvalRequest) {
w := s.app.NewWindow(approvalTitle(req.kind))
w.Resize(fyne.NewSize(480, 260))
w.CenterOnScreen()
w.RequestFocus()
var rows []string
if req.initiator != "" {
// The display name comes from the management dashboard and is
// not cryptographically asserted by the connecting client. The
// key fingerprint that follows IS: it's the Noise_IK static
// public key the client just proved possession of. Show both
// so the user can sanity-check that "Alice" is really the
// Alice they trust.
rows = append(rows, "From user: "+req.initiator)
}
if fp := approval.ShortKeyFingerprint(req.keyFingerprint); fp != "" {
rows = append(rows, "Key fp: "+fp)
}
if req.peerName != "" {
rows = append(rows, "Via peer: "+req.peerName)
}
if req.sourceIP != "" && req.sourceIP != req.peerName {
rows = append(rows, "Source IP: "+req.sourceIP)
}
if req.username != "" {
rows = append(rows, "OS user: "+req.username)
}
if len(rows) == 0 {
rows = []string{"Remote: " + req.displayPeer()}
}
body := strings.Join(rows, "\n")
bodyLabel := widget.NewLabel(body)
bodyLabel.Wrapping = fyne.TextWrapWord
countdown := widget.NewLabel("")
deadline := req.deadline()
updateCountdown := func() {
remaining := time.Until(deadline).Round(time.Second)
if remaining < 0 {
remaining = 0
}
countdown.SetText(fmt.Sprintf("Auto-deny in %s", remaining))
}
updateCountdown()
type outcome struct {
accept bool
viewOnly bool
}
decided := make(chan outcome, 1)
decide := func(o outcome) {
select {
case decided <- o:
default:
}
}
allow := widget.NewButton("Allow", func() { decide(outcome{accept: true}) })
allow.Importance = widget.HighImportance
allowView := widget.NewButton("Allow (view only)", func() { decide(outcome{accept: true, viewOnly: true}) })
deny := widget.NewButton("Deny", func() { decide(outcome{accept: false}) })
header := widget.NewLabelWithStyle(req.subject, fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
buttonRow := container.NewGridWithColumns(3, allow, allowView, deny)
info := container.NewVBox(header, widget.NewSeparator(), bodyLabel, widget.NewSeparator(), countdown)
w.SetContent(container.NewPadded(container.NewBorder(nil, buttonRow, nil, nil, info)))
w.SetCloseIntercept(func() { decide(outcome{}) })
go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for range ticker.C {
if time.Until(deadline) <= 0 {
decide(outcome{})
return
}
fyne.Do(updateCountdown)
}
}()
go func() {
o := <-decided
s.sendApprovalResponse(req.requestID, o.accept, o.viewOnly)
fyne.Do(func() {
w.Close()
s.app.Quit()
})
}()
w.Show()
}
func (s *serviceClient) sendApprovalResponse(requestID string, accept, viewOnly bool) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Warnf("approval response: get daemon client: %v", err)
return
}
ctx, cancel := context.WithTimeout(s.ctx, defaultFailTimeout)
defer cancel()
if _, err := conn.RespondApproval(ctx, &proto.RespondApprovalRequest{
RequestId: requestID,
Accept: accept,
ViewOnly: viewOnly,
}); err != nil {
log.Warnf("approval response: %v", err)
}
}
// approvalRequest is the parsed --approval-* CLI args that the forked
// dialog process consumes.
type approvalRequest struct {
requestID string
kind string
initiator string
peerName string
sourceIP string
username string
subject string
expiresAt string
keyFingerprint string
}
func (r approvalRequest) displayPeer() string {
switch {
case r.initiator != "":
return r.initiator
case r.peerName != "":
return r.peerName
case r.sourceIP != "":
return r.sourceIP
default:
return "unknown peer"
}
}
// deadline returns the wall-clock auto-deny moment. Falls back to a short
// local window when the daemon's expires_at is missing/unparsable, so a
// stale value never leaves the dialog open indefinitely.
func (r approvalRequest) deadline() time.Time {
if t, err := time.Parse(time.RFC3339, r.expiresAt); err == nil {
return t
}
return time.Now().Add(13 * time.Second)
}
func approvalTitle(kind string) string {
switch kind {
case "vnc":
return "Allow VNC Connection?"
case "ssh":
return "Allow SSH Connection?"
default:
return "Allow Incoming Connection?"
}
}

View File

@@ -38,6 +38,7 @@ import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/mdm"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ui/desktop"
"github.com/netbirdio/netbird/client/ui/event"
@@ -56,8 +57,22 @@ const (
const (
censoredPreSharedKey = "**********"
maxSSHJWTCacheTTL = 86_400 // 24 hours in seconds
// mdmFieldSuffix is appended to plain-text Entry widgets in the
// advanced Settings window when the underlying field is enforced
// by MDM, so the user sees the lock indicator inline next to the
// value. Stripped before any read site that feeds the value back
// into a SetConfig request (saveSettings / parseNumericSettings).
mdmFieldSuffix = " (MDM)"
)
// main is the entry point for the UI tray/client binary. Parses CLI
// flags, initialises logging, builds the Fyne application and tray
// icons, and constructs the service client (which may open a
// requested UI window). When a window-mode flag is set the Fyne event
// loop runs and main returns; otherwise main enforces single-instance
// behaviour (signalling an existing instance to show its window when
// present), sets up signal handling + default fonts, and runs the
// system tray loop.
func main() {
flags := parseFlags()
@@ -97,25 +112,13 @@ func main() {
showQuickActions: flags.showQuickActions,
showUpdate: flags.showUpdate,
showUpdateVersion: flags.showUpdateVersion,
showApproval: flags.showApproval,
approvalRequest: approvalRequest{
requestID: flags.approvalRequestID,
kind: flags.approvalKind,
initiator: os.Getenv(envApprovalInitiator),
peerName: os.Getenv(envApprovalPeerName),
sourceIP: os.Getenv(envApprovalSourceIP),
username: os.Getenv(envApprovalUsername),
subject: os.Getenv(envApprovalSubject),
expiresAt: flags.approvalExpiresAt,
keyFingerprint: os.Getenv(envApprovalKeyFingerprint),
},
})
// Watch for theme/settings changes to update the icon.
go watchSettingsChanges(a, client)
// Run in window mode if any UI flag was set.
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions || flags.showUpdate || flags.showApproval {
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions || flags.showUpdate {
a.Run()
return
}
@@ -152,11 +155,6 @@ type cliFlags struct {
saveLogsInFile bool
showUpdate bool
showUpdateVersion string
showApproval bool
approvalRequestID string
approvalKind string
approvalExpiresAt string
}
// parseFlags reads and returns all needed command-line flags.
@@ -178,10 +176,6 @@ func parseFlags() *cliFlags {
flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window")
flag.BoolVar(&flags.showUpdate, "update", false, "show update progress window")
flag.StringVar(&flags.showUpdateVersion, "update-version", "", "version to update to")
flag.BoolVar(&flags.showApproval, "approval", false, "show inbound-connection approval prompt window")
flag.StringVar(&flags.approvalRequestID, "approval-request-id", "", "approval prompt: daemon-issued request id")
flag.StringVar(&flags.approvalKind, "approval-kind", "", "approval prompt: subsystem kind (vnc, ssh, ...)")
flag.StringVar(&flags.approvalExpiresAt, "approval-expires-at", "", "approval prompt: RFC3339 deadline at which the daemon auto-denies")
flag.Parse()
return &flags
}
@@ -270,7 +264,6 @@ type serviceClient struct {
mQuit *systray.MenuItem
mNetworks *systray.MenuItem
mAllowSSH *systray.MenuItem
mAllowVNC *systray.MenuItem
mAutoConnect *systray.MenuItem
mEnableRosenpass *systray.MenuItem
mLazyConnEnabled *systray.MenuItem
@@ -309,8 +302,6 @@ type serviceClient struct {
sEnableSSHRemotePortForward *widget.Check
sDisableSSHAuth *widget.Check
iSSHJWTCacheTTL *widget.Entry
sServerVNCAllowed *widget.Check
sDisableVNCApproval *widget.Check
// observable settings over corresponding iMngURL and iPreSharedKey values.
managementURL string
@@ -332,8 +323,6 @@ type serviceClient struct {
enableSSHRemotePortForward bool
disableSSHAuth bool
sshJWTCacheTTL int
serverVNCAllowed bool
disableVNCApproval bool
connected bool
daemonVersion string
@@ -341,9 +330,13 @@ type serviceClient struct {
isUpdateIconActive bool
isEnforcedUpdate bool
lastNotifiedVersion string
settingsEnabled bool
profilesEnabled bool
networksEnabled bool
// networksMenuEnabled caches the last applied enabled-state of the
// mNetworks + mExitNode submenu items. Combines features.DisableNetworks
// AND s.connected — both must be true for the menus to be active.
// Zero value (false) matches the Disable() call at AddMenuItem time.
networksMenuEnabled bool
showNetworks bool
wNetworks fyne.Window
wProfiles fyne.Window
@@ -362,6 +355,13 @@ type serviceClient struct {
updateContextCancel context.CancelFunc
connectCancel context.CancelFunc
// mdmManagedFields caches the names of MDM-enforced policy keys
// surfaced by the daemon in GetConfigResponse. Each refresh of
// daemon config (loadSettings, getSrvConfig, config_changed event)
// updates this set and re-applies the lock/badge to the affected
// menu items and settings-form widgets.
mdmManagedFields map[string]bool
}
type menuHandler struct {
@@ -381,8 +381,6 @@ type newServiceClientArgs struct {
showQuickActions bool
showUpdate bool
showUpdateVersion string
showApproval bool
approvalRequest approvalRequest
}
// newServiceClient instance constructor
@@ -423,8 +421,6 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
s.showQuickActionsUI()
case args.showUpdate:
s.showUpdateProgress(ctx, args.showUpdateVersion)
case args.showApproval:
s.showApprovalUI(args.approvalRequest)
}
return s
@@ -471,15 +467,12 @@ func (s *serviceClient) updateIcon() {
}
func (s *serviceClient) showSettingsUI() {
// Check if update settings are disabled by daemon
features, err := s.getFeatures()
if err != nil {
log.Errorf("failed to get features from daemon: %v", err)
// Continue with default behavior if features can't be retrieved
} else if features != nil && features.DisableUpdateSettings {
log.Warn("Update settings are disabled by daemon")
return
}
// DisableUpdateSettings no longer gates the window from opening:
// the daemon blocks every actual mutation at SetConfig / Login,
// so the window is safe to show as a read-only view. The previous
// early-return also blocked Advanced Settings whenever update
// editing was off, which conflated two distinct kill switches
// (see comment in checkAndUpdateFeatures).
// add settings window UI elements.
s.wSettings = s.app.NewWindow("NetBird Settings")
@@ -508,8 +501,6 @@ func (s *serviceClient) showSettingsUI() {
s.sEnableSSHRemotePortForward = widget.NewCheck("Enable SSH Remote Port Forwarding", nil)
s.sDisableSSHAuth = widget.NewCheck("Disable SSH Authentication", nil)
s.iSSHJWTCacheTTL = widget.NewEntry()
s.sServerVNCAllowed = widget.NewCheck("Allow embedded VNC server on this peer", nil)
s.sDisableVNCApproval = widget.NewCheck("Skip per-connection approval prompt for VNC", nil)
s.wSettings.SetContent(s.getSettingsForm())
s.wSettings.Resize(fyne.NewSize(600, 400))
@@ -564,7 +555,7 @@ func (s *serviceClient) saveSettings() {
return
}
iMngURL := strings.TrimSpace(s.iMngURL.Text)
iMngURL := strings.TrimSpace(strings.TrimSuffix(s.iMngURL.Text, mdmFieldSuffix))
if s.hasSettingsChanged(iMngURL, port, mtu) {
if err := s.applySettingsChanges(iMngURL, port, mtu); err != nil {
@@ -586,7 +577,7 @@ func (s *serviceClient) validateSettings() error {
}
func (s *serviceClient) parseNumericSettings() (int64, int64, error) {
port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64)
port, err := strconv.ParseInt(strings.TrimSpace(strings.TrimSuffix(s.iInterfacePort.Text, mdmFieldSuffix)), 10, 64)
if err != nil {
return 0, 0, errors.New("invalid interface port")
}
@@ -622,8 +613,7 @@ func (s *serviceClient) hasSettingsChanged(iMngURL string, port, mtu int64) bool
s.disableServerRoutes != s.sDisableServerRoutes.Checked ||
s.disableIPv6 != s.sDisableIPv6.Checked ||
s.blockLANAccess != s.sBlockLANAccess.Checked ||
s.hasSSHChanges() ||
s.hasVNCChanges()
s.hasSSHChanges()
}
func (s *serviceClient) applySettingsChanges(iMngURL string, port, mtu int64) error {
@@ -682,8 +672,6 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
req.EnableSSHLocalPortForwarding = &s.sEnableSSHLocalPortForward.Checked
req.EnableSSHRemotePortForwarding = &s.sEnableSSHRemotePortForward.Checked
req.DisableSSHAuth = &s.sDisableSSHAuth.Checked
req.ServerVNCAllowed = &s.sServerVNCAllowed.Checked
req.DisableVNCApproval = &s.sDisableVNCApproval.Checked
sshJWTCacheTTLText := strings.TrimSpace(s.iSSHJWTCacheTTL.Text)
if sshJWTCacheTTLText != "" {
@@ -698,7 +686,15 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
req.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if s.iPreSharedKey.Text != censoredPreSharedKey {
// Only attach the PSK when the user actually typed something:
// - "" means the field was left untouched (we deliberately render
// an empty Text + placeholder hint to avoid leaking the daemon's
// "**********" redaction through the password reveal toggle);
// sending an empty pointer would tell the daemon to clear / overwrite
// the on-disk or MDM-enforced PSK, which then trips the MDM
// conflict gate when PSK is policy-managed.
// - "**********" is the redacted echo (legacy non-MDM path); also a no-op.
if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey {
req.OptionalPreSharedKey = &s.iPreSharedKey.Text
}
@@ -744,12 +740,10 @@ func (s *serviceClient) getSettingsForm() fyne.CanvasObject {
connectionForm := s.getConnectionForm()
networkForm := s.getNetworkForm()
sshForm := s.getSSHForm()
vncForm := s.getVNCForm()
tabs := container.NewAppTabs(
container.NewTabItem("Connection", connectionForm),
container.NewTabItem("Network", networkForm),
container.NewTabItem("SSH", sshForm),
container.NewTabItem("VNC", vncForm),
)
saveButton := widget.NewButtonWithIcon("Save", theme.ConfirmIcon(), s.saveSettings)
saveButton.Importance = widget.HighImportance
@@ -790,15 +784,6 @@ func (s *serviceClient) getSSHForm() *widget.Form {
}
}
func (s *serviceClient) getVNCForm() *widget.Form {
return &widget.Form{
Items: []*widget.FormItem{
{Text: "Allow VNC Server", Widget: s.sServerVNCAllowed},
{Text: "Disable Connection Approval Prompt", Widget: s.sDisableVNCApproval},
},
}
}
func (s *serviceClient) hasSSHChanges() bool {
currentSSHJWTCacheTTL := s.sshJWTCacheTTL
if text := strings.TrimSpace(s.iSSHJWTCacheTTL.Text); text != "" {
@@ -817,11 +802,6 @@ func (s *serviceClient) hasSSHChanges() bool {
s.sshJWTCacheTTL != currentSSHJWTCacheTTL
}
func (s *serviceClient) hasVNCChanges() bool {
return s.serverVNCAllowed != s.sServerVNCAllowed.Checked ||
s.disableVNCApproval != s.sDisableVNCApproval.Checked
}
func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
@@ -1087,6 +1067,13 @@ func (s *serviceClient) onTrayReady() {
}
s.mProfile = newProfileMenu(*newProfileMenuArgs)
// Seed the transition cache to match the actual default menu
// state (visible / enabled). Without this, the first
// checkAndUpdateFeatures tick that observes DisableProfiles=true
// is a no-op (cache zero-value == desired-false) and the menu
// never gets hidden — symptom: MDM enforces the kill switch but
// the profile menu stays clickable.
s.profilesEnabled = true
systray.AddSeparator()
s.mUp = systray.AddMenuItem("Connect", "Connect")
@@ -1096,7 +1083,6 @@ func (s *serviceClient) onTrayReady() {
s.mSettings = systray.AddMenuItem("Settings", disabledMenuDescr)
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false)
s.mAllowVNC = s.mSettings.AddSubMenuItemCheckbox("Allow VNC", allowVNCMenuDescr, false)
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false)
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false)
s.mLazyConnEnabled = s.mSettings.AddSubMenuItemCheckbox("Enable Lazy Connections", lazyConnMenuDescr, false)
@@ -1107,18 +1093,18 @@ func (s *serviceClient) onTrayReady() {
s.mCreateDebugBundle = s.mSettings.AddSubMenuItem("Create Debug Bundle", debugBundleMenuDescr)
s.loadSettings()
// Disable settings menu if update settings are disabled by daemon
// Disable profile menu if profiles are disabled by daemon.
// DisableUpdateSettings is enforced at the daemon's SetConfig /
// Login gates, not by hiding the UI — so the Settings menu (and
// its Advanced Settings submenu, which has its own kill switch)
// stays visible and the user can still inspect current values.
features, err := s.getFeatures()
if err != nil {
log.Errorf("failed to get features from daemon: %v", err)
// Continue with default behavior if features can't be retrieved
} else {
if features != nil && features.DisableUpdateSettings {
s.setSettingsEnabled(false)
}
if features != nil && features.DisableProfiles {
s.mProfile.setEnabled(false)
}
} else if features != nil && features.DisableProfiles {
s.mProfile.setEnabled(false)
s.profilesEnabled = false
}
s.exitNodeMu.Lock()
@@ -1152,13 +1138,20 @@ func (s *serviceClient) onTrayReady() {
// update exit node menu in case service is already connected
go s.updateExitNodes()
// Features (DisableProfiles, DisableUpdateSettings, DisableNetworks,
// ...) only change in two ways: at service install time (CLI flag,
// static) and at MDM ticker diff time. The daemon already publishes
// a SystemEvent{type=config_changed} on every MDM-driven engine
// restart, so the UI no longer needs to poll GetFeatures every 2 s.
// A single fetch at startup covers the static CLI-flag case; the
// event handler below covers MDM transitions. updateStatus stays in
// the 2 s loop because connection / peer state genuinely change
// continuously and have no event yet.
s.checkAndUpdateFeatures()
go func() {
s.getSrvConfig()
time.Sleep(100 * time.Millisecond) // To prevent race condition caused by systray not being fully initialized and ignoring setIcon
for {
// Check features before status so menus respect disable flags before being enabled
s.checkAndUpdateFeatures()
err := s.updateStatus()
if err != nil {
log.Errorf("error while updating status: %v", err)
@@ -1170,7 +1163,6 @@ func (s *serviceClient) onTrayReady() {
s.eventManager = event.NewManager(s.notifier, s.addr)
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
s.eventManager.AddHandler(s.handleApprovalEvent)
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
if event.Category == proto.SystemEvent_SYSTEM {
s.updateExitNodes()
@@ -1203,6 +1195,23 @@ func (s *serviceClient) onTrayReady() {
s.onUpdateAvailable(newVersion, enforced)
}
})
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
// Daemon emits a config_changed event after every engine spawn
// (Server.Start, Server.Up, MDM ticker restart). Re-sync the
// tray submenu checkboxes from the fresh daemon-side config so
// the user does not have to restart the tray to see CLI- or
// MDM-driven changes.
if event.Category == proto.SystemEvent_SYSTEM && event.Metadata["type"] == "config_changed" {
log.Infof("config_changed event received (source=%s); refreshing settings + features", event.Metadata["source"])
s.loadSettings()
// MDM-driven feature kill switches (DisableProfiles /
// DisableUpdateSettings / DisableNetworks) ride the same
// config_changed signal because the daemon re-applies its
// MDM policy on every engine spawn. Pull them in here so
// the UI is up to date without a periodic GetFeatures poll.
s.checkAndUpdateFeatures()
}
})
go s.eventManager.Start(s.ctx)
go s.eventHandler.listen(s.ctx)
@@ -1266,18 +1275,6 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService
return s.conn, nil
}
// setSettingsEnabled enables or disables the settings menu based on the provided state
func (s *serviceClient) setSettingsEnabled(enabled bool) {
if s.mSettings != nil {
if enabled {
s.mSettings.Enable()
} else {
s.mSettings.Hide()
s.mSettings.SetTooltip("Settings are disabled by daemon")
}
}
}
// checkAndUpdateFeatures checks the current features and updates the UI accordingly
func (s *serviceClient) checkAndUpdateFeatures() {
features, err := s.getFeatures()
@@ -1289,12 +1286,11 @@ func (s *serviceClient) checkAndUpdateFeatures() {
s.updateIndicationLock.Lock()
defer s.updateIndicationLock.Unlock()
// Update settings menu based on current features
settingsEnabled := features == nil || !features.DisableUpdateSettings
if s.settingsEnabled != settingsEnabled {
s.settingsEnabled = settingsEnabled
s.setSettingsEnabled(settingsEnabled)
}
// DisableUpdateSettings is enforced server-side by the daemon gates
// on SetConfig + Login: any attempt to mutate config from UI or
// CLI is rejected at that layer. The UI deliberately keeps the
// Settings menu visible so the user can still inspect current
// values — read-only by virtue of the daemon refusing edits.
// Update profile menu based on current features
if s.mProfile != nil {
@@ -1305,14 +1301,23 @@ func (s *serviceClient) checkAndUpdateFeatures() {
}
}
// Update networks and exit node menus based on current features
// Update networks and exit node menus based on current features.
// `networksEnabled` is the bare feature flag (read elsewhere, e.g. at
// connection-status transitions). `networksMenuEnabled` is the
// transition-cached state actually applied to the menu items —
// it folds in the connection state so a Connected client with the
// kill switch off shows the menus active, and only flips on diff.
s.networksEnabled = features == nil || !features.DisableNetworks
if s.networksEnabled && s.connected {
s.mNetworks.Enable()
s.mExitNode.Enable()
} else {
s.mNetworks.Disable()
s.mExitNode.Disable()
desiredNetworksMenu := s.networksEnabled && s.connected
if desiredNetworksMenu != s.networksMenuEnabled {
s.networksMenuEnabled = desiredNetworksMenu
if desiredNetworksMenu {
s.mNetworks.Enable()
s.mExitNode.Enable()
} else {
s.mNetworks.Disable()
s.mExitNode.Disable()
}
}
}
@@ -1406,16 +1411,17 @@ func (s *serviceClient) getSrvConfig() {
if cfg.SSHJWTCacheTTL != nil {
s.sshJWTCacheTTL = *cfg.SSHJWTCacheTTL
}
if cfg.ServerVNCAllowed != nil {
s.serverVNCAllowed = *cfg.ServerVNCAllowed
}
if cfg.DisableVNCApproval != nil {
s.disableVNCApproval = *cfg.DisableVNCApproval
}
if s.showAdvancedSettings {
s.iMngURL.SetText(s.managementURL)
s.iPreSharedKey.SetText(cfg.PreSharedKey)
// PSK is rendered with an empty Text and a hint via the
// placeholder so the eye toggle never reveals literal asterisks
// (the daemon returns the "**********" sentinel — writing that
// into a PasswordEntry would surface the literal sentinel when
// the user unmasks the field). The placeholder communicates the
// configured / MDM-managed state without exposing any value.
s.iPreSharedKey.SetText("")
s.iPreSharedKey.SetPlaceHolder(preSharedKeyPlaceholder(srvCfg))
s.iInterfaceName.SetText(cfg.WgIface)
s.iInterfacePort.SetText(strconv.Itoa(cfg.WgPort))
if cfg.MTU != 0 {
@@ -1425,7 +1431,15 @@ func (s *serviceClient) getSrvConfig() {
s.iMTU.SetPlaceHolder(strconv.Itoa(int(iface.DefaultMTU)))
}
s.sRosenpassPermissive.SetChecked(cfg.RosenpassPermissive)
if !cfg.RosenpassEnabled {
// Re-baseline the enabled state on every refresh: when Rosenpass
// is on the checkbox is editable, when it's off the field is
// inert. Without an explicit Enable() here the control stays
// stuck disabled after a previous refresh (or an MDM unlock) had
// turned it off — applyMDMLocksToSettingsForm below adds the
// MDM lock on top of this baseline.
if cfg.RosenpassEnabled {
s.sRosenpassPermissive.Enable()
} else {
s.sRosenpassPermissive.Disable()
}
s.sNetworkMonitor.SetChecked(*cfg.NetworkMonitor)
@@ -1452,14 +1466,15 @@ func (s *serviceClient) getSrvConfig() {
if cfg.SSHJWTCacheTTL != nil {
s.iSSHJWTCacheTTL.SetText(strconv.Itoa(*cfg.SSHJWTCacheTTL))
}
if cfg.ServerVNCAllowed != nil {
s.sServerVNCAllowed.SetChecked(*cfg.ServerVNCAllowed)
}
if cfg.DisableVNCApproval != nil {
s.sDisableVNCApproval.SetChecked(*cfg.DisableVNCApproval)
}
}
// MDM locks must run before the mNotifications-nil early return:
// the Settings window is rendered by a separate UI process launched
// with --settings (see handleAdvancedSettingsClick), and that child
// process does NOT run onReady — so its mNotifications is nil and
// the early return below skipped the lock pass entirely.
s.applyMDMLocks(srvCfg.MDMManagedFields)
if s.mNotifications == nil {
return
}
@@ -1517,8 +1532,6 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
config.DisableAutoConnect = cfg.DisableAutoConnect
config.ServerSSHAllowed = &cfg.ServerSSHAllowed
config.ServerVNCAllowed = &cfg.ServerVNCAllowed
config.DisableVNCApproval = &cfg.DisableVNCApproval
config.RosenpassEnabled = cfg.RosenpassEnabled
config.RosenpassPermissive = cfg.RosenpassPermissive
config.DisableNotifications = &cfg.DisableNotifications
@@ -1614,12 +1627,6 @@ func (s *serviceClient) loadSettings() {
s.mAllowSSH.Uncheck()
}
if cfg.ServerVNCAllowed {
s.mAllowVNC.Check()
} else {
s.mAllowVNC.Uncheck()
}
if cfg.DisableAutoConnect {
s.mAutoConnect.Uncheck()
} else {
@@ -1652,6 +1659,129 @@ func (s *serviceClient) loadSettings() {
if s.eventManager != nil {
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
}
s.applyMDMLocks(cfg.MDMManagedFields)
}
// applyMDMLocks disables and badges any tray submenu item or settings-
// form widget whose underlying field is enforced by the active MDM
// policy. Called from loadSettings (submenu refresh) and from
// getSrvConfig (settings-window refresh). Locked items keep their value
// already set by the surrounding refresh code — this routine only
// flips the enabled state and the title suffix, never the value.
func (s *serviceClient) applyMDMLocks(managed []string) {
set := make(map[string]bool, len(managed))
for _, k := range managed {
set[k] = true
}
s.mdmManagedFields = set
if len(managed) > 0 {
log.Infof("MDM-managed UI fields: %v", managed)
}
type submenuTarget struct {
item *systray.MenuItem
title string
key string
}
for _, t := range []submenuTarget{
{s.mAllowSSH, "Allow SSH", mdm.KeyAllowServerSSH},
{s.mAutoConnect, "Connect on Startup", mdm.KeyDisableAutoConnect},
{s.mEnableRosenpass, "Enable Quantum-Resistance", mdm.KeyRosenpassEnabled},
{s.mBlockInbound, "Block Inbound Connections", mdm.KeyBlockInbound},
} {
if t.item == nil {
continue
}
if set[t.key] {
t.item.SetTitle(t.title + " (MDM)")
t.item.Disable()
} else {
t.item.SetTitle(t.title)
t.item.Enable()
}
}
s.applyMDMLocksToSettingsForm(set)
}
// preSharedKeyPlaceholder returns the hint string shown in the PSK
// Entry's placeholder slot. The placeholder is the only signal the
// user gets that a PSK is configured, because the entry's Text is
// forced to empty to keep the password reveal toggle from leaking
// the daemon-returned "**********" redaction sentinel. Returns "" if
// no PSK is present, "MDM-managed" if the key is enforced by MDM,
// and "configured" otherwise.
func preSharedKeyPlaceholder(cfg *proto.GetConfigResponse) string {
if cfg == nil || cfg.PreSharedKey == "" {
return ""
}
for _, k := range cfg.MDMManagedFields {
if k == mdm.KeyPreSharedKey {
return "MDM-managed"
}
}
return "configured"
}
// applyMDMLocksToSettingsForm disables the per-field input widgets in
// the advanced Settings window when the corresponding MDM key is set.
// For plain-text entries (Management URL, Interface Port) the visible
// value is suffixed with " (MDM)" so the user sees the lock indicator
// inline; for the password entry the suffix is skipped (a password
// widget renders every char as a dot and the indicator would not be
// readable). The widgets are created lazily by showSettingsUI, so
// guard each ref against nil.
func (s *serviceClient) applyMDMLocksToSettingsForm(set map[string]bool) {
type entryTarget struct {
entry *widget.Entry
key string
inlineTag bool
}
for _, t := range []entryTarget{
{s.iMngURL, mdm.KeyManagementURL, true},
{s.iPreSharedKey, mdm.KeyPreSharedKey, false},
{s.iInterfacePort, mdm.KeyWireguardPort, true},
} {
if t.entry == nil {
continue
}
if set[t.key] {
if t.inlineTag && t.entry.Text != "" && !strings.HasSuffix(t.entry.Text, mdmFieldSuffix) {
t.entry.SetText(t.entry.Text + mdmFieldSuffix)
}
t.entry.Disable()
} else {
if t.inlineTag {
t.entry.SetText(strings.TrimSuffix(t.entry.Text, mdmFieldSuffix))
}
t.entry.Enable()
}
}
type checkTarget struct {
check *widget.Check
key string
}
for _, t := range []checkTarget{
{s.sDisableClientRoutes, mdm.KeyDisableClientRoutes},
{s.sDisableServerRoutes, mdm.KeyDisableServerRoutes},
} {
if t.check == nil {
continue
}
if set[t.key] {
t.check.Disable()
} else {
t.check.Enable()
}
}
if s.sRosenpassPermissive != nil && set[mdm.KeyRosenpassPermissive] {
// MDM lock layered on top of the Rosenpass-on/off baseline
// applied by getSrvConfig. No Enable() branch here: when the
// MDM key is removed, the next getSrvConfig refresh re-baselines
// the control on cfg.RosenpassEnabled and brings it back if
// Rosenpass is on.
s.sRosenpassPermissive.Disable()
}
}
// updateConfig updates the configuration parameters
@@ -1659,7 +1789,6 @@ func (s *serviceClient) loadSettings() {
func (s *serviceClient) updateConfig() error {
disableAutoStart := !s.mAutoConnect.Checked()
sshAllowed := s.mAllowSSH.Checked()
vncAllowed := s.mAllowVNC.Checked()
rosenpassEnabled := s.mEnableRosenpass.Checked()
lazyConnectionEnabled := s.mLazyConnEnabled.Checked()
blockInbound := s.mBlockInbound.Checked()
@@ -1688,7 +1817,6 @@ func (s *serviceClient) updateConfig() error {
Username: currUser.Username,
DisableAutoConnect: &disableAutoStart,
ServerSSHAllowed: &sshAllowed,
ServerVNCAllowed: &vncAllowed,
RosenpassEnabled: &rosenpassEnabled,
LazyConnectionEnabled: &lazyConnectionEnabled,
BlockInbound: &blockInbound,

View File

@@ -2,7 +2,6 @@ package main
const (
allowSSHMenuDescr = "Allow SSH connections"
allowVNCMenuDescr = "Allow embedded VNC server"
autoConnectMenuDescr = "Connect automatically when the service starts"
quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass"
lazyConnMenuDescr = "[Experimental] Enable lazy connections"

View File

@@ -112,7 +112,7 @@ func (e *Manager) handleEvent(event *proto.SystemEvent) {
handlers := slices.Clone(e.handlers)
e.mu.Unlock()
if event.UserMessage != "" && (enabled || event.Severity == proto.SystemEvent_CRITICAL) && !isV6DefaultRoutePartner(event) && event.Category != proto.SystemEvent_APPROVAL {
if event.UserMessage != "" && (enabled || event.Severity == proto.SystemEvent_CRITICAL) && !isV6DefaultRoutePartner(event) {
title := e.getEventTitle(event)
body := event.UserMessage
id := event.Metadata["id"]

View File

@@ -39,8 +39,6 @@ func (h *eventHandler) listen(ctx context.Context) {
h.handleDisconnectClick()
case <-h.client.mAllowSSH.ClickedCh:
h.handleAllowSSHClick()
case <-h.client.mAllowVNC.ClickedCh:
h.handleAllowVNCClick()
case <-h.client.mAutoConnect.ClickedCh:
h.handleAutoConnectClick()
case <-h.client.mEnableRosenpass.ClickedCh:
@@ -136,15 +134,6 @@ func (h *eventHandler) handleAllowSSHClick() {
}
func (h *eventHandler) handleAllowVNCClick() {
h.toggleCheckbox(h.client.mAllowVNC)
if err := h.updateConfigWithErr(); err != nil {
h.toggleCheckbox(h.client.mAllowVNC) // revert checkbox state on error
log.Errorf("failed to update config: %v", err)
h.client.notifier.Send("Error", "Failed to update VNC settings")
}
}
func (h *eventHandler) handleAutoConnectClick() {
h.toggleCheckbox(h.client.mAutoConnect)
if err := h.updateConfigWithErr(); err != nil {

View File

@@ -666,16 +666,48 @@ func (p *profileMenu) clear(profiles []Profile) {
}
}
// setEnabled enables or disables the profile menu based on the provided state
// setEnabled greys out (Disable) the profile menu and every existing
// sub-item when the daemon reports the kill switch active, so the user
// sees the menu but cannot enter "Manage Profiles" or switch profile.
// Previously this used Hide() on the parent, but Fyne's systray on
// Windows does not propagate Hide() to a parent that already has
// children — the submenu kept popping up and accepting clicks. Disable
// is the reliable visual lock.
func (p *profileMenu) setEnabled(enabled bool) {
if p.profileMenuItem != nil {
if enabled {
p.profileMenuItem.Enable()
p.profileMenuItem.SetTooltip("")
} else {
p.profileMenuItem.Hide()
p.profileMenuItem.SetTooltip("Profiles are disabled by daemon")
if p.profileMenuItem == nil {
return
}
p.mu.Lock()
defer p.mu.Unlock()
if enabled {
p.profileMenuItem.Enable()
p.profileMenuItem.SetTooltip("")
} else {
p.profileMenuItem.Disable()
p.profileMenuItem.SetTooltip("Profiles are disabled by daemon")
}
apply := func(item *systray.MenuItem) {
if item == nil {
return
}
if enabled {
item.Enable()
} else {
item.Disable()
}
}
for _, sub := range p.profileSubItems {
if sub != nil {
apply(sub.MenuItem)
}
}
if p.manageProfilesSubItem != nil {
apply(p.manageProfilesSubItem.MenuItem)
}
if p.logoutSubItem != nil {
apply(p.logoutSubItem.MenuItem)
}
}

View File

@@ -1,31 +0,0 @@
// Package vnc holds shared constants for the NetBird embedded VNC stack
// so non-server consumers (CLI capture, debug tooling) can refer to the
// well-known ports without depending on internal engine packages.
package vnc
// External and internal listen ports for the embedded VNC server.
// ExternalPort is what dashboard / browser clients see; the daemon
// DNATs it to InternalPort, where the in-process VNC server actually
// listens. Both flow over the WireGuard interface. AgentLegacyPort is
// the TCP port the per-session agent used before it switched to Unix
// sockets; kept here so packet captures from older builds still get
// tagged, and so any future on-wire agent variant has a reserved port.
const (
ExternalPort uint16 = 5900
InternalPort uint16 = 25900
AgentLegacyPort uint16 = 15900
)
// WellKnownPorts is the unordered set of ports a packet capture should
// treat as carrying NetBird VNC traffic.
var WellKnownPorts = [...]uint16{ExternalPort, InternalPort, AgentLegacyPort}
// IsWellKnownPort reports whether port matches any of WellKnownPorts.
func IsWellKnownPort(port uint16) bool {
for _, p := range WellKnownPorts {
if port == p {
return true
}
}
return false
}

View File

@@ -1,434 +0,0 @@
//go:build darwin && !ios
package server
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"os"
"os/exec"
"strconv"
"sync"
"syscall"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/configs"
)
// darwinAgentManager spawns a per-user VNC agent on demand and keeps it
// alive across multiple client connections within the same console-user
// session. A new agent is spawned the first time a client connects, or
// whenever the console user changes underneath us.
//
// Lifecycle is lazy by design: a daemon that never receives a VNC
// connection never spawns anything. The trade-off versus an eager spawn
// (the Windows model) is that the first VNC client pays the launchctl
// asuser + listen-readiness wait, ~hundreds of milliseconds in practice.
// That cost only repeats on user switch.
type darwinAgentManager struct {
mu sync.Mutex
authToken string
socketPath string
uid uint32
running bool
}
func newDarwinAgentManager(ctx context.Context) *darwinAgentManager {
m := &darwinAgentManager{}
go m.watchConsoleUser(ctx)
return m
}
// agentSocketName is the file name inside the per-uid socket directory
// the agent binds. The directory itself is created and chowned by the
// daemon (see prepareAgentSocketDir) so a non-root local user cannot
// pre-create or symlink the path before the agent listens.
const agentSocketName = "agent.sock"
// watchConsoleUser kills the cached agent whenever the console user
// changes (logout, fast user switch, login window). Without it the daemon
// keeps proxying to an agent whose TCC grant and WindowServer access
// belong to a user who is no longer at the screen, so the new user only
// ever sees the locked-screen wallpaper. Killing the agent breaks the
// loopback TCP that the daemon proxies into, the client disconnects, and
// the next reconnect runs ensure() against the new console uid.
func (m *darwinAgentManager) watchConsoleUser(ctx context.Context) {
t := time.NewTicker(2 * time.Second)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
uid, err := consoleUserID()
m.mu.Lock()
if !m.running {
m.mu.Unlock()
continue
}
if err != nil || uid != m.uid {
prev := m.uid
m.killLocked()
m.mu.Unlock()
if err != nil {
log.Infof("console user gone (was uid=%d): %v; agent stopped", prev, err)
} else {
log.Infof("console user changed %d -> %d; agent stopped, will respawn on next connect", prev, uid)
}
continue
}
m.mu.Unlock()
}
}
}
// Resolve spawns or respawns the per-user agent process as needed and
// returns its Unix-socket path, shared token, and the uid the agent was
// spawned under (so the daemon can validate peer credentials before
// dispatching the token). Each call is serialized so concurrent VNC
// clients share the same agent.
func (m *darwinAgentManager) Resolve(ctx context.Context) (string, string, uint32, error) {
consoleUID, err := consoleUserID()
if err != nil {
return "", "", 0, fmt.Errorf("no console user: %w", err)
}
m.mu.Lock()
defer m.mu.Unlock()
if m.running && m.uid == consoleUID && vncAgentRunning() {
return m.socketPath, m.authToken, m.uid, nil
}
m.killLocked()
// Reap stray agents so the new token is the only accepted one.
killAllVNCAgents()
socketDir, err := prepareAgentSocketDir(consoleUID)
if err != nil {
return "", "", 0, fmt.Errorf("prepare agent socket dir: %w", err)
}
socketPath := socketDir + "/" + agentSocketName
if err := os.Remove(socketPath); err != nil && !errors.Is(err, os.ErrNotExist) {
log.Debugf("clear stale agent socket %s: %v", socketPath, err)
}
token, err := generateAuthToken()
if err != nil {
return "", "", 0, fmt.Errorf("generate agent auth token: %w", err)
}
if err := spawnAgentForUser(consoleUID, socketPath, token); err != nil {
return "", "", 0, err
}
if err := waitForAgent(ctx, socketPath, 5*time.Second); err != nil {
killAllVNCAgents()
return "", "", 0, fmt.Errorf("agent did not start listening: %w", err)
}
m.authToken = token
m.socketPath = socketPath
m.uid = consoleUID
m.running = true
log.Infof("spawned VNC agent for console uid=%d on %s", consoleUID, socketPath)
return socketPath, token, consoleUID, nil
}
// prepareAgentSocketDir creates a per-uid subdirectory under the netbird
// runtime directory where the agent will bind its Unix socket. The leaf is
// owned by uid with mode 0700, so only the target user and root can write
// there. The parent is created root-owned with mode 0755 if missing.
// Symlinks at the per-uid level are refused (replaced with a fresh
// directory) so a low-priv user cannot redirect the chown that follows.
func prepareAgentSocketDir(uid uint32) (string, error) {
parent := configs.RuntimeDir
if err := ensureAgentSocketParent(parent); err != nil {
return "", err
}
subdir := fmt.Sprintf("%s/vnc-%d", parent, uid)
if err := purgeStaleAgentSubdir(subdir, uid); err != nil {
return "", err
}
if err := os.Mkdir(subdir, 0o700); err != nil && !errors.Is(err, os.ErrExist) {
return "", fmt.Errorf("mkdir %s: %w", subdir, err)
}
if err := os.Chmod(subdir, 0o700); err != nil {
return "", fmt.Errorf("chmod %s: %w", subdir, err)
}
if err := os.Chown(subdir, int(uid), -1); err != nil {
return "", fmt.Errorf("chown %s -> uid %d: %w", subdir, uid, err)
}
return subdir, nil
}
// ensureAgentSocketParent verifies the runtime parent dir exists, is not a
// symlink, and is owned by root.
func ensureAgentSocketParent(parent string) error {
if parent == "" {
return fmt.Errorf("no runtime directory configured for this platform")
}
if err := os.MkdirAll(parent, 0o755); err != nil {
return fmt.Errorf("mkdir %s: %w", parent, err)
}
info, err := os.Lstat(parent)
if err != nil {
return fmt.Errorf("lstat %s: %w", parent, err)
}
if info.Mode()&os.ModeSymlink != 0 {
return fmt.Errorf("%s is a symlink", parent)
}
if st, ok := info.Sys().(*syscall.Stat_t); ok && st.Uid != 0 {
return fmt.Errorf("%s not owned by root (uid=%d)", parent, st.Uid)
}
return nil
}
// purgeStaleAgentSubdir removes a leftover subdir unless it is a real dir
// owned by uid with mode 0700. Lstat (not Stat) so a symlink is detected.
func purgeStaleAgentSubdir(subdir string, uid uint32) error {
info, err := os.Lstat(subdir)
if errors.Is(err, os.ErrNotExist) {
return nil
}
if err != nil {
return fmt.Errorf("lstat %s: %w", subdir, err)
}
if agentSubdirOK(info, uid) {
return nil
}
if err := os.RemoveAll(subdir); err != nil {
return fmt.Errorf("remove stale %s: %w", subdir, err)
}
return nil
}
func agentSubdirOK(info os.FileInfo, uid uint32) bool {
if info.Mode()&os.ModeSymlink != 0 || !info.IsDir() {
return false
}
st, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return false
}
return st.Uid == uid && info.Mode().Perm() == 0o700
}
// stop terminates the spawned agent, if any. Intended for daemon shutdown.
func (m *darwinAgentManager) stop() {
m.mu.Lock()
defer m.mu.Unlock()
m.killLocked()
}
func (m *darwinAgentManager) killLocked() {
if !m.running {
return
}
killAllVNCAgents()
if m.socketPath != "" {
if err := os.Remove(m.socketPath); err != nil && !errors.Is(err, os.ErrNotExist) {
log.Debugf("remove agent socket %s: %v", m.socketPath, err)
}
}
m.running = false
m.authToken = ""
m.socketPath = ""
m.uid = 0
}
// consoleUserID returns the uid of the user currently sitting at the
// console (the one whose Aqua session is active). Returns
// errNoConsoleUser when nobody is logged in: at the login window
// /dev/console is owned by root.
func consoleUserID() (uint32, error) {
info, err := os.Stat("/dev/console")
if err != nil {
return 0, fmt.Errorf("stat /dev/console: %w", err)
}
st, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return 0, fmt.Errorf("/dev/console stat has unexpected type")
}
if st.Uid == 0 {
return 0, errNoConsoleUser
}
return st.Uid, nil
}
// spawnAgentForUser uses launchctl asuser to start a netbird vnc-agent
// process inside the target user's launchd bootstrap namespace. That is
// the only spawn mode on macOS that gives the child access to the user's
// WindowServer. The agent's stderr is relogged into the daemon log so
// startup failures are not silently lost when the readiness check times
// out.
func spawnAgentForUser(uid uint32, socketPath, token string) error {
exe, err := os.Executable()
if err != nil {
return fmt.Errorf("resolve own executable: %w", err)
}
cmd := exec.Command(
"/bin/launchctl", "asuser", strconv.FormatUint(uint64(uid), 10),
exe, vncAgentSubcommand,
"--socket", socketPath,
// Drop privs inside the agent: launchctl asuser preserves the
// daemon's uid (root), so without this the capture/input/
// encoder paths would run as root for the lifetime of the
// session. validateAgentPeer on the daemon side also relies on
// the agent's effective uid matching consoleUID.
"--target-uid", strconv.FormatUint(uint64(uid), 10),
)
cmd.Env = append(os.Environ(), agentTokenEnvVar+"="+token)
stderr, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("agent stderr pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("launchctl asuser: %w", err)
}
go func() {
defer stderr.Close()
relogAgentStream(stderr)
}()
go func() { _ = cmd.Wait() }()
return nil
}
// waitForAgent dials the agent's Unix socket until it answers. Used to
// gate proxy attempts until the spawned process has finished its Start.
func waitForAgent(ctx context.Context, socketPath string, wait time.Duration) error {
var d net.Dialer
deadline := time.Now().Add(wait)
for time.Now().Before(deadline) {
if ctx.Err() != nil {
return ctx.Err()
}
dialCtx, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
c, err := d.DialContext(dialCtx, "unix", socketPath)
cancel()
if err == nil {
_ = c.Close()
return nil
}
time.Sleep(100 * time.Millisecond)
}
return fmt.Errorf("timeout dialing %s", socketPath)
}
// vncAgentRunning reports whether any vnc-agent process exists on the
// system. There is at most one agent per machine, so any match is "the"
// agent.
func vncAgentRunning() bool {
pids, err := vncAgentPIDs()
if err != nil {
log.Debugf("scan for vnc-agent: %v", err)
return false
}
return len(pids) > 0
}
// killAllVNCAgents sends SIGTERM to every process whose argv contains
// "vnc-agent", waits briefly for them to exit, and escalates to SIGKILL
// for any that remain. We enumerate kern.proc.all rather than
// kern.proc.uid because launchctl asuser preserves the caller's uid
// (root) on the spawned child, so a uid-scoped filter would never match.
func killAllVNCAgents() {
pids, err := vncAgentPIDs()
if err != nil {
log.Debugf("scan for vnc-agent: %v", err)
return
}
for _, pid := range pids {
_ = syscall.Kill(pid, syscall.SIGTERM)
}
if len(pids) == 0 {
return
}
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
remaining, _ := vncAgentPIDs()
if len(remaining) == 0 {
return
}
time.Sleep(100 * time.Millisecond)
}
leftover, _ := vncAgentPIDs()
for _, pid := range leftover {
_ = syscall.Kill(pid, syscall.SIGKILL)
}
}
// vncAgentPIDs returns the pids of vnc-agent subprocesses spawned from
// this binary. Matches exactly on argv[0] == our own executable path
// AND argv[1] == "vnc-agent" so unrelated processes that happen to have
// the same name elsewhere in argv are not targeted. Skips pid 0 and 1
// defensively.
func vncAgentPIDs() ([]int, error) {
procs, err := unix.SysctlKinfoProcSlice("kern.proc.all")
if err != nil {
return nil, fmt.Errorf("sysctl kern.proc.all: %w", err)
}
ownExe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("resolve own executable: %w", err)
}
var out []int
for i := range procs {
pid := int(procs[i].Proc.P_pid)
if pid <= 1 {
continue
}
argv, err := procArgv(pid)
if err != nil || !argvIsVNCAgent(argv, ownExe) {
continue
}
out = append(out, pid)
}
return out, nil
}
// procArgv reads the kernel's stored argv for pid via the kern.procargs2
// sysctl. Format: 4-byte argc, then argv[0..argc) each NUL-terminated,
// then envp, then padding. We only need argv so we stop after argc.
func procArgv(pid int) ([]string, error) {
raw, err := unix.SysctlRaw("kern.procargs2", pid)
if err != nil {
return nil, err
}
if len(raw) < 4 {
return nil, fmt.Errorf("procargs2 truncated")
}
argc := int(raw[0]) | int(raw[1])<<8 | int(raw[2])<<16 | int(raw[3])<<24
body := raw[4:]
// Skip the executable path (NUL-terminated) and any zero padding that
// follows before argv[0].
end := bytes.IndexByte(body, 0)
if end < 0 {
return nil, fmt.Errorf("procargs2 path unterminated")
}
body = body[end+1:]
for len(body) > 0 && body[0] == 0 {
body = body[1:]
}
args := make([]string, 0, argc)
for i := 0; i < argc; i++ {
end := bytes.IndexByte(body, 0)
if end < 0 {
break
}
args = append(args, string(body[:end]))
body = body[end+1:]
}
return args, nil
}
// argvIsVNCAgent reports whether argv belongs to a vnc-agent subprocess
// spawned from our binary. Requires argv[0] to match ownExe exactly and
// argv[1] to be the vnc-agent subcommand. Matches the spawn shape in
// spawnAgentForUser and rejects anything else.
func argvIsVNCAgent(argv []string, ownExe string) bool {
if len(argv) < 2 || ownExe == "" {
return false
}
return argv[0] == ownExe && argv[1] == vncAgentSubcommand
}

View File

@@ -1,305 +0,0 @@
//go:build darwin || windows
package server
import (
"bufio"
"bytes"
"context"
crand "crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"time"
log "github.com/sirupsen/logrus"
)
// errNoConsoleUser is the sentinel returned by sessionAgent.Resolve when
// the platform has no interactive user to attach a capture agent to (the
// macOS loginwindow state). Mapped to a distinct RFB reject code so the
// browser can show a meaningful message.
var errNoConsoleUser = errors.New("no user logged into console")
// sessionAgent abstracts the per-platform manager that spawns and tracks
// the user-session VNC agent. Resolve returns the agent's Unix-socket
// path, the shared per-spawn token, and the uid the agent was spawned
// under (used to validate peer credentials before the daemon hands the
// token to whoever is on the other end of the socket). Resolve may spawn
// the agent lazily.
type sessionAgent interface {
Resolve(ctx context.Context) (socketPath, token string, peerUID uint32, err error)
}
// prefixConn replays already-consumed header bytes ahead of the proxy
// stream by swapping in a different Reader on the same underlying Conn.
type prefixConn struct {
io.Reader
net.Conn
}
func (p *prefixConn) Read(b []byte) (int, error) { return p.Reader.Read(b) }
// handleServiceConnection runs the connection-header handshake (source
// check, Noise_IK auth) on conn, resolves the right per-session agent
// via sa, and proxies to it. Every accepted connection emits exactly one
// outcome line on the daemon log.
func (s *Server) handleServiceConnection(conn net.Conn, sa sessionAgent) {
start := time.Now()
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
if !s.isAllowedSource(conn.RemoteAddr()) {
connLog.Info("VNC connection rejected: source not allowed")
_ = conn.Close()
return
}
var headerBuf bytes.Buffer
tee := io.TeeReader(conn, &headerBuf)
teeConn := &prefixConn{Reader: tee, Conn: conn}
header, err := s.readConnectionHeader(teeConn)
if err != nil {
connLog.Infof("VNC connection rejected: header read failed: %v", err)
_ = conn.Close()
return
}
authedLog, sessionUserID, ok := s.authorizeSession(conn, header, connLog)
if !ok {
authedLog.Info("VNC connection rejected: auth failed")
return
}
if err := s.registerConnAuth(conn, header); err != nil {
rejectConnection(conn, codeMessage(RejectCodeAuthForbidden, err.Error()))
authedLog.Warnf("VNC connection rejected: %v", err)
return
}
decision, err := s.gateApproval(conn, header)
if err != nil {
authedLog.Infof("VNC connection rejected: %v", err)
return
}
if decision.ViewOnly {
authedLog.Info("VNC connection approved by user (view-only)")
} else if s.requireApproval {
authedLog.Info("VNC connection approved by user")
}
socketPath, token, peerUID, err := sa.Resolve(s.ctx)
if err != nil {
code := RejectCodeCapturerError
if errors.Is(err, errNoConsoleUser) {
code = RejectCodeNoConsoleUser
}
rejectConnection(conn, codeMessage(code, err.Error()))
authedLog.Warnf("VNC connection rejected: agent unavailable: %v", err)
return
}
var initiator string
if s.authorizer != nil {
initiator = s.authorizer.LookupSessionDisplayName(header.clientStatic)
}
sessionID := s.addSession(ActiveSessionInfo{
RemoteAddress: conn.RemoteAddr().String(),
Mode: modeString(header.mode),
Username: header.username,
UserID: sessionUserID,
Initiator: initiator,
}, conn)
defer s.removeSession(sessionID)
replayConn := &prefixConn{
Reader: io.MultiReader(&headerBuf, conn),
Conn: conn,
}
if err := proxyToAgent(s.ctx, replayConn, socketPath, token, peerUID, decision.ViewOnly, authedLog); err != nil {
rejectConnection(conn, codeMessage(RejectCodeCapturerError, err.Error()))
authedLog.Warnf("VNC connection rejected: agent unreachable: %v", err)
return
}
authedLog.Infof("VNC connection closed (%dms)", time.Since(start).Milliseconds())
}
const (
// agentTokenLen is the size of the random per-spawn token in bytes.
agentTokenLen = 32
// agentTokenEnvVar names the environment variable the daemon uses to
// hand the per-spawn token to the agent child. Out-of-band channels
// like this keep the secret out of the command line, where listings
// such as `ps` or Windows tasklist would expose it.
agentTokenEnvVar = "NB_VNC_AGENT_TOKEN" // #nosec G101 -- env var name, not a credential
// vncAgentSubcommand is the CLI subcommand the daemon invokes to start
// the per-session agent process. Must match cmd.vncAgentCmd.Use in
// client/cmd/vnc_agent.go.
vncAgentSubcommand = "vnc-agent"
)
// generateAuthToken returns a fresh hex-encoded random token for one
// daemon→agent session. The daemon hands this to the spawned agent
// out-of-band (env var on Windows) and verifies it on every connection
// the agent accepts.
func generateAuthToken() (string, error) {
b := make([]byte, agentTokenLen)
if _, err := crand.Read(b); err != nil {
return "", fmt.Errorf("read random: %w", err)
}
return hex.EncodeToString(b), nil
}
// proxyToAgent dials the per-session agent's Unix socket, validates the
// peer's kernel-asserted uid (so the daemon never hands its per-spawn
// token to an impostor that won the listen race), writes the raw token
// bytes plus a single view-only flag byte, then copies bytes both ways
// until either side closes. The token + flag prefix must precede any RFB
// byte so the agent's verifyAgentToken can run first. Returns nil once a
// stream is established; the caller is responsible for sending an
// RFB-level rejection on error so the client sees a reason instead of a
// bare timeout. authedLog receives one audit line per dispatched
// preamble so an operator can correlate daemon→agent traffic with the
// remote session that triggered it.
func proxyToAgent(ctx context.Context, client net.Conn, socketPath, authToken string, peerUID uint32, viewOnly bool, authedLog *log.Entry) error {
tokenBytes, err := hex.DecodeString(authToken)
if err != nil || len(tokenBytes) != agentTokenLen {
return fmt.Errorf("invalid auth token (len=%d): %w", len(tokenBytes), err)
}
agentConn, err := dialAgentWithRetry(ctx, socketPath)
if err != nil {
return fmt.Errorf("dial agent at %s: %w", socketPath, err)
}
if err := validateAgentPeer(agentConn, peerUID); err != nil {
_ = agentConn.Close()
return fmt.Errorf("agent peer validation failed: %w", err)
}
preamble := make([]byte, len(tokenBytes)+1)
copy(preamble, tokenBytes)
if viewOnly {
preamble[len(tokenBytes)] = 1
}
if _, err := agentConn.Write(preamble); err != nil {
_ = agentConn.Close()
return fmt.Errorf("send auth preamble to agent: %w", err)
}
// Audit: one line per successfully-dispatched daemon→agent preamble.
// Token printed as its first 8 hex chars (enough to correlate, not
// enough to use). Kept at Info so the default deployment captures it.
tokenFp := authToken
if len(tokenFp) > 8 {
tokenFp = tokenFp[:8]
}
if authedLog != nil {
authedLog.Infof("VNC IPC: dispatched preamble to agent socket=%s peer_uid=%d view_only=%v token_fp=%s", socketPath, peerUID, viewOnly, tokenFp)
}
defer client.Close()
defer agentConn.Close()
log.Debugf("proxy connected to agent, starting bidirectional copy")
done := make(chan struct{}, 2)
cp := func(label string, dst, src net.Conn) {
n, err := io.Copy(dst, src)
log.Debugf("proxy %s: %d bytes, err=%v", label, n, err)
done <- struct{}{}
}
go cp("client->agent", agentConn, client)
go cp("agent->client", client, agentConn)
<-done
return nil
}
// relogAgentStream reads log lines from the agent's stderr and re-emits
// them through the daemon's logrus, so the merged log keeps a single
// format. JSON lines (the agent's normal output) are parsed and dispatched
// by level; plain-text lines (cobra errors, panic traces) are forwarded
// verbatim so early-startup failures stay visible.
func relogAgentStream(r io.Reader) {
entry := log.WithField("component", "vnc-agent")
scanner := bufio.NewScanner(r)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Bytes()
if len(line) == 0 {
continue
}
if line[0] != '{' {
entry.Warn(string(line))
continue
}
var m map[string]any
if err := json.Unmarshal(line, &m); err != nil {
entry.Warn(string(line))
continue
}
msg, _ := m["msg"].(string)
if msg == "" {
continue
}
fields := make(log.Fields)
for k, v := range m {
switch k {
case "msg", "level", "time", "func":
continue
case "caller":
fields["source"] = v
default:
fields[k] = v
}
}
e := entry.WithFields(fields)
switch m["level"] {
case "error":
e.Error(msg)
case "warning":
e.Warn(msg)
case "debug":
e.Debug(msg)
case "trace":
e.Trace(msg)
default:
e.Info(msg)
}
}
}
// dialAgentWithRetry retries the loopback connect for up to ~10 s so the
// daemon does not race the agent's first listen. Returns the live conn or
// the final error. Aborts early when ctx is cancelled so a Stop() during
// service-mode startup doesn't leave a goroutine sleeping for 10 s.
func dialAgentWithRetry(ctx context.Context, addr string) (net.Conn, error) {
var d net.Dialer
var lastErr error
for range 50 {
if err := ctx.Err(); err != nil {
if lastErr == nil {
lastErr = err
}
return nil, lastErr
}
dialCtx, cancel := context.WithTimeout(ctx, time.Second)
c, err := d.DialContext(dialCtx, "unix", addr)
cancel()
if err == nil {
return c, nil
}
lastErr = err
select {
case <-ctx.Done():
if errors.Is(lastErr, context.Canceled) || errors.Is(lastErr, context.DeadlineExceeded) {
lastErr = ctx.Err()
}
return nil, lastErr
case <-time.After(200 * time.Millisecond):
}
}
return nil, lastErr
}

View File

@@ -1,46 +0,0 @@
//go:build darwin && !ios
package server
import (
"fmt"
"net"
"golang.org/x/sys/unix"
)
// validateAgentPeer enforces that the peer behind the just-connected Unix
// socket is the agent we expect it to be: a process running under
// expectedUID, with the right effective uid stamped by the kernel on the
// socket. Refuses (with a non-nil error) if anything else is listening on
// the path (an unrelated local process that won the listen race or
// squatted the path before us). Defends against the daemon shipping its
// per-spawn auth token to a process that isn't the spawned agent.
func validateAgentPeer(conn net.Conn, expectedUID uint32) error {
uconn, ok := conn.(*net.UnixConn)
if !ok {
return fmt.Errorf("peer cred: expected *net.UnixConn, got %T", conn)
}
raw, err := uconn.SyscallConn()
if err != nil {
return fmt.Errorf("peer cred: syscall conn: %w", err)
}
var cred *unix.Xucred
var inner error
ctlErr := raw.Control(func(fd uintptr) {
cred, inner = unix.GetsockoptXucred(int(fd), unix.SOL_LOCAL, unix.LOCAL_PEERCRED)
})
if ctlErr != nil {
return fmt.Errorf("peer cred: control: %w", ctlErr)
}
if inner != nil {
return fmt.Errorf("peer cred: getsockopt LOCAL_PEERCRED: %w", inner)
}
if cred == nil {
return fmt.Errorf("peer cred: nil xucred")
}
if cred.Uid != expectedUID {
return fmt.Errorf("peer cred: agent uid %d does not match expected %d", cred.Uid, expectedUID)
}
return nil
}

View File

@@ -1,115 +0,0 @@
//go:build darwin && !ios
package server
import (
"net"
"os"
"path/filepath"
"strings"
"sync"
"testing"
)
// TestValidateAgentPeerAcceptsOwnUID confirms the happy path: a Unix
// socket whose peer is the current process must validate when the
// expected uid matches the process's own. Both sides of a unix-socket
// pair share the same kernel cred, so this exercises the real getsockopt
// LOCAL_PEERCRED path.
func TestValidateAgentPeerAcceptsOwnUID(t *testing.T) {
dir := t.TempDir()
sockPath := filepath.Join(dir, "test.sock")
ln, err := net.Listen("unix", sockPath)
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c, err := ln.Accept()
if err == nil {
_ = c.Close()
}
}()
c, err := net.Dial("unix", sockPath)
if err != nil {
t.Fatalf("dial: %v", err)
}
defer c.Close()
if err := validateAgentPeer(c, uint32(os.Getuid())); err != nil {
t.Fatalf("validateAgentPeer rejected own uid: %v", err)
}
wg.Wait()
}
// TestValidateAgentPeerRejectsWrongUID ensures the validator fails when
// the expected uid differs from the kernel-reported peer uid. This is
// the path that catches a hostile process that won the listen race.
func TestValidateAgentPeerRejectsWrongUID(t *testing.T) {
dir := t.TempDir()
sockPath := filepath.Join(dir, "test.sock")
ln, err := net.Listen("unix", sockPath)
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c, err := ln.Accept()
if err == nil {
_ = c.Close()
}
}()
c, err := net.Dial("unix", sockPath)
if err != nil {
t.Fatalf("dial: %v", err)
}
defer c.Close()
// Pick a uid the test process certainly isn't running as.
wrongUID := uint32(os.Getuid()) + 1
err = validateAgentPeer(c, wrongUID)
if err == nil {
t.Fatal("expected mismatch error, got nil")
}
if !strings.Contains(err.Error(), "does not match expected") {
t.Fatalf("error should mention uid mismatch, got: %v", err)
}
wg.Wait()
}
// TestValidateAgentPeerRejectsNonUnix protects against being handed a
// non-Unix-socket connection (the validator can't enforce anything on
// e.g. a *net.TCPConn so it must refuse rather than silently pass).
func TestValidateAgentPeerRejectsNonUnix(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen tcp: %v", err)
}
defer ln.Close()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c, err := ln.Accept()
if err == nil {
_ = c.Close()
}
}()
c, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial tcp: %v", err)
}
defer c.Close()
if err := validateAgentPeer(c, 0); err == nil {
t.Fatal("expected refusal on non-unix conn, got nil")
}
wg.Wait()
}

View File

@@ -1,31 +0,0 @@
//go:build windows
package server
import (
"net"
)
// validateAgentPeer is a documented no-op on Windows. AF_UNIX on Windows
// exposes no SO_PEERCRED equivalent and no supported API to recover the
// peer process from an accepted AF_UNIX connection, so the daemon cannot
// match the connected peer against the agent PID it spawned the way the
// darwin path does via LOCAL_PEERCRED. The Windows trust model therefore
// rests on three other measures, none of which assume the socket path is
// secret:
//
// - the socket lives in a dedicated directory (agentSocketDir) created
// with a DACL granting only SYSTEM and Administrators, so an
// unprivileged local user cannot create or squat a socket there;
// - each spawn uses a cryptographically random socket name, so the path
// is unguessable before the agent binds it;
// - the daemon publishes the path only after confirming the spawned
// agent is listening (see waitForAgentListening), and gates every
// connection on the per-spawn auth-token preamble that follows this
// call.
//
// If a future Windows release exposes peer-PID retrieval for AF_UNIX,
// this function should verify the peer against the spawned agent PID.
func validateAgentPeer(_ net.Conn, _ uint32) error {
return nil
}

View File

@@ -1,747 +0,0 @@
//go:build windows
package server
import (
"context"
crand "crypto/rand"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"net"
"os"
"path/filepath"
"runtime"
"sync"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
stillActive = 259
tokenPrimary = 1
securityImpersonation = 2
tokenSessionID = 12
createUnicodeEnvironment = 0x00000400
createNoWindow = 0x08000000
createSuspended = 0x00000004
createBreakawayFromJob = 0x01000000
)
var (
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
advapi32 = windows.NewLazySystemDLL("advapi32.dll")
userenv = windows.NewLazySystemDLL("userenv.dll")
procWTSGetActiveConsoleSessionId = kernel32.NewProc("WTSGetActiveConsoleSessionId")
procCreateJobObjectW = kernel32.NewProc("CreateJobObjectW")
procSetInformationJobObject = kernel32.NewProc("SetInformationJobObject")
procAssignProcessToJobObject = kernel32.NewProc("AssignProcessToJobObject")
procSetTokenInformation = advapi32.NewProc("SetTokenInformation")
procCreateEnvironmentBlock = userenv.NewProc("CreateEnvironmentBlock")
procDestroyEnvironmentBlock = userenv.NewProc("DestroyEnvironmentBlock")
wtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll")
procWTSEnumerateSessionsW = wtsapi32.NewProc("WTSEnumerateSessionsW")
procWTSFreeMemory = wtsapi32.NewProc("WTSFreeMemory")
procWTSQuerySessionInformation = wtsapi32.NewProc("WTSQuerySessionInformationW")
)
// GetCurrentSessionID returns the session ID of the current process.
func GetCurrentSessionID() uint32 {
var token windows.Token
if err := windows.OpenProcessToken(windows.CurrentProcess(),
windows.TOKEN_QUERY, &token); err != nil {
return 0
}
defer token.Close()
var id uint32
var ret uint32
_ = windows.GetTokenInformation(token, windows.TokenSessionId,
(*byte)(unsafe.Pointer(&id)), 4, &ret)
return id
}
func getConsoleSessionID() uint32 {
r, _, _ := procWTSGetActiveConsoleSessionId.Call()
return uint32(r)
}
const (
wtsActive = 0
wtsConnected = 1
wtsDisconnected = 4
)
// getActiveSessionID returns the session ID of the best session to attach to.
// On a Windows Server with no console display attached, session 1 still
// reports WTSActive (login screen "owns" the console), so a naive
// first-active-wins pick lands on a session with no actual rendering.
// Preference order:
// 1. Active session with a user logged in (RDP user in session ≥2)
// 2. Active session without a user (console at login screen)
// 3. Console session ID
func getActiveSessionID() uint32 {
var sessionInfo uintptr
var count uint32
r, _, _ := procWTSEnumerateSessionsW.Call(
0, // WTS_CURRENT_SERVER_HANDLE
0, // reserved
1, // version
uintptr(unsafe.Pointer(&sessionInfo)),
uintptr(unsafe.Pointer(&count)),
)
if r == 0 || count == 0 {
return getConsoleSessionID()
}
defer func() { _, _, _ = procWTSFreeMemory.Call(sessionInfo) }()
type wtsSession struct {
SessionID uint32
Station *uint16
State uint32
}
sessions := unsafe.Slice((*wtsSession)(unsafe.Pointer(sessionInfo)), count)
var withUser uint32
var withUserFound bool
var anyActive uint32
var anyActiveFound bool
for _, s := range sessions {
if s.SessionID == 0 {
continue
}
if s.State != wtsActive {
continue
}
if !anyActiveFound {
anyActive = s.SessionID
anyActiveFound = true
}
if !withUserFound && wtsSessionHasUser(s.SessionID) {
withUser = s.SessionID
withUserFound = true
}
}
if withUserFound {
return withUser
}
if anyActiveFound {
return anyActive
}
return getConsoleSessionID()
}
// wtsSessionHasUser returns true if the session has a non-empty user name,
// i.e. someone is logged in (vs. the login/Welcome screen). The console
// session at the lock screen has WTSUserName == "".
const wtsUserName = 5
func wtsSessionHasUser(sessionID uint32) bool {
var buf uintptr
var bytesReturned uint32
r, _, _ := procWTSQuerySessionInformation.Call(
0, // WTS_CURRENT_SERVER_HANDLE
uintptr(sessionID),
uintptr(wtsUserName),
uintptr(unsafe.Pointer(&buf)),
uintptr(unsafe.Pointer(&bytesReturned)),
)
if r == 0 || buf == 0 {
return false
}
defer func() { _, _, _ = procWTSFreeMemory.Call(buf) }()
// First UTF-16 code unit non-zero ⇒ non-empty username.
return *(*uint16)(unsafe.Pointer(buf)) != 0
}
// getSystemTokenForSession duplicates the current SYSTEM token and sets its
// session ID so the spawned process runs in the target session. Using a SYSTEM
// token gives access to both Default and Winlogon desktops plus UIPI bypass.
func getSystemTokenForSession(sessionID uint32) (windows.Token, error) {
var cur windows.Token
if err := windows.OpenProcessToken(windows.CurrentProcess(),
windows.MAXIMUM_ALLOWED, &cur); err != nil {
return 0, fmt.Errorf("OpenProcessToken: %w", err)
}
defer cur.Close()
var dup windows.Token
if err := windows.DuplicateTokenEx(cur, windows.MAXIMUM_ALLOWED, nil,
securityImpersonation, tokenPrimary, &dup); err != nil {
return 0, fmt.Errorf("DuplicateTokenEx: %w", err)
}
sid := sessionID
r, _, err := procSetTokenInformation.Call(
uintptr(dup),
uintptr(tokenSessionID),
uintptr(unsafe.Pointer(&sid)),
unsafe.Sizeof(sid),
)
if r == 0 {
dup.Close()
return 0, fmt.Errorf("SetTokenInformation(SessionId=%d): %w", sessionID, err)
}
return dup, nil
}
// injectEnvVar appends a KEY=VALUE entry to a Unicode environment block.
// The block is a sequence of null-terminated UTF-16 strings, terminated by
// an extra null. Returns the new []uint16 backing slice; the caller must
// hold the returned slice alive until CreateProcessAsUser completes.
func injectEnvVar(envBlock uintptr, key, value string) []uint16 {
entry := key + "=" + value
// Walk the existing block to find its total length.
ptr := (*uint16)(unsafe.Pointer(envBlock))
var totalChars int
for {
ch := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars)*2))
if ch == 0 {
// Check for double-null terminator.
next := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars+1)*2))
totalChars++
if next == 0 {
// End of block (don't count the final null yet, we'll rebuild).
break
}
} else {
totalChars++
}
}
entryUTF16, _ := windows.UTF16FromString(entry)
// New block: existing entries + new entry (null-terminated) + final null.
newLen := totalChars + len(entryUTF16) + 1
newBlock := make([]uint16, newLen)
// Copy existing entries (up to but not including the final null).
for i := range totalChars {
newBlock[i] = *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(i)*2))
}
copy(newBlock[totalChars:], entryUTF16)
newBlock[newLen-1] = 0 // final null terminator
return newBlock
}
func spawnAgentInSession(sessionID uint32, socketPath, authToken string, jobHandle windows.Handle) (windows.Handle, error) {
token, err := getSystemTokenForSession(sessionID)
if err != nil {
return 0, fmt.Errorf("get SYSTEM token for session %d: %w", sessionID, err)
}
defer token.Close()
var envBlock uintptr
r, _, e := procCreateEnvironmentBlock.Call(
uintptr(unsafe.Pointer(&envBlock)),
uintptr(token),
0,
)
if r == 0 {
// Without an environment block we cannot inject NB_VNC_AGENT_TOKEN;
// the agent would start unauthenticated. Abort instead of launching.
return 0, fmt.Errorf("CreateEnvironmentBlock: %w", e)
}
defer func() { _, _, _ = procDestroyEnvironmentBlock.Call(envBlock) }()
// Inject the auth token into the environment block so it doesn't appear
// in the process command line (visible via tasklist/wmic). injectedBlock
// must stay alive until CreateProcessAsUser returns.
injectedBlock := injectEnvVar(envBlock, agentTokenEnvVar, authToken)
exePath, err := os.Executable()
if err != nil {
return 0, fmt.Errorf("get executable path: %w", err)
}
cmdLine := fmt.Sprintf(`"%s" %s --socket %q`, exePath, vncAgentSubcommand, socketPath)
cmdLineW, err := windows.UTF16PtrFromString(cmdLine)
if err != nil {
return 0, fmt.Errorf("UTF16 cmdline: %w", err)
}
// Create an inheritable pipe for the agent's stderr so we can relog
// its output in the service process.
var sa windows.SecurityAttributes
sa.Length = uint32(unsafe.Sizeof(sa))
sa.InheritHandle = 1
var stderrRead, stderrWrite windows.Handle
if err := windows.CreatePipe(&stderrRead, &stderrWrite, &sa, 0); err != nil {
return 0, fmt.Errorf("create stderr pipe: %w", err)
}
// The read end must NOT be inherited by the child.
_ = windows.SetHandleInformation(stderrRead, windows.HANDLE_FLAG_INHERIT, 0)
desktop, _ := windows.UTF16PtrFromString(`WinSta0\Default`)
si := windows.StartupInfo{
Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})),
Desktop: desktop,
Flags: windows.STARTF_USESHOWWINDOW | windows.STARTF_USESTDHANDLES,
ShowWindow: 0,
StdErr: stderrWrite,
StdOutput: stderrWrite,
}
var pi windows.ProcessInformation
var envPtr *uint16
if len(injectedBlock) > 0 {
envPtr = &injectedBlock[0]
} else if envBlock != 0 {
envPtr = (*uint16)(unsafe.Pointer(envBlock))
}
// CREATE_SUSPENDED so we can assign the process to our Job Object
// before it executes. Without this the agent could spawn its own child
// processes and have them inherit the SCM service-job (not ours), or
// briefly listen on the agent port before we tear it down on rollback.
// CREATE_BREAKAWAY_FROM_JOB lets the child leave the SCM-managed
// service job; harmless if that job allows breakaway, and is required
// before AssignProcessToJobObject can succeed in the no-nested-jobs case.
err = windows.CreateProcessAsUser(
token, nil, cmdLineW,
nil, nil, true, // inheritHandles=true for the pipe
createUnicodeEnvironment|createNoWindow|createSuspended|createBreakawayFromJob,
envPtr, nil, &si, &pi,
)
runtime.KeepAlive(injectedBlock)
// Close the write end in the parent so reads will get EOF when the child exits.
_ = windows.CloseHandle(stderrWrite)
if err != nil {
_ = windows.CloseHandle(stderrRead)
return 0, fmt.Errorf("CreateProcessAsUser: %w", err)
}
if jobHandle != 0 {
r, _, e := procAssignProcessToJobObject.Call(uintptr(jobHandle), uintptr(pi.Process))
if r == 0 {
log.Warnf("assign agent to job object: %v (orphan possible on service crash)", e)
}
}
if _, err := windows.ResumeThread(pi.Thread); err != nil {
_ = windows.CloseHandle(pi.Thread)
_ = windows.TerminateProcess(pi.Process, 1)
_ = windows.CloseHandle(pi.Process)
_ = windows.CloseHandle(stderrRead)
return 0, fmt.Errorf("ResumeThread: %w", err)
}
_ = windows.CloseHandle(pi.Thread)
// Relog agent output in the service with a [vnc-agent] prefix.
go relogAgentOutput(stderrRead)
log.Infof("spawned agent PID=%d in session %d on %s", pi.ProcessId, sessionID, socketPath)
return pi.Process, nil
}
// sessionManager monitors the active console session and ensures a VNC agent
// process is running in it. When the session changes (e.g., user switch, RDP
// connect/disconnect), it kills the old agent and spawns a new one. Each
// spawn picks a per-session Unix-socket path the agent binds and the
// daemon dials over local IPC.
type sessionManager struct {
mu sync.Mutex
agentProc windows.Handle
everSpawned bool
agentStartedAt time.Time
spawnFailures int
nextSpawnAt time.Time
sessionID uint32
authToken string
socketPath string
done chan struct{}
// jobHandle owns the agent processes via a Windows Job Object with
// JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE. When the service exits or crashes,
// the OS closes the handle and terminates every assigned agent: no
// orphaned agent processes holding a socket across restarts.
jobHandle windows.Handle
}
const (
// agentSocketDir is a dedicated subdirectory under C:\Windows\Temp that
// the daemon creates with a restrictive DACL (SYSTEM + Administrators
// only). The default ACL on C:\Windows\Temp grants BUILTIN\Users
// create-file rights, so the agent socket must not live directly there:
// an unprivileged local user could pre-create a predictable path and
// intercept the daemon→agent stream. Both the daemon and the agent run
// as SYSTEM, so a SYSTEM-write-only directory is sufficient.
agentSocketDir = `C:\Windows\Temp\netbird-vnc`
// agentSocketDirSDDL grants full access to Local System (SY) and the
// Builtin Administrators group (BA) only, with the DACL protected
// (P) from inheritance so the parent's BUILTIN\Users grant does not
// flow in. AI is omitted; PAI marks the DACL protected and auto-
// inherited entries cleared.
agentSocketDirSDDL = "D:PAI(A;;FA;;;SY)(A;;FA;;;BA)"
// agentSocketRandomLen is the number of random bytes mixed into each
// per-spawn socket name so the path is unguessable before the agent
// owns it.
agentSocketRandomLen = 16
// agentReadyTimeout bounds how long the daemon waits for the freshly
// spawned agent to bind and accept on its socket before treating the
// spawn as failed.
agentReadyTimeout = 5 * time.Second
)
func newSessionManager() *sessionManager {
m := &sessionManager{sessionID: ^uint32(0), done: make(chan struct{})}
if h, err := createKillOnCloseJob(); err != nil {
log.Warnf("create job object for vnc-agent (orphan agents possible after crash): %v", err)
} else {
m.jobHandle = h
}
return m
}
// createKillOnCloseJob returns a Job Object configured so that closing its
// handle (process exit or explicit Close) terminates every process assigned
// to it. Used to keep orphaned vnc-agent processes from outliving the service.
func createKillOnCloseJob() (windows.Handle, error) {
r, _, e := procCreateJobObjectW.Call(0, 0)
if r == 0 {
return 0, fmt.Errorf("CreateJobObject: %w", e)
}
job := windows.Handle(r)
// JOBOBJECT_EXTENDED_LIMIT_INFORMATION on amd64 = 144 bytes.
//
// JOBOBJECT_BASIC_LIMIT_INFORMATION (64 bytes with alignment padding)
// PerProcessUserTimeLimit LARGE_INTEGER off 0
// PerJobUserTimeLimit LARGE_INTEGER off 8
// LimitFlags DWORD off 16
// [4 byte pad to align SIZE_T]
// MinimumWorkingSetSize SIZE_T off 24
// MaximumWorkingSetSize SIZE_T off 32
// ActiveProcessLimit DWORD off 40
// [4 byte pad to align ULONG_PTR]
// Affinity ULONG_PTR off 48
// PriorityClass DWORD off 56
// SchedulingClass DWORD off 60
// IO_COUNTERS (48) + 4 * SIZE_T (32) = 144 total.
//
// We only set LimitFlags; the rest stays zero.
const sizeofExtended = 144
const offsetLimitFlags = 16
const jobObjectExtendedLimitInformation = 9
const jobObjectLimitKillOnJobClose = 0x00002000
var info [sizeofExtended]byte
binary.LittleEndian.PutUint32(info[offsetLimitFlags:offsetLimitFlags+4], jobObjectLimitKillOnJobClose)
r, _, e = procSetInformationJobObject.Call(
uintptr(job),
uintptr(jobObjectExtendedLimitInformation),
uintptr(unsafe.Pointer(&info[0])),
uintptr(sizeofExtended),
)
if r == 0 {
_ = windows.CloseHandle(job)
return 0, fmt.Errorf("SetInformationJobObject(KILL_ON_JOB_CLOSE): %w", e)
}
return job, nil
}
// Resolve returns the current agent socket path, shared token, and the
// uid the agent runs under (0 on Windows since the agent runs as
// SYSTEM in the interactive session; see validateAgentPeer for the
// Windows trust model). The path is only published after the spawned
// agent is confirmed listening, so a caller never receives a socket a
// squatter could be holding. When no agent is spawned yet (initial
// boot, between session switches, or permanently disabled when
// SE_TCB_NAME is missing) it surfaces a distinct error so the daemon
// can reject the connection with a meaningful message instead of timing
// out the proxy dial.
func (m *sessionManager) Resolve(_ context.Context) (string, string, uint32, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.socketPath == "" {
return "", "", 0, errAgentNotReady
}
return m.socketPath, m.authToken, 0, nil
}
var errAgentNotReady = errors.New("VNC agent not running yet")
// Stop signals the session manager to exit its polling loop and closes the
// Job Object handle, which Windows uses as the trigger to terminate every
// agent process this manager spawned.
func (m *sessionManager) Stop() {
select {
case <-m.done:
default:
close(m.done)
}
m.mu.Lock()
if m.jobHandle != 0 {
_ = windows.CloseHandle(m.jobHandle)
m.jobHandle = 0
}
m.mu.Unlock()
}
func (m *sessionManager) run() {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
if !m.tick() {
return
}
select {
case <-m.done:
m.mu.Lock()
m.killAgent()
m.mu.Unlock()
return
case <-ticker.C:
}
}
}
// tick performs one session/agent-state update. Returns false if the manager
// should permanently stop (e.g. missing SYSTEM privileges).
func (m *sessionManager) tick() bool {
sid := getActiveSessionID()
m.mu.Lock()
defer m.mu.Unlock()
m.handleSessionChange(sid)
m.reapExitedAgent()
return m.maybeSpawnAgent(sid)
}
func (m *sessionManager) handleSessionChange(sid uint32) {
if sid == m.sessionID {
return
}
log.Infof("active session changed: %d -> %d", m.sessionID, sid)
m.killAgent()
m.sessionID = sid
}
func (m *sessionManager) reapExitedAgent() {
if m.agentProc == 0 {
return
}
var code uint32
if err := windows.GetExitCodeProcess(m.agentProc, &code); err != nil {
log.Debugf("GetExitCodeProcess: %v", err)
return
}
if code == stillActive {
return
}
m.scheduleNextSpawn(code, time.Since(m.agentStartedAt))
if err := windows.CloseHandle(m.agentProc); err != nil {
log.Debugf("close agent handle: %v", err)
}
m.agentProc = 0
m.authToken = ""
m.socketPath = ""
}
// scheduleNextSpawn applies an exponential backoff on fast crashes (<5s) and
// resets immediately otherwise.
func (m *sessionManager) scheduleNextSpawn(exitCode uint32, lifetime time.Duration) {
if lifetime < 5*time.Second {
m.spawnFailures++
backoff := time.Duration(1<<min(m.spawnFailures, 5)) * time.Second
if backoff > 30*time.Second {
backoff = 30 * time.Second
}
m.nextSpawnAt = time.Now().Add(backoff)
log.Warnf("agent exited (code=%d) after %v, retrying in %v (failures=%d)", exitCode, lifetime.Round(time.Millisecond), backoff, m.spawnFailures)
return
}
m.spawnFailures = 0
m.nextSpawnAt = time.Time{}
log.Infof("agent exited (code=%d) after %v, respawning", exitCode, lifetime.Round(time.Second))
}
// maybeSpawnAgent spawns a new agent if there's no current one and the backoff
// window has elapsed. Returns false to permanently stop the manager when the
// service lacks the privileges needed to spawn cross-session.
func (m *sessionManager) maybeSpawnAgent(sid uint32) bool {
if m.agentProc != 0 || sid == 0xFFFFFFFF || !time.Now().After(m.nextSpawnAt) {
return true
}
if err := ensureAgentSocketDir(); err != nil {
log.Warnf("prepare agent socket dir: %v", err)
m.nextSpawnAt = time.Now().Add(5 * time.Second)
return true
}
// The leaf name carries a cryptographically random component so a local
// user cannot pre-create the path at a guessable location. The session
// id is kept for diagnostics only; security does not rely on it.
socketPath, err := newAgentSocketPath(sid)
if err != nil {
log.Warnf("generate agent socket path: %v", err)
return true
}
// Covers a previous-run crash that escaped Job Object kill-on-close.
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
log.Debugf("clear stale agent socket %s: %v", socketPath, err)
}
token, err := generateAuthToken()
if err != nil {
log.Warnf("generate agent auth token: %v", err)
return true
}
h, err := spawnAgentInSession(sid, socketPath, token, m.jobHandle)
if err != nil {
if errors.Is(err, windows.ERROR_PRIVILEGE_NOT_HELD) {
// SE_TCB_NAME (token-impersonation across sessions) is only
// granted to SYSTEM. Without it spawnAgent will fail every 2
// seconds forever: log once and give up.
log.Warnf("VNC service mode disabled: agent spawn requires SYSTEM privileges (got: %v)", err)
return false
}
log.Warnf("spawn agent in session %d: %v", sid, err)
return true
}
// Gate on listen-readiness before publishing the path: do not hand a
// caller a socket the agent has not bound yet. On timeout, fail closed
// by killing the agent and leaving socketPath/authToken unset so
// Resolve keeps returning errAgentNotReady.
if err := waitForAgentListening(socketPath, agentReadyTimeout); err != nil {
log.Warnf("agent in session %d did not start listening: %v", sid, err)
_ = windows.TerminateProcess(h, 1)
_ = windows.CloseHandle(h)
if rmErr := os.Remove(socketPath); rmErr != nil && !os.IsNotExist(rmErr) {
log.Debugf("clear unready agent socket %s: %v", socketPath, rmErr)
}
m.scheduleNextSpawn(0, 0)
return true
}
m.authToken = token
m.socketPath = socketPath
m.agentProc = h
m.agentStartedAt = time.Now()
m.everSpawned = true
return true
}
// ensureAgentSocketDir creates the dedicated socket directory with a
// restrictive DACL (SYSTEM + Administrators only). A pre-existing directory
// is torn down and recreated rather than reused: it may have been created by
// an unprivileged user with a permissive ACL, and it only ever holds our
// transient sockets, so removing it loses nothing. Fails closed: returns an
// error if the directory cannot be created with the intended security.
func ensureAgentSocketDir() error {
sd, err := windows.SecurityDescriptorFromString(agentSocketDirSDDL)
if err != nil {
return fmt.Errorf("parse socket dir SDDL: %w", err)
}
var sa windows.SecurityAttributes
sa.Length = uint32(unsafe.Sizeof(sa))
sa.SecurityDescriptor = sd
dirW, err := windows.UTF16PtrFromString(agentSocketDir)
if err != nil {
return fmt.Errorf("encode socket dir path: %w", err)
}
err = windows.CreateDirectory(dirW, &sa)
if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
if rmErr := os.RemoveAll(agentSocketDir); rmErr != nil {
return fmt.Errorf("remove pre-existing socket dir %s: %w", agentSocketDir, rmErr)
}
err = windows.CreateDirectory(dirW, &sa)
}
if err != nil {
return fmt.Errorf("create socket dir %s: %w", agentSocketDir, err)
}
return nil
}
// newAgentSocketPath returns a per-spawn socket path inside the secured
// socket directory. The leaf name mixes a cryptographically random component
// with the session id (for diagnostics) so the path is unguessable before the
// agent binds it.
func newAgentSocketPath(sessionID uint32) (string, error) {
b := make([]byte, agentSocketRandomLen)
if _, err := crand.Read(b); err != nil {
return "", fmt.Errorf("read random: %w", err)
}
name := fmt.Sprintf("netbird-vnc-%d-%s.sock", sessionID, hex.EncodeToString(b))
return filepath.Join(agentSocketDir, name), nil
}
// waitForAgentListening dials the agent's Unix socket until it answers or the
// timeout elapses. Mirrors the darwin readiness gate so the daemon never
// exposes a socket path before the legitimate agent owns it.
func waitForAgentListening(socketPath string, wait time.Duration) error {
var d net.Dialer
deadline := time.Now().Add(wait)
var lastErr error
for time.Now().Before(deadline) {
c, err := d.Dial("unix", socketPath)
if err == nil {
_ = c.Close()
return nil
}
lastErr = err
time.Sleep(100 * time.Millisecond)
}
if lastErr == nil {
lastErr = fmt.Errorf("timeout")
}
return fmt.Errorf("dial %s: %w", socketPath, lastErr)
}
func (m *sessionManager) killAgent() {
if m.agentProc == 0 {
return
}
_ = windows.TerminateProcess(m.agentProc, 0)
_ = windows.CloseHandle(m.agentProc)
m.agentProc = 0
m.authToken = ""
m.socketPath = ""
log.Info("killed old agent")
}
// relogAgentOutput reads log lines from the agent's stderr pipe and
// relogs them with the service's formatter. The *os.File owns the
// underlying handle, so closing it suffices.
func relogAgentOutput(pipe windows.Handle) {
f := os.NewFile(uintptr(pipe), "vnc-agent-stderr")
defer func() { _ = f.Close() }()
relogAgentStream(f)
}
// logCleanupCall invokes a Windows syscall used solely as a cleanup primitive
// (CloseClipboard, ReleaseDC, etc.) and logs failures at trace level. The
// indirection lets us satisfy errcheck without scattering ignored returns at
// each call site, while still capturing diagnostic info when the OS reports
// a failure.
func logCleanupCall(name string, proc *windows.LazyProc) {
r, _, err := proc.Call()
if r == 0 && err != nil && err != windows.NTE_OP_OK {
log.Tracef("%s: %v", name, err)
}
}
// logCleanupCallArgs is logCleanupCall with one argument; common pattern for
// release-by-handle syscalls.
func logCleanupCallArgs(name string, proc *windows.LazyProc, args ...uintptr) {
r, _, err := proc.Call(args...)
if r == 0 && err != nil && err != windows.NTE_OP_OK {
log.Tracef("%s: %v", name, err)
}
}

View File

@@ -1,631 +0,0 @@
//go:build darwin && !ios
package server
import (
"errors"
"fmt"
"hash/maphash"
"image"
"os"
"runtime"
"strconv"
"sync"
"sync/atomic"
"time"
"unsafe"
"github.com/ebitengine/purego"
log "github.com/sirupsen/logrus"
)
var darwinCaptureOnce sync.Once
var (
cgMainDisplayID func() uint32
cgDisplayPixelsWide func(uint32) uintptr
cgDisplayPixelsHigh func(uint32) uintptr
cgDisplayCreateImage func(uint32) uintptr
cgImageGetWidth func(uintptr) uintptr
cgImageGetHeight func(uintptr) uintptr
cgImageGetBytesPerRow func(uintptr) uintptr
cgImageGetBitsPerPixel func(uintptr) uintptr
cgImageGetDataProvider func(uintptr) uintptr
cgDataProviderCopyData func(uintptr) uintptr
cgImageRelease func(uintptr)
cfDataGetLength func(uintptr) int64
cfDataGetBytePtr func(uintptr) uintptr
cfRelease func(uintptr)
cgRequestScreenCaptureAccess func() bool
cgEventCreate func(uintptr) uintptr
cgEventGetLocation func(uintptr) cgPoint
darwinCaptureReady bool
)
// cgPoint mirrors CoreGraphics CGPoint: two doubles, 16 bytes, returned
// in registers on Darwin amd64/arm64. Used to receive cursor coordinates
// from CGEventGetLocation via purego.
type cgPoint struct {
X, Y float64
}
func initDarwinCapture() {
darwinCaptureOnce.Do(func() {
cg, err := purego.Dlopen("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
log.Debugf("load CoreGraphics: %v", err)
return
}
cf, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
log.Debugf("load CoreFoundation: %v", err)
return
}
purego.RegisterLibFunc(&cgMainDisplayID, cg, "CGMainDisplayID")
purego.RegisterLibFunc(&cgDisplayPixelsWide, cg, "CGDisplayPixelsWide")
purego.RegisterLibFunc(&cgDisplayPixelsHigh, cg, "CGDisplayPixelsHigh")
purego.RegisterLibFunc(&cgDisplayCreateImage, cg, "CGDisplayCreateImage")
purego.RegisterLibFunc(&cgImageGetWidth, cg, "CGImageGetWidth")
purego.RegisterLibFunc(&cgImageGetHeight, cg, "CGImageGetHeight")
purego.RegisterLibFunc(&cgImageGetBytesPerRow, cg, "CGImageGetBytesPerRow")
purego.RegisterLibFunc(&cgImageGetBitsPerPixel, cg, "CGImageGetBitsPerPixel")
purego.RegisterLibFunc(&cgImageGetDataProvider, cg, "CGImageGetDataProvider")
purego.RegisterLibFunc(&cgDataProviderCopyData, cg, "CGDataProviderCopyData")
purego.RegisterLibFunc(&cgImageRelease, cg, "CGImageRelease")
purego.RegisterLibFunc(&cfDataGetLength, cf, "CFDataGetLength")
purego.RegisterLibFunc(&cfDataGetBytePtr, cf, "CFDataGetBytePtr")
purego.RegisterLibFunc(&cfRelease, cf, "CFRelease")
// CGRequestScreenCaptureAccess (macOS 11+) prompts on first call and
// is a cheap no-op once granted. The Preflight companion is unreliable
// on Sequoia (returns false even when access is granted), so we drive
// the permission flow from actual capture failures instead.
if sym, err := purego.Dlsym(cg, "CGRequestScreenCaptureAccess"); err == nil {
purego.RegisterFunc(&cgRequestScreenCaptureAccess, sym)
}
// CGEventCreate / CGEventGetLocation feed the cursor position used
// by remote-cursor compositing. Optional; absence reports as a
// position-source error and disables that feature on this host.
if sym, err := purego.Dlsym(cg, "CGEventCreate"); err == nil {
purego.RegisterFunc(&cgEventCreate, sym)
}
if sym, err := purego.Dlsym(cg, "CGEventGetLocation"); err == nil {
purego.RegisterFunc(&cgEventGetLocation, sym)
}
darwinCaptureReady = true
})
}
// CGCapturer captures the macOS main display using Core Graphics.
type CGCapturer struct {
displayID uint32
w, h int
// downscale is 1 for pixel-perfect, 2 for Retina 2:1 box-filter downscale.
downscale int
hashSeed maphash.Seed
lastHash uint64
hasHash bool
// cursor lazily binds the private CGSCreateCurrentCursorImage symbol
// so we can emit the Cursor pseudo-encoding without a per-frame cost
// on builds that never query it.
cursorOnce sync.Once
cursor *cgCursor
}
// PrimeScreenCapturePermission triggers the macOS Screen Recording
// permission prompt without creating a full capturer. The platform wiring
// calls this at VNC-server enable time so the user sees the prompt the
// moment they turn the feature on. CGRequestScreenCaptureAccess is a
// no-op when the grant already exists, so calling it on every enable is
// cheap and safe.
func PrimeScreenCapturePermission() {
initDarwinCapture()
if !darwinCaptureReady {
return
}
if cgRequestScreenCaptureAccess != nil {
cgRequestScreenCaptureAccess()
}
}
// notifyScreenRecordingMissing nudges the user once per agent process to
// approve Screen Recording. The capturer init retries on backoff when the
// grant is missing; without the sync.Once we would reopen System Settings
// every tick and flood the daemon log with the same warning.
var screenRecordingNotifyOnce sync.Once
func notifyScreenRecordingMissing() {
screenRecordingNotifyOnce.Do(func() {
if cgRequestScreenCaptureAccess != nil {
cgRequestScreenCaptureAccess()
}
openPrivacyPane("Privacy_ScreenCapture")
log.Warn("Screen Recording permission not granted. " +
"Opened System Settings > Privacy & Security > Screen Recording; enable netbird and restart.")
})
}
// NewCGCapturer creates a screen capturer for the main display.
func NewCGCapturer() (*CGCapturer, error) {
initDarwinCapture()
if !darwinCaptureReady {
return nil, fmt.Errorf("CoreGraphics not available")
}
displayID := cgMainDisplayID()
c := &CGCapturer{displayID: displayID, downscale: 1, hashSeed: maphash.MakeSeed()}
img, err := c.Capture()
if err != nil {
notifyScreenRecordingMissing()
return nil, fmt.Errorf("probe capture: %w", err)
}
nativeW := img.Rect.Dx()
nativeH := img.Rect.Dy()
c.hasHash = false
if nativeW == 0 || nativeH == 0 {
return nil, errors.New("display dimensions are zero")
}
logicalW := int(cgDisplayPixelsWide(displayID))
logicalH := int(cgDisplayPixelsHigh(displayID))
// Enable 2:1 downscale on Retina unless explicitly disabled. Cuts pixel
// count 4x, shrinking convert, diff, and wire data proportionally.
if !retinaDownscaleDisabled() && nativeW >= 2*logicalW && nativeH >= 2*logicalH && nativeW%2 == 0 && nativeH%2 == 0 {
c.downscale = 2
}
c.w = nativeW / c.downscale
c.h = nativeH / c.downscale
log.Infof("macOS capturer ready: %dx%d (native %dx%d, logical %dx%d, downscale=%d, display=%d)",
c.w, c.h, nativeW, nativeH, logicalW, logicalH, c.downscale, displayID)
return c, nil
}
func retinaDownscaleDisabled() bool {
v := os.Getenv(EnvVNCDisableDownscale)
if v == "" {
return false
}
disabled, err := strconv.ParseBool(v)
if err != nil {
log.Warnf("parse %s: %v", EnvVNCDisableDownscale, err)
return false
}
return disabled
}
// Width returns the screen width.
func (c *CGCapturer) Width() int { return c.w }
// Height returns the screen height.
func (c *CGCapturer) Height() int { return c.h }
// CaptureInto writes a fresh frame directly into dst, skipping the
// per-frame image.RGBA allocation that Capture() does. It always fills
// dst: the capturer is shared across all sessions, so dedup here would
// starve every consumer but the first one to poll after a change.
// Per-session prevFrame diffing in the session layer handles no-op frames.
func (c *CGCapturer) CaptureInto(dst *image.RGBA) error {
cgImage := cgDisplayCreateImage(c.displayID)
if cgImage == 0 {
return fmt.Errorf("CGDisplayCreateImage returned nil (screen recording permission?)")
}
defer cgImageRelease(cgImage)
w := int(cgImageGetWidth(cgImage))
h := int(cgImageGetHeight(cgImage))
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
bpp := int(cgImageGetBitsPerPixel(cgImage))
provider := cgImageGetDataProvider(cgImage)
if provider == 0 {
return fmt.Errorf("CGImageGetDataProvider returned nil")
}
cfData := cgDataProviderCopyData(provider)
if cfData == 0 {
return fmt.Errorf("CGDataProviderCopyData returned nil")
}
defer cfRelease(cfData)
dataLen := int(cfDataGetLength(cfData))
dataPtr := cfDataGetBytePtr(cfData)
if dataPtr == 0 || dataLen == 0 {
return fmt.Errorf("empty image data")
}
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
ds := c.downscale
if ds < 1 {
ds = 1
}
outW := w / ds
outH := h / ds
if dst.Rect.Dx() != outW || dst.Rect.Dy() != outH {
return fmt.Errorf("dst size mismatch: dst=%dx%d capturer=%dx%d",
dst.Rect.Dx(), dst.Rect.Dy(), outW, outH)
}
bytesPerPixel := bpp / 8
if bytesPerPixel == 4 && ds == 1 {
convertBGRAToRGBA(dst.Pix, dst.Stride, src, bytesPerRow, w, h)
return nil
}
if bytesPerPixel == 4 && ds == 2 {
convertBGRAToRGBADownscale2(dst.Pix, dst.Stride, src, bytesPerRow, outW, outH)
return nil
}
for row := 0; row < outH; row++ {
srcOff := row * ds * bytesPerRow
dstOff := row * dst.Stride
for col := 0; col < outW; col++ {
si := srcOff + col*ds*bytesPerPixel
di := dstOff + col*4
dst.Pix[di+0] = src[si+2]
dst.Pix[di+1] = src[si+1]
dst.Pix[di+2] = src[si+0]
dst.Pix[di+3] = 0xff
}
}
return nil
}
func (c *CGCapturer) Capture() (*image.RGBA, error) {
cgImage := cgDisplayCreateImage(c.displayID)
if cgImage == 0 {
return nil, fmt.Errorf("CGDisplayCreateImage returned nil (screen recording permission?)")
}
defer cgImageRelease(cgImage)
w := int(cgImageGetWidth(cgImage))
h := int(cgImageGetHeight(cgImage))
bytesPerRow := int(cgImageGetBytesPerRow(cgImage))
bpp := int(cgImageGetBitsPerPixel(cgImage))
provider := cgImageGetDataProvider(cgImage)
if provider == 0 {
return nil, fmt.Errorf("CGImageGetDataProvider returned nil")
}
cfData := cgDataProviderCopyData(provider)
if cfData == 0 {
return nil, fmt.Errorf("CGDataProviderCopyData returned nil")
}
defer cfRelease(cfData)
dataLen := int(cfDataGetLength(cfData))
dataPtr := cfDataGetBytePtr(cfData)
if dataPtr == 0 || dataLen == 0 {
return nil, fmt.Errorf("empty image data")
}
src := unsafe.Slice((*byte)(unsafe.Pointer(dataPtr)), dataLen)
hash := maphash.Bytes(c.hashSeed, src)
if c.hasHash && hash == c.lastHash {
return nil, errFrameUnchanged
}
c.lastHash = hash
c.hasHash = true
ds := c.downscale
if ds < 1 {
ds = 1
}
outW := w / ds
outH := h / ds
img := image.NewRGBA(image.Rect(0, 0, outW, outH))
bytesPerPixel := bpp / 8
switch {
case bytesPerPixel == 4 && ds == 1:
convertBGRAToRGBA(img.Pix, img.Stride, src, bytesPerRow, w, h)
case bytesPerPixel == 4 && ds == 2:
convertBGRAToRGBADownscale2(img.Pix, img.Stride, src, bytesPerRow, outW, outH)
default:
convertBGRAToRGBAGeneric(img.Pix, img.Stride, src, bytesPerRow, bgraDownscaleParams{outW: outW, outH: outH, bytesPerPixel: bytesPerPixel, ds: ds})
}
return img, nil
}
type bgraDownscaleParams struct {
outW, outH, bytesPerPixel, ds int
}
// convertBGRAToRGBAGeneric is the slow per-pixel fallback for non-4-bytes
// or non-1/2 downscale formats. Always available regardless of the source
// format quirks the fast paths optimize for.
func convertBGRAToRGBAGeneric(dst []byte, dstStride int, src []byte, srcStride int, p bgraDownscaleParams) {
for row := 0; row < p.outH; row++ {
srcOff := row * p.ds * srcStride
dstOff := row * dstStride
for col := 0; col < p.outW; col++ {
si := srcOff + col*p.ds*p.bytesPerPixel
di := dstOff + col*4
dst[di+0] = src[si+2]
dst[di+1] = src[si+1]
dst[di+2] = src[si+0]
dst[di+3] = 0xff
}
}
}
// convertBGRAToRGBADownscale2 averages every 2x2 BGRA block into one RGBA
// output pixel, parallelised across GOMAXPROCS cores. outW and outH are the
// destination dimensions (source is 2*outW by 2*outH).
func convertBGRAToRGBADownscale2(dst []byte, dstStride int, src []byte, srcStride, outW, outH int) {
workers := runtime.GOMAXPROCS(0)
if workers > outH {
workers = outH
}
if workers < 1 || outH < 32 {
workers = 1
}
convertRows := func(y0, y1 int) {
for row := y0; row < y1; row++ {
srcRow0 := 2 * row * srcStride
srcRow1 := srcRow0 + srcStride
dstOff := row * dstStride
for col := 0; col < outW; col++ {
s0 := srcRow0 + col*8
s1 := srcRow1 + col*8
b := (uint32(src[s0]) + uint32(src[s0+4]) + uint32(src[s1]) + uint32(src[s1+4])) >> 2
g := (uint32(src[s0+1]) + uint32(src[s0+5]) + uint32(src[s1+1]) + uint32(src[s1+5])) >> 2
r := (uint32(src[s0+2]) + uint32(src[s0+6]) + uint32(src[s1+2]) + uint32(src[s1+6])) >> 2
di := dstOff + col*4
dst[di+0] = byte(r)
dst[di+1] = byte(g)
dst[di+2] = byte(b)
dst[di+3] = 0xff
}
}
}
if workers == 1 {
convertRows(0, outH)
return
}
var wg sync.WaitGroup
chunk := (outH + workers - 1) / workers
for i := 0; i < workers; i++ {
y0 := i * chunk
y1 := y0 + chunk
if y1 > outH {
y1 = outH
}
if y0 >= y1 {
break
}
wg.Add(1)
go func(y0, y1 int) {
defer wg.Done()
convertRows(y0, y1)
}(y0, y1)
}
wg.Wait()
}
// convertBGRAToRGBA swaps R/B channels using uint32 word operations, and
// parallelises across GOMAXPROCS cores for large images.
func convertBGRAToRGBA(dst []byte, dstStride int, src []byte, srcStride, w, h int) {
workers := runtime.GOMAXPROCS(0)
if workers > h {
workers = h
}
if workers < 1 || h < 64 {
workers = 1
}
convertRows := func(y0, y1 int) {
rowBytes := w * 4
for row := y0; row < y1; row++ {
dstRow := dst[row*dstStride : row*dstStride+rowBytes]
srcRow := src[row*srcStride : row*srcStride+rowBytes]
dstU := unsafe.Slice((*uint32)(unsafe.Pointer(&dstRow[0])), w)
srcU := unsafe.Slice((*uint32)(unsafe.Pointer(&srcRow[0])), w)
for i, p := range srcU {
dstU[i] = (p & 0xff00ff00) | ((p & 0x000000ff) << 16) | ((p & 0x00ff0000) >> 16) | 0xff000000
}
}
}
if workers == 1 {
convertRows(0, h)
return
}
var wg sync.WaitGroup
chunk := (h + workers - 1) / workers
for i := 0; i < workers; i++ {
y0 := i * chunk
y1 := y0 + chunk
if y1 > h {
y1 = h
}
if y0 >= y1 {
break
}
wg.Add(1)
go func(y0, y1 int) {
defer wg.Done()
convertRows(y0, y1)
}(y0, y1)
}
wg.Wait()
}
// MacPoller wraps CGCapturer with a staleness-cached on-demand Capture:
// sessions drive captures themselves from their encoder goroutine, so we
// don't need a background ticker. The last result is cached for a short
// window so concurrent sessions coalesce into one capture.
//
// The capturer is allocated lazily on first use and released when all
// clients disconnect. Init is retried with backoff because the user may
// grant Screen Recording permission while the server is already running.
type MacPoller struct {
mu sync.Mutex
capturer *CGCapturer
w, h int
lastFrame *image.RGBA
lastAt time.Time
clients atomic.Int32
initFails int
initBackoffUntil time.Time
closed bool
}
// macInitRetryBackoffFor returns the delay we wait between init attempts
// after consecutive failures. Screen Recording permission is a one-shot
// user grant, so after several failures we back off aggressively.
func macInitRetryBackoffFor(fails int) time.Duration {
switch {
case fails > 15:
return 30 * time.Second
case fails > 5:
return 10 * time.Second
default:
return 2 * time.Second
}
}
// NewMacPoller creates a lazy on-demand capturer for the macOS display.
func NewMacPoller() *MacPoller {
return &MacPoller{}
}
// Wake is a no-op retained for API compatibility. With on-demand capture
// there is no background retry loop to kick: init happens on the next
// Capture/ClientConnect call.
func (p *MacPoller) Wake() {
// intentional no-op
}
// ClientConnect increments the active client count and eagerly initialises
// the capturer so the first FBUpdateRequest doesn't pay the init cost.
func (p *MacPoller) ClientConnect() {
if p.clients.Add(1) == 1 {
p.mu.Lock()
_ = p.ensureCapturerLocked()
p.mu.Unlock()
}
}
// ClientDisconnect decrements the active client count. On the last
// disconnect the capturer is released.
func (p *MacPoller) ClientDisconnect() {
if p.clients.Add(-1) == 0 {
p.mu.Lock()
p.capturer = nil
p.lastFrame = nil
p.mu.Unlock()
}
}
// Close releases all resources.
func (p *MacPoller) Close() {
p.mu.Lock()
p.closed = true
p.capturer = nil
p.lastFrame = nil
p.mu.Unlock()
}
// Width returns the screen width. Triggers lazy init if needed.
func (p *MacPoller) Width() int {
p.mu.Lock()
defer p.mu.Unlock()
_ = p.ensureCapturerLocked()
return p.w
}
// Height returns the screen height. Triggers lazy init if needed.
func (p *MacPoller) Height() int {
p.mu.Lock()
defer p.mu.Unlock()
_ = p.ensureCapturerLocked()
return p.h
}
// CaptureInto fills dst directly via the underlying capturer, bypassing
// the freshness cache.
func (p *MacPoller) CaptureInto(dst *image.RGBA) error {
p.mu.Lock()
defer p.mu.Unlock()
if err := p.ensureCapturerLocked(); err != nil {
return err
}
if err := p.capturer.CaptureInto(dst); err != nil {
p.capturer = nil
return fmt.Errorf("macos capture: %w", err)
}
return nil
}
// Capture returns a fresh frame, serving from the short-lived cache if a
// previous caller captured within freshWindow. Handles the
// errFrameUnchanged return from CGCapturer by reusing the cached frame.
func (p *MacPoller) Capture() (*image.RGBA, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.lastFrame != nil && time.Since(p.lastAt) < freshWindow {
return p.lastFrame, nil
}
if err := p.ensureCapturerLocked(); err != nil {
return nil, err
}
img, err := p.capturer.Capture()
if errors.Is(err, errFrameUnchanged) {
if p.lastFrame != nil {
p.lastAt = time.Now()
return p.lastFrame, nil
}
return nil, err
}
if err != nil {
// Drop the capturer so the next call retries init; the display stream
// can die if the session changes or permissions are revoked.
p.capturer = nil
return nil, fmt.Errorf("macos capture: %w", err)
}
p.lastFrame = img
p.lastAt = time.Now()
return img, nil
}
// ensureCapturerLocked initialises the underlying CGCapturer if needed.
// Caller must hold p.mu.
func (p *MacPoller) ensureCapturerLocked() error {
if p.closed {
return fmt.Errorf("poller closed")
}
if p.capturer != nil {
return nil
}
if time.Now().Before(p.initBackoffUntil) {
return fmt.Errorf("macOS capturer unavailable (retry scheduled)")
}
c, err := NewCGCapturer()
if err != nil {
p.initFails++
p.initBackoffUntil = time.Now().Add(macInitRetryBackoffFor(p.initFails))
if p.initFails == 1 || p.initFails%10 == 0 {
log.Warnf("macOS capturer: %v (attempt %d)", err, p.initFails)
} else {
log.Debugf("macOS capturer: %v (attempt %d)", err, p.initFails)
}
return err
}
p.initFails = 0
p.capturer = c
p.w, p.h = c.Width(), c.Height()
return nil
}
var _ ScreenCapturer = (*MacPoller)(nil)

View File

@@ -1,99 +0,0 @@
//go:build windows
package server
import (
"errors"
"fmt"
"image"
"github.com/kirides/go-d3d/d3d11"
"github.com/kirides/go-d3d/outputduplication"
)
// dxgiCapturer captures the desktop using DXGI Desktop Duplication.
// Provides GPU-accelerated capture with native dirty rect tracking.
// Only works from the interactive user session, not Session 0.
//
// Uses a double-buffer: DXGI writes into img, then we copy to the current
// output buffer and hand it out. Alternating between two output buffers
// avoids allocating a new image.RGBA per frame (~8MB at 1080p, 30fps).
type dxgiCapturer struct {
dup *outputduplication.OutputDuplicator
device *d3d11.ID3D11Device
ctx *d3d11.ID3D11DeviceContext
img *image.RGBA
out [2]*image.RGBA
outIdx int
width int
height int
}
func newDXGICapturer() (*dxgiCapturer, error) {
device, deviceCtx, err := d3d11.NewD3D11Device()
if err != nil {
return nil, fmt.Errorf("create D3D11 device: %w", err)
}
dup, err := outputduplication.NewIDXGIOutputDuplication(device, deviceCtx, 0)
if err != nil {
device.Release()
deviceCtx.Release()
return nil, fmt.Errorf("create output duplication: %w", err)
}
w, h := screenSize()
if w == 0 || h == 0 {
dup.Release()
device.Release()
deviceCtx.Release()
return nil, fmt.Errorf("screen dimensions are zero")
}
rect := image.Rect(0, 0, w, h)
c := &dxgiCapturer{
dup: dup,
device: device,
ctx: deviceCtx,
img: image.NewRGBA(rect),
out: [2]*image.RGBA{image.NewRGBA(rect), image.NewRGBA(rect)},
width: w,
height: h,
}
// Grab the initial frame with a longer timeout to ensure we have
// a valid image before returning.
_ = dup.GetImage(c.img, 2000)
return c, nil
}
func (c *dxgiCapturer) capture() (*image.RGBA, error) {
err := c.dup.GetImage(c.img, 100)
if err != nil && !errors.Is(err, outputduplication.ErrNoImageYet) {
return nil, err
}
// Copy into the next output buffer. The DesktopCapturer hands out the
// returned pointer to VNC sessions that read pixels concurrently, so we
// alternate between two pre-allocated buffers instead of allocating per frame.
out := c.out[c.outIdx]
c.outIdx ^= 1
copy(out.Pix, c.img.Pix)
return out, nil
}
func (c *dxgiCapturer) close() {
if c.dup != nil {
c.dup.Release()
c.dup = nil
}
if c.ctx != nil {
c.ctx.Release()
c.ctx = nil
}
if c.device != nil {
c.device.Release()
c.device = nil
}
}

View File

@@ -1,148 +0,0 @@
//go:build freebsd
package server
import (
"fmt"
"image"
"sync"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// FreeBSD vt(4) framebuffer ioctl numbers from sys/fbio.h.
//
// #define FBIOGTYPE _IOR('F', 0, struct fbtype)
//
// _IOR(g, n, t) on FreeBSD: dir=2 (read) <<30 | (sizeof(t) & 0x1fff)<<16
// | (g<<8) | n. sizeof(struct fbtype)=24 → 0x40184600.
const fbioGType = 0x40184600
func defaultFBPath() string { return "/dev/ttyv0" }
// fbType mirrors FreeBSD's struct fbtype.
type fbType struct {
FbType int32
FbHeight int32
FbWidth int32
FbDepth int32
FbCMSize int32
FbSize int32
}
// FBCapturer reads pixels from FreeBSD's vt(4) framebuffer device. The
// vt(4) console exposes the active framebuffer via ttyv0 with FBIOGTYPE
// for geometry and mmap for backing memory. Pixel layout is assumed to
// be 32bpp BGRA (the common case for KMS-backed vt); fbtype doesn't
// expose channel offsets, so we don't try to handle exotic layouts here.
type FBCapturer struct {
mu sync.Mutex
path string
fd int
mmap []byte
w, h int
bpp int
stride int
closeOnce sync.Once
}
// NewFBCapturer opens the given vt(4) device and queries its geometry.
func NewFBCapturer(path string) (*FBCapturer, error) {
if path == "" {
path = defaultFBPath()
}
fd, err := unix.Open(path, unix.O_RDWR, 0)
if err != nil {
return nil, fmt.Errorf("open %s: %w", path, err)
}
var fbt fbType
if _, _, e := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), fbioGType, uintptr(unsafe.Pointer(&fbt))); e != 0 {
unix.Close(fd)
return nil, fmt.Errorf("FBIOGTYPE: %v", e)
}
if fbt.FbDepth != 16 && fbt.FbDepth != 24 && fbt.FbDepth != 32 {
unix.Close(fd)
return nil, fmt.Errorf("unsupported framebuffer depth: %d", fbt.FbDepth)
}
if fbt.FbWidth <= 0 || fbt.FbHeight <= 0 || fbt.FbSize <= 0 {
unix.Close(fd)
return nil, fmt.Errorf("invalid framebuffer geometry: %dx%d size=%d", fbt.FbWidth, fbt.FbHeight, fbt.FbSize)
}
mm, err := unix.Mmap(fd, 0, int(fbt.FbSize), unix.PROT_READ, unix.MAP_SHARED)
if err != nil {
unix.Close(fd)
return nil, fmt.Errorf("mmap %s: %w (vt may not support mmap on this driver, e.g. virtio_gpu)", path, err)
}
bpp := int(fbt.FbDepth)
stride := int(fbt.FbWidth) * (bpp / 8)
c := &FBCapturer{
path: path,
fd: fd, // valid fd >= 0; we use -1 as the closed sentinel
mmap: mm,
w: int(fbt.FbWidth),
h: int(fbt.FbHeight),
bpp: bpp,
stride: stride,
}
log.Infof("framebuffer capturer ready: %s %dx%d bpp=%d (freebsd vt)", path, c.w, c.h, c.bpp)
return c, nil
}
// Width returns the framebuffer width.
func (c *FBCapturer) Width() int { return c.w }
// Height returns the framebuffer height.
func (c *FBCapturer) Height() int { return c.h }
// Capture allocates a fresh image and fills it with the current
// framebuffer contents.
func (c *FBCapturer) Capture() (*image.RGBA, error) {
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
if err := c.CaptureInto(img); err != nil {
return nil, err
}
return img, nil
}
// CaptureInto reads the framebuffer directly into dst.Pix. Assumes BGRA
// for 32bpp; the FreeBSD fbtype struct doesn't expose channel offsets.
func (c *FBCapturer) CaptureInto(dst *image.RGBA) error {
c.mu.Lock()
defer c.mu.Unlock()
if dst.Rect.Dx() != c.w || dst.Rect.Dy() != c.h {
return fmt.Errorf("dst size mismatch: dst=%dx%d fb=%dx%d",
dst.Rect.Dx(), dst.Rect.Dy(), c.w, c.h)
}
switch c.bpp {
case 32:
// vt(4) on KMS framebuffers is BGRA: byte 0=B, 1=G, 2=R.
swizzleBGRAtoRGBA(dst.Pix, c.mmap[:c.h*c.stride])
case 24:
swizzleFB24(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h)
case 16:
swizzleFB16RGB565(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h)
}
return nil
}
// Close releases the framebuffer mmap and file descriptor. Serialized with
// CaptureInto via c.mu so an in-flight capture can't read freed memory.
func (c *FBCapturer) Close() {
c.closeOnce.Do(func() {
c.mu.Lock()
defer c.mu.Unlock()
if c.mmap != nil {
_ = unix.Munmap(c.mmap)
c.mmap = nil
}
if c.fd >= 0 {
_ = unix.Close(c.fd)
c.fd = -1
}
})
}

View File

@@ -1,229 +0,0 @@
//go:build linux && !android
package server
import (
"encoding/binary"
"fmt"
"image"
"sync"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// Linux framebuffer ioctls (linux/fb.h).
const (
fbioGetVScreenInfo = 0x4600
fbioGetFScreenInfo = 0x4602
)
func defaultFBPath() string { return "/dev/fb0" }
// fbVarScreenInfo mirrors the kernel's fb_var_screeninfo. Only the
// fields we use are mapped; the rest are absorbed into _padN.
type fbVarScreenInfo struct {
Xres, Yres uint32
XresVirtual, YresVirtual uint32
XOffset, YOffset uint32
BitsPerPixel uint32
Grayscale uint32
RedOffset, RedLen, RedMSBR uint32
GreenOffset, GreenLen, GreenMSBR uint32
BlueOffset, BlueLen, BlueMSBR uint32
TranspOffset, TranspLen, TranspM uint32
NonStd uint32
Activate uint32
Height, Width uint32
AccelFlags uint32
PixClock uint32
LeftMargin, RightMargin uint32
UpperMargin, LowerMargin uint32
HsyncLen, VsyncLen uint32
Sync uint32
Vmode uint32
Rotate uint32
Colorspace uint32
_pad [4]uint32
}
// fbFixScreenInfo mirrors fb_fix_screeninfo. We only need LineLength.
type fbFixScreenInfo struct {
IDStr [16]byte
SmemStart uint64
SmemLen uint32
Type uint32
TypeAux uint32
Visual uint32
XPanStep uint16
YPanStep uint16
YWrapStep uint16
_pad0 uint16
LineLength uint32
MmioStart uint64
MmioLen uint32
Accel uint32
Capabilities uint16
_reserved [2]uint16
}
// FBCapturer reads pixels straight from the Linux framebuffer device.
// Used as a fallback when X11 isn't available, e.g. on a headless box at
// the kernel console or the display manager's pre-login screen on machines
// without an Xorg server. The framebuffer must be mmap()-able under our
// process privileges (typically the netbird service runs as root).
type FBCapturer struct {
mu sync.Mutex
path string
fd int
mmap []byte
w, h int
bpp int
stride int
rOff uint32
gOff uint32
bOff uint32
rLen uint32
gLen uint32
bLen uint32
closeOnce sync.Once
}
// NewFBCapturer opens the given framebuffer device (/dev/fbN) and
// queries its current geometry + pixel format.
func NewFBCapturer(path string) (*FBCapturer, error) {
if path == "" {
path = "/dev/fb0"
}
fd, err := unix.Open(path, unix.O_RDONLY, 0)
if err != nil {
return nil, fmt.Errorf("open %s: %w", path, err)
}
var vinfo fbVarScreenInfo
if _, _, e := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), fbioGetVScreenInfo, uintptr(unsafe.Pointer(&vinfo))); e != 0 {
unix.Close(fd)
return nil, fmt.Errorf("FBIOGET_VSCREENINFO: %v", e)
}
var finfo fbFixScreenInfo
if _, _, e := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), fbioGetFScreenInfo, uintptr(unsafe.Pointer(&finfo))); e != 0 {
unix.Close(fd)
return nil, fmt.Errorf("FBIOGET_FSCREENINFO: %v", e)
}
bpp := int(vinfo.BitsPerPixel)
if bpp != 16 && bpp != 24 && bpp != 32 {
unix.Close(fd)
return nil, fmt.Errorf("unsupported framebuffer bpp: %d", bpp)
}
size := int(finfo.LineLength) * int(vinfo.Yres)
if size <= 0 {
unix.Close(fd)
return nil, fmt.Errorf("invalid framebuffer dimensions: stride=%d h=%d", finfo.LineLength, vinfo.Yres)
}
mm, err := unix.Mmap(fd, 0, size, unix.PROT_READ, unix.MAP_SHARED)
if err != nil {
unix.Close(fd)
return nil, fmt.Errorf("mmap %s: %w", path, err)
}
c := &FBCapturer{
path: path,
fd: fd,
mmap: mm,
w: int(vinfo.Xres),
h: int(vinfo.Yres),
bpp: bpp,
stride: int(finfo.LineLength),
rOff: vinfo.RedOffset,
gOff: vinfo.GreenOffset,
bOff: vinfo.BlueOffset,
rLen: vinfo.RedLen,
gLen: vinfo.GreenLen,
bLen: vinfo.BlueLen,
}
log.Infof("framebuffer capturer ready: %s %dx%d bpp=%d r=%d/%d g=%d/%d b=%d/%d",
path, c.w, c.h, c.bpp, c.rOff, c.rLen, c.gOff, c.gLen, c.bOff, c.bLen)
return c, nil
}
// Width returns the framebuffer width in pixels.
func (c *FBCapturer) Width() int { return c.w }
// Height returns the framebuffer height in pixels.
func (c *FBCapturer) Height() int { return c.h }
// Capture allocates a fresh image and fills it with the current
// framebuffer contents.
func (c *FBCapturer) Capture() (*image.RGBA, error) {
img := image.NewRGBA(image.Rect(0, 0, c.w, c.h))
if err := c.CaptureInto(img); err != nil {
return nil, err
}
return img, nil
}
// CaptureInto reads the framebuffer directly into dst.Pix.
func (c *FBCapturer) CaptureInto(dst *image.RGBA) error {
c.mu.Lock()
defer c.mu.Unlock()
if dst.Rect.Dx() != c.w || dst.Rect.Dy() != c.h {
return fmt.Errorf("dst size mismatch: dst=%dx%d fb=%dx%d",
dst.Rect.Dx(), dst.Rect.Dy(), c.w, c.h)
}
switch c.bpp {
case 32:
swizzleFB32(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h, channelShifts{R: c.rOff, G: c.gOff, B: c.bOff})
case 24:
swizzleFB24(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h)
case 16:
swizzleFB16RGB565(dst.Pix, dst.Stride, c.mmap, c.stride, c.w, c.h)
}
return nil
}
// Close releases the framebuffer mmap and file descriptor. Serialized with
// CaptureInto via c.mu so an in-flight capture can't read freed memory.
func (c *FBCapturer) Close() {
c.closeOnce.Do(func() {
c.mu.Lock()
defer c.mu.Unlock()
if c.mmap != nil {
_ = unix.Munmap(c.mmap)
c.mmap = nil
}
if c.fd >= 0 {
_ = unix.Close(c.fd)
c.fd = -1
}
})
}
// channelShifts groups the bit offsets for the R/G/B channels in a packed
// uint32 framebuffer pixel. Bundling avoids drowning per-row callers in a
// 9-parameter signature.
type channelShifts struct {
R, G, B uint32
}
// swizzleFB32 handles 32-bit framebuffers with arbitrary R/G/B channel
// offsets. Pulls one pixel per uint32, then masks each channel into the
// destination RGBA byte order.
func swizzleFB32(dst []byte, dstStride int, src []byte, srcStride, w, h int, shifts channelShifts) {
for y := 0; y < h; y++ {
srcRow := src[y*srcStride : y*srcStride+w*4]
dstRow := dst[y*dstStride:]
for x := 0; x < w; x++ {
pix := binary.LittleEndian.Uint32(srcRow[x*4 : x*4+4])
dstRow[x*4+0] = byte(pix >> shifts.R)
dstRow[x*4+1] = byte(pix >> shifts.G)
dstRow[x*4+2] = byte(pix >> shifts.B)
dstRow[x*4+3] = 0xff
}
}
}

View File

@@ -1,149 +0,0 @@
//go:build (linux && !android) || freebsd
package server
import (
"image"
"sync"
)
// FBPoller wraps FBCapturer with the same lifecycle (ClientConnect /
// ClientDisconnect, lazy init) as X11Poller, so it slots into the same
// session plumbing without code changes upstream. The concrete
// FBCapturer is platform-specific (capture_fb_linux.go / _freebsd.go);
// this file owns the cross-platform glue.
type FBPoller struct {
mu sync.Mutex
path string
capturer *FBCapturer
w, h int
clients int32
}
// NewFBPoller returns a poller that opens path on first use. Empty path
// defaults to /dev/fb0 on Linux and /dev/ttyv0 on FreeBSD.
func NewFBPoller(path string) *FBPoller {
if path == "" {
path = defaultFBPath()
}
return &FBPoller{path: path}
}
// ClientConnect eagerly initialises the capturer on first connect.
func (p *FBPoller) ClientConnect() {
p.mu.Lock()
defer p.mu.Unlock()
p.clients++
if p.clients == 1 {
_ = p.ensureCapturerLocked()
}
}
// ClientDisconnect closes the capturer when the last client leaves.
func (p *FBPoller) ClientDisconnect() {
p.mu.Lock()
defer p.mu.Unlock()
p.clients--
if p.clients <= 0 && p.capturer != nil {
p.capturer.Close()
p.capturer = nil
}
}
// Width returns the framebuffer width, doing lazy init if needed.
func (p *FBPoller) Width() int {
p.mu.Lock()
defer p.mu.Unlock()
_ = p.ensureCapturerLocked()
return p.w
}
// Height returns the framebuffer height, doing lazy init if needed.
func (p *FBPoller) Height() int {
p.mu.Lock()
defer p.mu.Unlock()
_ = p.ensureCapturerLocked()
return p.h
}
// Capture takes a fresh frame.
func (p *FBPoller) Capture() (*image.RGBA, error) {
p.mu.Lock()
defer p.mu.Unlock()
if err := p.ensureCapturerLocked(); err != nil {
return nil, err
}
return p.capturer.Capture()
}
// CaptureInto fills dst directly.
func (p *FBPoller) CaptureInto(dst *image.RGBA) error {
p.mu.Lock()
defer p.mu.Unlock()
if err := p.ensureCapturerLocked(); err != nil {
return err
}
return p.capturer.CaptureInto(dst)
}
// Close releases all framebuffer resources.
func (p *FBPoller) Close() {
p.mu.Lock()
defer p.mu.Unlock()
if p.capturer != nil {
p.capturer.Close()
p.capturer = nil
}
}
func (p *FBPoller) ensureCapturerLocked() error {
if p.capturer != nil {
return nil
}
c, err := NewFBCapturer(p.path)
if err != nil {
return err
}
p.capturer = c
p.w, p.h = c.Width(), c.Height()
return nil
}
var _ ScreenCapturer = (*FBPoller)(nil)
var _ captureIntoer = (*FBPoller)(nil)
// swizzleFB24 handles 24-bit packed framebuffers (B,G,R triplets).
// Shared between Linux and FreeBSD framebuffer paths.
func swizzleFB24(dst []byte, dstStride int, src []byte, srcStride, w, h int) {
for y := 0; y < h; y++ {
srcRow := src[y*srcStride : y*srcStride+w*3]
dstRow := dst[y*dstStride:]
for x := 0; x < w; x++ {
b := srcRow[x*3+0]
g := srcRow[x*3+1]
r := srcRow[x*3+2]
dstRow[x*4+0] = r
dstRow[x*4+1] = g
dstRow[x*4+2] = b
dstRow[x*4+3] = 0xff
}
}
}
// swizzleFB16RGB565 handles 16bpp RGB 565 framebuffers.
func swizzleFB16RGB565(dst []byte, dstStride int, src []byte, srcStride, w, h int) {
for y := 0; y < h; y++ {
srcRow := src[y*srcStride : y*srcStride+w*2]
dstRow := dst[y*dstStride:]
for x := 0; x < w; x++ {
pix := uint16(srcRow[x*2]) | uint16(srcRow[x*2+1])<<8
r := byte((pix >> 11) & 0x1f)
g := byte((pix >> 5) & 0x3f)
b := byte(pix & 0x1f)
dstRow[x*4+0] = (r << 3) | (r >> 2)
dstRow[x*4+1] = (g << 2) | (g >> 4)
dstRow[x*4+2] = (b << 3) | (b >> 2)
dstRow[x*4+3] = 0xff
}
}
}

View File

@@ -1,586 +0,0 @@
//go:build windows
package server
import (
"fmt"
"image"
"runtime"
"sync"
"sync/atomic"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
var (
gdi32 = windows.NewLazySystemDLL("gdi32.dll")
user32 = windows.NewLazySystemDLL("user32.dll")
procGetDC = user32.NewProc("GetDC")
procReleaseDC = user32.NewProc("ReleaseDC")
procCreateCompatDC = gdi32.NewProc("CreateCompatibleDC")
procCreateDIBSection = gdi32.NewProc("CreateDIBSection")
procSelectObject = gdi32.NewProc("SelectObject")
procDeleteObject = gdi32.NewProc("DeleteObject")
procDeleteDC = gdi32.NewProc("DeleteDC")
procBitBlt = gdi32.NewProc("BitBlt")
procGetSystemMetrics = user32.NewProc("GetSystemMetrics")
// Desktop switching for service/Session 0 capture.
procOpenInputDesktop = user32.NewProc("OpenInputDesktop")
procSetThreadDesktop = user32.NewProc("SetThreadDesktop")
procCloseDesktop = user32.NewProc("CloseDesktop")
procOpenWindowStation = user32.NewProc("OpenWindowStationW")
procSetProcessWindowStation = user32.NewProc("SetProcessWindowStation")
procCloseWindowStation = user32.NewProc("CloseWindowStation")
procGetUserObjectInformationW = user32.NewProc("GetUserObjectInformationW")
)
const uoiName = 2
const (
smCxScreen = 0
smCyScreen = 1
srccopy = 0x00CC0020
captureBlt = 0x40000000
dibRgbColors = 0
)
type bitmapInfoHeader struct {
Size uint32
Width int32
Height int32
Planes uint16
BitCount uint16
Compression uint32
SizeImage uint32
XPelsPerMeter int32
YPelsPerMeter int32
ClrUsed uint32
ClrImportant uint32
}
type bitmapInfo struct {
Header bitmapInfoHeader
}
// setupInteractiveWindowStation associates the current process with WinSta0,
// the interactive window station. This is required for a SYSTEM service in
// Session 0 to call OpenInputDesktop for screen capture and input injection.
func setupInteractiveWindowStation() error {
name, err := windows.UTF16PtrFromString("WinSta0")
if err != nil {
return fmt.Errorf("UTF16 WinSta0: %w", err)
}
hWinSta, _, err := procOpenWindowStation.Call(
uintptr(unsafe.Pointer(name)),
0,
uintptr(windows.MAXIMUM_ALLOWED),
)
if hWinSta == 0 {
return fmt.Errorf("OpenWindowStation(WinSta0): %w", err)
}
r, _, err := procSetProcessWindowStation.Call(hWinSta)
if r == 0 {
_, _, _ = procCloseWindowStation.Call(hWinSta)
return fmt.Errorf("SetProcessWindowStation: %w", err)
}
log.Info("process window station set to WinSta0 (interactive)")
return nil
}
func screenSize() (int, int) {
w, _, _ := procGetSystemMetrics.Call(uintptr(smCxScreen))
h, _, _ := procGetSystemMetrics.Call(uintptr(smCyScreen))
return int(w), int(h)
}
func getDesktopName(hDesk uintptr) string {
var buf [256]uint16
var needed uint32
_, _, _ = procGetUserObjectInformationW.Call(hDesk, uoiName,
uintptr(unsafe.Pointer(&buf[0])), 512,
uintptr(unsafe.Pointer(&needed)))
return windows.UTF16ToString(buf[:])
}
// switchToInputDesktop opens the desktop currently receiving user input
// and sets it as the calling OS thread's desktop. Must be called from a
// goroutine locked to its OS thread via runtime.LockOSThread().
func switchToInputDesktop() (bool, string) {
hDesk, _, _ := procOpenInputDesktop.Call(0, 0, uintptr(windows.MAXIMUM_ALLOWED))
if hDesk == 0 {
return false, ""
}
name := getDesktopName(hDesk)
ret, _, _ := procSetThreadDesktop.Call(hDesk)
_, _, _ = procCloseDesktop.Call(hDesk)
return ret != 0, name
}
// gdiCapturer captures the desktop screen using GDI BitBlt.
// GDI objects (DC, DIBSection) are allocated once and reused across frames.
type gdiCapturer struct {
mu sync.Mutex
width int
height int
// Pre-allocated GDI resources, reused across captures.
memDC uintptr
bmp uintptr
bits uintptr
}
func newGDICapturer() (*gdiCapturer, error) {
w, h := screenSize()
if w == 0 || h == 0 {
return nil, fmt.Errorf("screen dimensions are zero")
}
c := &gdiCapturer{width: w, height: h}
if err := c.allocGDI(); err != nil {
return nil, err
}
return c, nil
}
// allocGDI pre-allocates the compatible DC and DIB section for reuse.
func (c *gdiCapturer) allocGDI() error {
screenDC, _, _ := procGetDC.Call(0)
if screenDC == 0 {
return fmt.Errorf("GetDC returned 0")
}
defer func() { _, _, _ = procReleaseDC.Call(0, screenDC) }()
memDC, _, _ := procCreateCompatDC.Call(screenDC)
if memDC == 0 {
return fmt.Errorf("CreateCompatibleDC returned 0")
}
bi := bitmapInfo{
Header: bitmapInfoHeader{
Size: uint32(unsafe.Sizeof(bitmapInfoHeader{})),
Width: int32(c.width),
Height: -int32(c.height), // negative = top-down DIB
Planes: 1,
BitCount: 32,
},
}
var bits uintptr
bmp, _, _ := procCreateDIBSection.Call(
screenDC,
uintptr(unsafe.Pointer(&bi)),
dibRgbColors,
uintptr(unsafe.Pointer(&bits)),
0, 0,
)
if bmp == 0 || bits == 0 {
_, _, _ = procDeleteDC.Call(memDC)
return fmt.Errorf("CreateDIBSection returned 0")
}
_, _, _ = procSelectObject.Call(memDC, bmp)
c.memDC = memDC
c.bmp = bmp
c.bits = bits
return nil
}
func (c *gdiCapturer) close() { c.freeGDI() }
// freeGDI releases pre-allocated GDI resources.
func (c *gdiCapturer) freeGDI() {
if c.bmp != 0 {
_, _, _ = procDeleteObject.Call(c.bmp)
c.bmp = 0
}
if c.memDC != 0 {
_, _, _ = procDeleteDC.Call(c.memDC)
c.memDC = 0
}
c.bits = 0
}
func (c *gdiCapturer) capture() (*image.RGBA, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.memDC == 0 {
return nil, fmt.Errorf("GDI resources not allocated")
}
screenDC, _, _ := procGetDC.Call(0)
if screenDC == 0 {
return nil, fmt.Errorf("GetDC returned 0")
}
defer func() { _, _, _ = procReleaseDC.Call(0, screenDC) }()
// SRCCOPY|CAPTUREBLT: CAPTUREBLT forces inclusion of layered/topmost
// windows in the capture and is required for GDI BitBlt to return live
// pixels when the session is rendered through RDP / DWM-composited
// surfaces. Without it BitBlt reads the backing-store DIB which is
// often empty (all-black) on RDP and headless sessions.
ret, _, _ := procBitBlt.Call(c.memDC, 0, 0, uintptr(c.width), uintptr(c.height),
screenDC, 0, 0, srccopy|captureBlt)
if ret == 0 {
return nil, fmt.Errorf("BitBlt returned 0")
}
n := c.width * c.height * 4
raw := unsafe.Slice((*byte)(unsafe.Pointer(c.bits)), n)
// GDI gives BGRA, the RFB encoder expects RGBA (img.Pix layout).
// Swap R and B in bulk using uint32 operations (one load + mask + shift
// per pixel instead of three separate byte assignments).
img := image.NewRGBA(image.Rect(0, 0, c.width, c.height))
swizzleBGRAtoRGBA(img.Pix, raw)
return img, nil
}
// DesktopCapturer captures the interactive desktop, handling desktop transitions
// (login screen, UAC prompts). A dedicated OS-locked goroutine continuously
// captures frames on demand via a dedicated OS-locked goroutine (required
// because DXGI's D3D11 device context is not thread-safe). Sessions drive
// timing by calling Capture(); a short staleness cache coalesces concurrent
// requests. Capture pauses automatically when no clients are connected.
type DesktopCapturer struct {
mu sync.Mutex
w, h int
// lastFrame/lastAt implement a small staleness cache so multiple
// near-simultaneous Capture calls share one DXGI round-trip.
lastFrame *image.RGBA
lastAt time.Time
// clients tracks the number of active VNC sessions. When zero, the
// worker goroutine releases the underlying capturer.
clients atomic.Int32
// reqCh carries capture requests from sessions to the OS-locked worker.
reqCh chan captureReq
// wake is signaled when a client connects and the worker should resume.
wake chan struct{}
// done is closed when Close is called, terminating the worker.
done chan struct{}
// cursorState holds the latest cursor sprite sampled by the worker.
// The worker calls GetCursorInfo every capture and decodes a new
// sprite only when the HCURSOR changes.
cursorState cursorState
}
// captureReq is a single capture request awaiting a reply. Reply channel is
// buffered to size 1 so the worker never blocks on a sender that's gone.
type captureReq struct {
reply chan captureReply
}
type captureReply struct {
img *image.RGBA
err error
}
// NewDesktopCapturer creates an on-demand capturer for the active desktop.
func NewDesktopCapturer() *DesktopCapturer {
c := &DesktopCapturer{
wake: make(chan struct{}, 1),
done: make(chan struct{}),
reqCh: make(chan captureReq),
}
go c.worker()
return c
}
// ClientConnect increments the active client count, resuming capture if needed.
func (c *DesktopCapturer) ClientConnect() {
c.clients.Add(1)
select {
case c.wake <- struct{}{}:
default:
}
}
// ClientDisconnect decrements the active client count.
func (c *DesktopCapturer) ClientDisconnect() {
c.clients.Add(-1)
}
// Close stops the capture loop and releases resources.
func (c *DesktopCapturer) Close() {
select {
case <-c.done:
default:
close(c.done)
}
}
// Width returns the current screen width, triggering a capture if the
// worker hasn't initialised yet. validateCapturer depends on Width/Height
// becoming non-zero promptly after ClientConnect so it doesn't reject
// brand-new sessions.
func (c *DesktopCapturer) Width() int {
c.mu.Lock()
w := c.w
c.mu.Unlock()
if w == 0 && c.clients.Load() > 0 {
_, _ = c.Capture()
c.mu.Lock()
w = c.w
c.mu.Unlock()
}
return w
}
// Height returns the current screen height, triggering a capture if the
// worker hasn't initialised yet (see Width). Returns 0 while no client is
// connected so callers don't deadlock against a parked worker.
func (c *DesktopCapturer) Height() int {
c.mu.Lock()
h := c.h
c.mu.Unlock()
if h == 0 && c.clients.Load() > 0 {
_, _ = c.Capture()
c.mu.Lock()
h = c.h
c.mu.Unlock()
}
return h
}
// Capture returns a freshly captured frame, serving from a short staleness
// cache when multiple sessions ask within freshWindow of each other. All
// real DXGI/GDI work happens on the OS-locked worker goroutine.
func (c *DesktopCapturer) Capture() (*image.RGBA, error) {
c.mu.Lock()
if c.lastFrame != nil && time.Since(c.lastAt) < freshWindow {
img := c.lastFrame
c.mu.Unlock()
return img, nil
}
c.mu.Unlock()
reply := make(chan captureReply, 1)
select {
case c.reqCh <- captureReq{reply: reply}:
case <-c.done:
return nil, fmt.Errorf("capturer closed")
}
select {
case r := <-reply:
if r.err != nil {
return nil, r.err
}
c.mu.Lock()
c.lastFrame = r.img
c.lastAt = time.Now()
c.mu.Unlock()
return r.img, nil
case <-c.done:
return nil, fmt.Errorf("capturer closed")
}
}
// waitForClient blocks until a client connects or the capturer is closed.
func (c *DesktopCapturer) waitForClient() bool {
if c.clients.Load() > 0 {
return true
}
select {
case <-c.wake:
return true
case <-c.done:
return false
}
}
// worker owns DXGI/GDI state on its OS-locked thread and services capture
// requests from sessions. No background ticker: a capture happens only when
// a session asks for one (throttled by Capture()'s staleness cache).
func (c *DesktopCapturer) worker() {
runtime.LockOSThread()
// When running as a Windows service (Session 0), we need to attach to the
// interactive window station before OpenInputDesktop will succeed.
if err := setupInteractiveWindowStation(); err != nil {
log.Warnf("attach to interactive window station: %v", err)
}
w := &captureWorker{c: c}
defer w.closeCapturer()
for {
if !c.waitForClient() {
return
}
// Drop the capturer when all clients have disconnected so we don't
// hold the DXGI duplication or GDI DC on an idle peer.
if c.clients.Load() <= 0 {
w.closeCapturer()
continue
}
if !w.handleNextRequest() {
return
}
}
}
// frameCapturer is the per-backend interface used by the worker. DXGI and
// GDI implementations both satisfy it.
type frameCapturer interface {
capture() (*image.RGBA, error)
close()
}
// captureWorker owns the worker goroutine's mutable state. Extracted into a
// struct so the request/desktop/init logic can live on small methods and the
// outer worker() stays a thin loop.
type captureWorker struct {
c *DesktopCapturer
cap frameCapturer
desktopFails int
lastDesktop string
nextInitRetry time.Time
cursor cursorSampler
// lastBackend records the last capturer kind that came out of
// createCapturer ("dxgi" or "gdi"); used to demote repeat "using X"
// and DXGI-unavailable logs to debug when nothing changed.
lastBackend string
// lastDXGIErr is the textual DXGI failure printed in the most recent
// fallback warning; suppresses repeat warns when DXGI keeps failing
// the same way across desktop changes (login -> lock -> login).
lastDXGIErr string
}
// handleNextRequest waits for either shutdown or a capture request and runs
// the request through prepCapturer/capture. Returns false when the worker
// should exit.
func (w *captureWorker) handleNextRequest() bool {
select {
case <-w.c.done:
return false
case req := <-w.c.reqCh:
w.serveRequest(req)
return true
}
}
func (w *captureWorker) serveRequest(req captureReq) {
fc, err := w.prepCapturer()
if err != nil {
req.reply <- captureReply{err: err}
return
}
img, err := fc.capture()
if err != nil {
log.Debugf("capture: %v", err)
w.closeCapturer()
w.nextInitRetry = time.Now().Add(100 * time.Millisecond)
req.reply <- captureReply{err: err}
return
}
if snap, err := w.cursor.sample(); err != nil {
w.c.cursorState.store(&cursorSnapshot{err: err})
} else {
w.c.cursorState.store(snap)
}
req.reply <- captureReply{img: img}
}
// prepCapturer switches to the input desktop, handles desktop-change
// teardown, and creates the underlying capturer on demand. Backoff state is
// tracked across calls via w.nextInitRetry.
func (w *captureWorker) prepCapturer() (frameCapturer, error) {
if err := w.refreshDesktop(); err != nil {
return nil, err
}
if w.cap != nil {
return w.cap, nil
}
if time.Now().Before(w.nextInitRetry) {
return nil, fmt.Errorf("capturer init backing off")
}
fc, err := w.createCapturer()
if err != nil {
w.nextInitRetry = time.Now().Add(500 * time.Millisecond)
return nil, err
}
w.cap = fc
sw, sh := screenSize()
w.c.mu.Lock()
sizeChanged := w.c.w != sw || w.c.h != sh
w.c.w, w.c.h = sw, sh
w.c.mu.Unlock()
if sizeChanged {
log.Infof("screen capturer ready: %dx%d", sw, sh)
} else {
log.Debugf("screen capturer ready: %dx%d", sw, sh)
}
return w.cap, nil
}
// refreshDesktop tracks the active input desktop. When it changes (lock
// screen, fast-user-switch) the existing capturer is dropped so the next
// call rebuilds one against the new desktop.
func (w *captureWorker) refreshDesktop() error {
ok, desk := switchToInputDesktop()
if !ok {
w.desktopFails++
if w.desktopFails == 1 || w.desktopFails%100 == 0 {
log.Warnf("switchToInputDesktop failed (count=%d), no interactive desktop session?", w.desktopFails)
}
return fmt.Errorf("no interactive desktop")
}
if w.desktopFails > 0 {
log.Infof("switchToInputDesktop recovered after %d failures, desktop=%q", w.desktopFails, desk)
w.desktopFails = 0
}
if desk != w.lastDesktop {
log.Infof("desktop changed: %q -> %q", w.lastDesktop, desk)
w.lastDesktop = desk
w.closeCapturer()
}
return nil
}
func (w *captureWorker) createCapturer() (frameCapturer, error) {
dc, err := newDXGICapturer()
if err == nil {
if w.lastBackend != "dxgi" {
log.Info("using DXGI Desktop Duplication for capture")
} else {
log.Debug("using DXGI Desktop Duplication for capture")
}
w.lastBackend = "dxgi"
w.lastDXGIErr = ""
return dc, nil
}
errStr := err.Error()
if errStr != w.lastDXGIErr {
log.Warnf("DXGI Desktop Duplication unavailable, falling back to slower GDI BitBlt: %v", err)
w.lastDXGIErr = errStr
} else {
log.Debugf("DXGI Desktop Duplication still unavailable, falling back to slower GDI BitBlt: %v", err)
}
gc, err := newGDICapturer()
if err != nil {
return nil, err
}
if w.lastBackend != "gdi" {
log.Info("using GDI BitBlt for capture")
} else {
log.Debug("using GDI BitBlt for capture")
}
w.lastBackend = "gdi"
return gc, nil
}
func (w *captureWorker) closeCapturer() {
if w.cap != nil {
w.cap.close()
w.cap = nil
}
}

View File

@@ -1,544 +0,0 @@
//go:build (linux && !android) || freebsd
package server
import (
"fmt"
"image"
"os"
"os/exec"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"github.com/jezek/xgb"
"github.com/jezek/xgb/xproto"
)
const (
// x11SocketDir is the well-known directory where X servers create
// their abstract UNIX-domain sockets, named "X<display>". Used both
// for auto-detecting an existing display and for placing/probing
// sockets of virtual sessions we spawn.
x11SocketDir = "/tmp/.X11-unix"
// envDisplay is the X11 display selector environment variable.
envDisplay = "DISPLAY"
// envXAuthority points X clients at the cookie file used to
// authenticate against the running X server.
envXAuthority = "XAUTHORITY"
)
// X11Capturer captures the screen from an X11 display using the MIT-SHM extension.
type X11Capturer struct {
mu sync.Mutex
conn *xgb.Conn
screen *xproto.ScreenInfo
w, h int
shmID int
shmAddr []byte
shmSeg uint32
useSHM bool
// bufs double-buffers output images so the X11Poller's capture loop can
// overwrite one while the session is still encoding the other. Before
// this, a single reused buffer would race with the reader. Allocation
// happens on first use and on geometry change.
bufs [2]*image.RGBA
cur int
// cursor is the XFixes binding used to report the current sprite.
// Allocated lazily on the first Cursor call. cursorInitErr latches
// a permanent init failure so we stop retrying every frame.
cursor *xfixesCursor
cursorInitErr error
}
// detectX11Display finds the active X11 display and sets DISPLAY/XAUTHORITY
// environment variables if needed. This is required when running as a system
// service where these vars aren't set.
func detectX11Display() {
if os.Getenv(envDisplay) != "" {
return
}
// Try /proc first (Linux), then ps fallback (FreeBSD and others).
if detectX11FromProc() {
return
}
if detectX11FromSockets() {
return
}
}
// detectX11FromProc scans /proc/*/cmdline for Xorg (Linux).
func detectX11FromProc() bool {
entries, err := os.ReadDir("/proc")
if err != nil {
return false
}
for _, e := range entries {
if !e.IsDir() {
continue
}
cmdline, err := os.ReadFile("/proc/" + e.Name() + "/cmdline")
if err != nil {
continue
}
if display, auth := parseXorgArgs(splitCmdline(cmdline)); display != "" {
setDisplayEnv(display, auth)
return true
}
}
return false
}
// detectX11FromSockets checks /tmp/.X11-unix/ for X sockets and uses ps
// to find the auth file. Works on FreeBSD and other systems without /proc.
func detectX11FromSockets() bool {
entries, err := os.ReadDir(x11SocketDir)
if err != nil {
return false
}
// Pick the lowest numeric display rather than the lexically first
// entry, so X10 doesn't win over X2.
minDisplay := -1
for _, e := range entries {
name := e.Name()
if len(name) < 2 || name[0] != 'X' {
continue
}
n, err := strconv.Atoi(name[1:])
if err != nil {
continue
}
if minDisplay < 0 || n < minDisplay {
minDisplay = n
}
}
if minDisplay < 0 {
return false
}
display := ":" + strconv.Itoa(minDisplay)
os.Setenv(envDisplay, display)
auth := findXorgAuthFromPS()
if auth != "" {
os.Setenv(envXAuthority, auth)
log.Infof("auto-detected DISPLAY=%s (from socket) XAUTHORITY=%s (from ps)", display, auth)
} else {
log.Infof("auto-detected DISPLAY=%s (from socket)", display)
}
return true
}
// findXorgAuthFromPS runs ps to find Xorg and extract its -auth argument.
func findXorgAuthFromPS() string {
out, err := exec.Command("ps", "auxww").Output()
if err != nil {
return ""
}
for _, line := range strings.Split(string(out), "\n") {
if !strings.Contains(line, "Xorg") && !strings.Contains(line, "/X ") {
continue
}
fields := strings.Fields(line)
for i, f := range fields {
if f == "-auth" && i+1 < len(fields) {
return fields[i+1]
}
}
}
return ""
}
func parseXorgArgs(args []string) (display, auth string) {
if len(args) == 0 {
return "", ""
}
base := args[0]
if !(base == "Xorg" || base == "X" || len(base) > 0 && base[len(base)-1] == 'X' ||
strings.Contains(base, "/Xorg") || strings.Contains(base, "/X")) {
return "", ""
}
for i, arg := range args[1:] {
if len(arg) > 0 && arg[0] == ':' {
display = arg
}
if arg == "-auth" && i+2 < len(args) {
auth = args[i+2]
}
}
return display, auth
}
func setDisplayEnv(display, auth string) {
os.Setenv(envDisplay, display)
if auth != "" {
os.Setenv(envXAuthority, auth)
log.Infof("auto-detected DISPLAY=%s XAUTHORITY=%s", display, auth)
return
}
log.Infof("auto-detected DISPLAY=%s", display)
}
func splitCmdline(data []byte) []string {
var args []string
for _, b := range splitNull(data) {
if len(b) > 0 {
args = append(args, string(b))
}
}
return args
}
func splitNull(data []byte) [][]byte {
var parts [][]byte
start := 0
for i, b := range data {
if b == 0 {
parts = append(parts, data[start:i])
start = i + 1
}
}
if start < len(data) {
parts = append(parts, data[start:])
}
return parts
}
// NewX11Capturer connects to the X11 display and sets up shared memory capture.
// Empty cookieHex falls back to XAUTHORITY env lookup.
func NewX11Capturer(display, cookieHex string) (*X11Capturer, error) {
if display == "" {
detectX11Display()
display = os.Getenv(envDisplay)
}
if display == "" {
return nil, fmt.Errorf("DISPLAY not set and no Xorg process found")
}
var conn *xgb.Conn
var err error
if cookieHex != "" {
conn, err = dialXUnixWithCookie(display, cookieHex)
} else {
conn, err = xgb.NewConnDisplay(display)
}
if err != nil {
return nil, fmt.Errorf("connect to X11 display %s: %w", display, err)
}
setup := xproto.Setup(conn)
if len(setup.Roots) == 0 {
conn.Close()
return nil, fmt.Errorf("no X11 screens")
}
screen := setup.Roots[0]
c := &X11Capturer{
conn: conn,
screen: &screen,
w: int(screen.WidthInPixels),
h: int(screen.HeightInPixels),
}
if err := c.initSHM(); err != nil {
log.Debugf("X11 SHM not available, using slow GetImage: %v", err)
}
log.Infof("X11 capturer ready: %dx%d (display=%s, shm=%v)", c.w, c.h, display, c.useSHM)
return c, nil
}
// initSHM is implemented in capture_x11_shm_linux.go (requires SysV SHM).
// On platforms without SysV SHM (FreeBSD), a stub returns an error and
// the capturer falls back to GetImage.
// Width returns the screen width.
func (c *X11Capturer) Width() int { return c.w }
// Height returns the screen height.
func (c *X11Capturer) Height() int { return c.h }
// Capture returns the current screen as an RGBA image.
func (c *X11Capturer) Capture() (*image.RGBA, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.useSHM {
return c.captureSHM()
}
return c.captureGetImage()
}
// CaptureInto fills the caller's destination buffer in one pass. The
// source path (SHM or fallback GetImage) writes directly into dst.Pix
// instead of going through the X11Capturer's internal double-buffer,
// saving one full-frame memcpy per capture.
func (c *X11Capturer) CaptureInto(dst *image.RGBA) error {
c.mu.Lock()
defer c.mu.Unlock()
if dst.Rect.Dx() != c.w || dst.Rect.Dy() != c.h {
return fmt.Errorf("dst size mismatch: dst=%dx%d capturer=%dx%d",
dst.Rect.Dx(), dst.Rect.Dy(), c.w, c.h)
}
if c.useSHM {
return c.captureSHMInto(dst)
}
return c.captureGetImageInto(dst)
}
func (c *X11Capturer) captureGetImageInto(dst *image.RGBA) error {
cookie := xproto.GetImage(c.conn, xproto.ImageFormatZPixmap,
xproto.Drawable(c.screen.Root),
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF)
reply, err := cookie.Reply()
if err != nil {
return fmt.Errorf("GetImage: %w", err)
}
n := c.w * c.h * 4
if len(reply.Data) < n {
return fmt.Errorf("GetImage returned %d bytes, expected %d", len(reply.Data), n)
}
swizzleBGRAtoRGBA(dst.Pix, reply.Data)
return nil
}
// captureSHM is implemented in capture_x11_shm_linux.go.
func (c *X11Capturer) captureGetImage() (*image.RGBA, error) {
cookie := xproto.GetImage(c.conn, xproto.ImageFormatZPixmap,
xproto.Drawable(c.screen.Root),
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF)
reply, err := cookie.Reply()
if err != nil {
return nil, fmt.Errorf("GetImage: %w", err)
}
data := reply.Data
n := c.w * c.h * 4
if len(data) < n {
return nil, fmt.Errorf("GetImage returned %d bytes, expected %d", len(data), n)
}
img := c.nextBuffer()
swizzleBGRAtoRGBA(img.Pix, data)
return img, nil
}
// nextBuffer returns the *image.RGBA the next capture should fill, advancing
// the double-buffer index. Reallocates on geometry change.
func (c *X11Capturer) nextBuffer() *image.RGBA {
c.cur ^= 1
b := c.bufs[c.cur]
if b == nil || b.Rect.Dx() != c.w || b.Rect.Dy() != c.h {
b = image.NewRGBA(image.Rect(0, 0, c.w, c.h))
c.bufs[c.cur] = b
}
return b
}
// Close releases X11 resources.
func (c *X11Capturer) Close() {
c.closeSHM()
c.conn.Close()
}
// closeSHM is implemented in capture_x11_shm_linux.go.
// X11Poller wraps X11Capturer with a staleness-cached on-demand Capture:
// sessions drive captures themselves through the encoder goroutine, so we
// don't need a background ticker. The last result is cached for a short
// window so concurrent sessions coalesce into one capture.
//
// The capturer is allocated lazily on first use and released when all
// clients disconnect, so an idle peer holds no X connection or SHM segment.
type X11Poller struct {
mu sync.Mutex
capturer *X11Capturer
w, h int
// closed at Close so callers can stop waiting on retry backoff.
done chan struct{}
// lastFrame/lastAt implement a small cache: multiple near-simultaneous
// Capture calls (multi-client, or input-coalesced) return the same
// frame instead of hammering the X server.
lastFrame *image.RGBA
lastAt time.Time
// initBackoffUntil throttles capturer re-init when the X server is
// unavailable or flapping.
initBackoffUntil time.Time
clients atomic.Int32
display string
// cookieHex authenticates the X11 connection; empty falls back to XAUTHORITY env.
cookieHex string
}
// initRetryBackoff gates capturer re-init attempts after a failure so we
// don't spin on X server errors.
const initRetryBackoff = 2 * time.Second
// NewX11Poller creates a lazy on-demand capturer for the given X display.
// Empty cookieHex falls back to XAUTHORITY env lookup.
func NewX11Poller(display, cookieHex string) *X11Poller {
return &X11Poller{
display: display,
cookieHex: cookieHex,
done: make(chan struct{}),
}
}
// ClientConnect increments the active client count. The first client triggers
// eager capturer initialisation so that the first FBUpdateRequest doesn't
// pay the X11 connect + SHM attach latency.
func (p *X11Poller) ClientConnect() {
if p.clients.Add(1) == 1 {
p.mu.Lock()
_ = p.ensureCapturerLocked()
p.mu.Unlock()
}
}
// ClientDisconnect decrements the active client count. On the last
// disconnect we close the underlying capturer so idle peers cost nothing.
func (p *X11Poller) ClientDisconnect() {
if p.clients.Add(-1) == 0 {
p.mu.Lock()
if p.capturer != nil {
p.capturer.Close()
p.capturer = nil
p.lastFrame = nil
}
p.mu.Unlock()
}
}
// Close releases all resources. Subsequent Capture calls will fail.
func (p *X11Poller) Close() {
p.mu.Lock()
defer p.mu.Unlock()
select {
case <-p.done:
default:
close(p.done)
}
if p.capturer != nil {
p.capturer.Close()
p.capturer = nil
}
}
// Width returns the screen width. Triggers lazy init if needed.
func (p *X11Poller) Width() int {
p.mu.Lock()
defer p.mu.Unlock()
_ = p.ensureCapturerLocked()
return p.w
}
// Height returns the screen height. Triggers lazy init if needed.
func (p *X11Poller) Height() int {
p.mu.Lock()
defer p.mu.Unlock()
_ = p.ensureCapturerLocked()
return p.h
}
// Cursor satisfies cursorSource by forwarding to the lazily-initialised
// X11Capturer. Asking for the cursor on an idle poller triggers the same
// lazy X11 connection setup as a capture would.
func (p *X11Poller) Cursor() (*image.RGBA, int, int, uint64, error) {
p.mu.Lock()
defer p.mu.Unlock()
if err := p.ensureCapturerLocked(); err != nil {
return nil, 0, 0, 0, err
}
return p.capturer.Cursor()
}
// CursorPos satisfies cursorPositionSource by forwarding to the X11Capturer.
func (p *X11Poller) CursorPos() (int, int, error) {
p.mu.Lock()
defer p.mu.Unlock()
if err := p.ensureCapturerLocked(); err != nil {
return 0, 0, err
}
return p.capturer.CursorPos()
}
// Capture returns a fresh frame, serving from the short-lived cache if a
// previous caller captured within freshWindow.
func (p *X11Poller) Capture() (*image.RGBA, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.lastFrame != nil && time.Since(p.lastAt) < freshWindow {
return p.lastFrame, nil
}
if err := p.ensureCapturerLocked(); err != nil {
return nil, err
}
img, err := p.capturer.Capture()
if err != nil {
// Drop the capturer so the next call re-inits; the X connection may
// have died (e.g. Xorg restart).
p.capturer.Close()
p.capturer = nil
p.initBackoffUntil = time.Now().Add(initRetryBackoff)
return nil, fmt.Errorf("x11 capture: %w", err)
}
p.lastFrame = img
p.lastAt = time.Now()
return img, nil
}
// CaptureInto fills dst directly via the underlying capturer, bypassing
// the freshness cache. The session's prevFrame/curFrame swap means each
// session needs its own buffer anyway, so caching wouldn't help.
func (p *X11Poller) CaptureInto(dst *image.RGBA) error {
p.mu.Lock()
defer p.mu.Unlock()
if err := p.ensureCapturerLocked(); err != nil {
return err
}
if err := p.capturer.CaptureInto(dst); err != nil {
p.capturer.Close()
p.capturer = nil
p.initBackoffUntil = time.Now().Add(initRetryBackoff)
return fmt.Errorf("x11 capture: %w", err)
}
return nil
}
// ensureCapturerLocked initialises the underlying X11Capturer if not
// already open. Caller must hold p.mu.
func (p *X11Poller) ensureCapturerLocked() error {
if p.capturer != nil {
return nil
}
select {
case <-p.done:
return fmt.Errorf("x11 capturer closed")
default:
}
if time.Now().Before(p.initBackoffUntil) {
return fmt.Errorf("x11 capturer unavailable (retry scheduled)")
}
c, err := NewX11Capturer(p.display, p.cookieHex)
if err != nil {
p.initBackoffUntil = time.Now().Add(initRetryBackoff)
log.Debugf("X11 capturer: %v", err)
return err
}
p.capturer = c
p.w, p.h = c.Width(), c.Height()
return nil
}

View File

@@ -1,96 +0,0 @@
//go:build linux && !android
package server
import (
"fmt"
"image"
"github.com/jezek/xgb/shm"
"github.com/jezek/xgb/xproto"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
func (c *X11Capturer) initSHM() error {
if err := shm.Init(c.conn); err != nil {
return fmt.Errorf("init SHM extension: %w", err)
}
size := c.w * c.h * 4
id, err := unix.SysvShmGet(unix.IPC_PRIVATE, size, unix.IPC_CREAT|0600)
if err != nil {
return fmt.Errorf("shmget: %w", err)
}
addr, err := unix.SysvShmAttach(id, 0, 0)
if err != nil {
if _, ctlErr := unix.SysvShmCtl(id, unix.IPC_RMID, nil); ctlErr != nil {
log.Debugf("shmctl IPC_RMID on attach failure: %v", ctlErr)
}
return fmt.Errorf("shmat: %w", err)
}
if _, err := unix.SysvShmCtl(id, unix.IPC_RMID, nil); err != nil {
log.Debugf("shmctl IPC_RMID: %v", err)
}
seg, err := shm.NewSegId(c.conn)
if err != nil {
if detachErr := unix.SysvShmDetach(addr); detachErr != nil {
log.Debugf("shmdt on new-seg failure: %v", detachErr)
}
return fmt.Errorf("new SHM seg: %w", err)
}
if err := shm.AttachChecked(c.conn, seg, uint32(id), false).Check(); err != nil {
if detachErr := unix.SysvShmDetach(addr); detachErr != nil {
log.Debugf("shmdt on attach-checked failure: %v", detachErr)
}
return fmt.Errorf("SHM attach to X: %w", err)
}
c.shmID = id
c.shmAddr = addr
c.shmSeg = uint32(seg)
c.useSHM = true
return nil
}
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
if err := c.fillSHM(); err != nil {
return nil, err
}
img := c.nextBuffer()
swizzleBGRAtoRGBA(img.Pix, c.shmAddr[:c.w*c.h*4])
return img, nil
}
// captureSHMInto runs a single SHM GetImage and swizzles directly into the
// caller-provided destination, skipping the internal double-buffer.
func (c *X11Capturer) captureSHMInto(dst *image.RGBA) error {
if err := c.fillSHM(); err != nil {
return err
}
swizzleBGRAtoRGBA(dst.Pix, c.shmAddr[:c.w*c.h*4])
return nil
}
func (c *X11Capturer) fillSHM() error {
cookie := shm.GetImage(c.conn, xproto.Drawable(c.screen.Root),
0, 0, uint16(c.w), uint16(c.h), 0xFFFFFFFF,
xproto.ImageFormatZPixmap, shm.Seg(c.shmSeg), 0)
if _, err := cookie.Reply(); err != nil {
return fmt.Errorf("SHM GetImage: %w", err)
}
return nil
}
func (c *X11Capturer) closeSHM() {
if c.useSHM {
shm.Detach(c.conn, shm.Seg(c.shmSeg))
if err := unix.SysvShmDetach(c.shmAddr); err != nil {
log.Debugf("shmdt on close: %v", err)
}
}
}

View File

@@ -1,24 +0,0 @@
//go:build freebsd
package server
import (
"fmt"
"image"
)
func (c *X11Capturer) initSHM() error {
return fmt.Errorf("SysV SHM not available on this platform")
}
func (c *X11Capturer) captureSHM() (*image.RGBA, error) {
return nil, fmt.Errorf("SHM capture not available on this platform")
}
func (c *X11Capturer) captureSHMInto(_ *image.RGBA) error {
return fmt.Errorf("SHM capture not available on this platform")
}
func (c *X11Capturer) closeSHM() {
// no SHM to close on this platform
}

View File

@@ -1,77 +0,0 @@
//go:build !js && !ios && !android
package server
import (
"reflect"
"testing"
)
func TestCoalesceRects(t *testing.T) {
cases := []struct {
name string
in [][4]int
want [][4]int
}{
{
name: "empty",
in: nil,
want: nil,
},
{
name: "single",
in: [][4]int{{0, 0, 64, 64}},
want: [][4]int{{0, 0, 64, 64}},
},
{
name: "horizontal_run",
in: [][4]int{{0, 0, 64, 64}, {64, 0, 64, 64}, {128, 0, 64, 64}},
want: [][4]int{{0, 0, 192, 64}},
},
{
name: "vertical_run",
in: [][4]int{{0, 0, 64, 64}, {0, 64, 64, 64}, {0, 128, 64, 64}},
want: [][4]int{{0, 0, 64, 192}},
},
{
name: "block_2x2",
in: [][4]int{
{0, 0, 64, 64}, {64, 0, 64, 64},
{0, 64, 64, 64}, {64, 64, 64, 64},
},
want: [][4]int{{0, 0, 128, 128}},
},
{
name: "no_merge_gap",
in: [][4]int{{0, 0, 64, 64}, {192, 0, 64, 64}},
want: [][4]int{{0, 0, 64, 64}, {192, 0, 64, 64}},
},
{
name: "two_disjoint_columns",
in: [][4]int{
{0, 0, 64, 64}, {192, 0, 64, 64},
{0, 64, 64, 64}, {192, 64, 64, 64},
},
want: [][4]int{{0, 0, 64, 128}, {192, 0, 64, 128}},
},
{
name: "misaligned_widths_no_vertical_merge",
in: [][4]int{
{0, 0, 128, 64},
{0, 64, 64, 64},
},
want: [][4]int{{0, 0, 128, 64}, {0, 64, 64, 64}},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := coalesceRects(tc.in)
if len(got) == 0 && len(tc.want) == 0 {
return
}
if !reflect.DeepEqual(got, tc.want) {
t.Fatalf("got %v want %v", got, tc.want)
}
})
}
}

View File

@@ -1,10 +0,0 @@
package server
// interactiveUserError returns nil when a user is logged into the console
// (i.e. an Aqua session is active). At the loginwindow there is nobody to
// display an approval prompt to, so callers can decline without waiting on
// the broker. Any error (including errNoConsoleUser) is treated as decline.
func interactiveUserError() error {
_, err := consoleUserID()
return err
}

View File

@@ -1,7 +0,0 @@
//go:build !darwin && !windows
package server
// interactiveUserError is unused outside service mode (darwin/windows) but
// the symbol must exist so gateApproval compiles on all platforms.
func interactiveUserError() error { return nil }

View File

@@ -1,15 +0,0 @@
package server
// interactiveUserError returns nil when there is a logged-in user session
// on the box. At the lock/login screen WTSQueryUserName is empty, which
// means there is nobody to display an approval prompt to.
func interactiveUserError() error {
sid := getActiveSessionID()
if sid == 0 {
return errNoConsoleUser
}
if !wtsSessionHasUser(sid) {
return errNoConsoleUser
}
return nil
}

View File

@@ -1,228 +0,0 @@
//go:build !js && !ios && !android
package server
import (
"hash/maphash"
"image"
)
// copyRectDetector finds tiles in the current frame that match the content
// of some tile-aligned region of the previous frame, so we can emit them as
// CopyRect rectangles (16 wire bytes) instead of re-encoding the pixels.
//
// The detector keeps two structures:
// - tileHash, a flat slice of one hash per tile-aligned position, used as
// the source of truth for the previous frame's tile content.
// - prevTiles, a hash → position lookup used during findTileMatch.
//
// updateDirty rehashes only the tiles that changed this frame, so the
// steady-state cost is proportional to the dirty set, not the framebuffer.
// A full rebuild from scratch is only done on the first frame or when the
// detector has not yet been initialized for the current resolution.
//
// Limitations:
// - Only tile-aligned source positions are considered. Sub-tile-aligned
// moves (e.g. window dragged by 7 pixels) are not detected. This still
// covers the common case of vertical/horizontal scrolling, which always
// produces tile-aligned matches at the tile granularity.
// - 64-bit maphash collisions are assumed not to happen. The probability
// for any single frame's hash universe is ~2^-32 * tileCount² which is
// vanishingly small at typical resolutions; if we ever observe one we
// can fall back to a full memcmp verification.
type copyRectDetector struct {
seed maphash.Seed
tileSize int
w, h int
cols, rows int
// tileHash[ty*cols + tx] is the current hash of the tile at (tx, ty)
// in the previous frame. Lookup uses this to detect stale prevTiles
// entries: incremental updates may leave hash→pos entries pointing
// at a tile whose content has since changed.
tileHash []uint64
// prevTiles maps a tile hash to a (x, y) origin in the previous frame.
prevTiles map[uint64][2]int
// hash is reused across hash computations to keep the per-tile lookup
// path allocation-free.
hash maphash.Hash
}
func newCopyRectDetector(tileSize int) *copyRectDetector {
d := &copyRectDetector{
seed: maphash.MakeSeed(),
tileSize: tileSize,
prevTiles: make(map[uint64][2]int),
}
d.hash.SetSeed(d.seed)
return d
}
// resize ensures the per-tile tables match the given framebuffer size.
// Called from rebuild before each full hash sweep.
func (d *copyRectDetector) resize(w, h int) {
if d.w == w && d.h == h && d.tileHash != nil {
return
}
d.w, d.h = w, h
d.cols = w / d.tileSize
d.rows = h / d.tileSize
d.tileHash = make([]uint64, d.cols*d.rows)
}
// hashTile computes the 64-bit maphash of one tile-aligned tile of frame.
func (d *copyRectDetector) hashTile(frame *image.RGBA, tx, ty int) uint64 {
d.hash.Reset()
ts := d.tileSize
stride := frame.Stride
rowBytes := ts * 4
base := ty*stride + tx*4
for row := 0; row < ts; row++ {
off := base + row*stride
_, _ = d.hash.Write(frame.Pix[off : off+rowBytes])
}
return d.hash.Sum64()
}
// rebuild discards everything and rehashes the whole frame. O(w*h). Use
// for the first frame or after the detector has been resized. Steady-state
// updates should go through updateDirty instead.
func (d *copyRectDetector) rebuild(frame *image.RGBA, w, h int) {
d.resize(w, h)
if d.prevTiles == nil {
d.prevTiles = make(map[uint64][2]int)
} else {
clear(d.prevTiles)
}
ts := d.tileSize
for ty := 0; ty+ts <= h; ty += ts {
for tx := 0; tx+ts <= w; tx += ts {
sum := d.hashTile(frame, tx, ty)
d.tileHash[(ty/ts)*d.cols+(tx/ts)] = sum
if _, exists := d.prevTiles[sum]; !exists {
d.prevTiles[sum] = [2]int{tx, ty}
}
}
}
}
// updateDirty rehashes only the tiles named in dirty (each entry is
// [x, y, w, h] with w and h equal to tileSize). O(len(dirty)) work, which
// in the common case is a tiny fraction of the whole framebuffer.
//
// The prevTiles map is replaced on collision rather than first-wins so a
// newly-hashed tile claims the slot. Old, stale entries pointing at tiles
// that no longer carry that hash are filtered at lookup time via tileHash.
func (d *copyRectDetector) updateDirty(frame *image.RGBA, w, h int, dirty [][4]int) {
if d.w != w || d.h != h || d.tileHash == nil {
d.rebuild(frame, w, h)
return
}
ts := d.tileSize
for _, r := range dirty {
if r[2] != ts || r[3] != ts {
continue
}
tx, ty := r[0], r[1]
if tx+ts > w || ty+ts > h {
continue
}
sum := d.hashTile(frame, tx, ty)
d.tileHash[(ty/ts)*d.cols+(tx/ts)] = sum
// Latest-wins on collision: ensures the most recent owner of this
// hash is the one we'll return on lookup. The previous owner's
// entry, if any, gets shadowed; if its content has changed it's
// stale anyway and findTileMatch's verification will skip it.
d.prevTiles[sum] = [2]int{tx, ty}
}
}
// findTileMatch hashes the current-frame tile at (dstX, dstY) and looks up
// its hash in the previous-frame map. Returns (srcX, srcY, true) when a
// matching tile-aligned tile exists at a different position whose stored
// hash still equals the requested hash (so the result is not stale).
func (d *copyRectDetector) findTileMatch(cur *image.RGBA, dstX, dstY int) (int, int, bool) {
if len(d.prevTiles) == 0 || d.tileHash == nil {
return 0, 0, false
}
ts := d.tileSize
if dstX+ts > cur.Rect.Dx() || dstY+ts > cur.Rect.Dy() {
return 0, 0, false
}
sum := d.hashTile(cur, dstX, dstY)
pos, ok := d.prevTiles[sum]
if !ok {
return 0, 0, false
}
if pos[0] == dstX && pos[1] == dstY {
return 0, 0, false
}
// Reject source coords that fall outside the current framebuffer
// (frame may have shrunk since the source position was recorded). A
// CopyRect with an out-of-range source would have the client copy
// from undefined pixels, so drop the match and let the encoder send
// the rect normally.
if pos[0] < 0 || pos[1] < 0 || pos[0]+ts > cur.Rect.Dx() || pos[1]+ts > cur.Rect.Dy() {
return 0, 0, false
}
// Reject stale entries: the position the map points at must still
// carry the same hash according to our per-tile array.
hashIdx := (pos[1]/ts)*d.cols + pos[0]/ts
if hashIdx < 0 || hashIdx >= len(d.tileHash) {
return 0, 0, false
}
if d.tileHash[hashIdx] != sum {
return 0, 0, false
}
return pos[0], pos[1], true
}
// extractCopyRectTiles examines the diff-produced (per-tile) dirty list and
// pulls out any tiles whose current-frame content matches a prev-frame tile
// at a different position. Returns the CopyRect candidates and the residual
// dirty tiles that still need pixel encoding.
type copyRectMove struct {
srcX, srcY int
dstX, dstY int
}
func (d *copyRectDetector) extractCopyRectTiles(cur *image.RGBA, dirtyTiles [][4]int) (moves []copyRectMove, remaining [][4]int) {
ts := d.tileSize
remaining = dirtyTiles[:0:cap(dirtyTiles)]
for _, r := range dirtyTiles {
if r[2] == ts && r[3] == ts {
if sx, sy, ok := d.findTileMatch(cur, r[0], r[1]); ok {
// The client applies moves sequentially against its live
// framebuffer. If this move's source overlaps the
// destination of any move already queued, that destination
// has overwritten the source pixels client-side, so the
// copy would read corrupted data. Drop it and let the tile
// fall through to normal pixel encoding instead.
if tileOverlapsPriorDst(moves, sx, sy, ts) {
remaining = append(remaining, r)
continue
}
moves = append(moves, copyRectMove{
srcX: sx, srcY: sy, dstX: r[0], dstY: r[1],
})
continue
}
}
remaining = append(remaining, r)
}
return moves, remaining
}
// tileOverlapsPriorDst reports whether the tileSize-square source rectangle
// at (srcX, srcY) intersects the destination rectangle of any move already
// emitted. All move rectangles are ts×ts, so the test reduces to a
// per-axis distance check.
func tileOverlapsPriorDst(moves []copyRectMove, srcX, srcY, ts int) bool {
for _, m := range moves {
dx := srcX - m.dstX
dy := srcY - m.dstY
if dx > -ts && dx < ts && dy > -ts && dy < ts {
return true
}
}
return false
}

View File

@@ -1,225 +0,0 @@
//go:build !js && !ios && !android
package server
import (
"image"
"testing"
)
// fillTile paints a tileSize×tileSize block of img at (x,y) with the colour
// derived from (r,g,b) so the test can construct distinct-content tiles.
func fillTile(img *image.RGBA, x, y, ts int, r, g, b byte) {
for row := 0; row < ts; row++ {
off := (y+row)*img.Stride + x*4
for col := 0; col < ts; col++ {
img.Pix[off+col*4+0] = r
img.Pix[off+col*4+1] = g
img.Pix[off+col*4+2] = b
img.Pix[off+col*4+3] = 0xff
}
}
}
// copyTile copies a tileSize×tileSize block from src(sx,sy) to dst(dx,dy).
func copyTile(dst, src *image.RGBA, sx, sy, dx, dy, ts int) {
for row := 0; row < ts; row++ {
srcOff := (sy+row)*src.Stride + sx*4
dstOff := (dy+row)*dst.Stride + dx*4
copy(dst.Pix[dstOff:dstOff+ts*4], src.Pix[srcOff:srcOff+ts*4])
}
}
func TestCopyRectDetector_DetectsVerticalScroll(t *testing.T) {
const w, h = 256, 192 // 4×3 tiles at 64px
const ts = 64
prev := image.NewRGBA(image.Rect(0, 0, w, h))
cur := image.NewRGBA(image.Rect(0, 0, w, h))
// prev: 12 tiles each with a unique colour.
for ty := 0; ty < 3; ty++ {
for tx := 0; tx < 4; tx++ {
fillTile(prev, tx*ts, ty*ts, ts, byte(tx*40), byte(ty*60), 0x80)
}
}
// cur: simulate a single-tile-row scroll upward, every tile copied from
// the row below in prev, top row is new content.
for ty := 0; ty < 2; ty++ {
for tx := 0; tx < 4; tx++ {
copyTile(cur, prev, tx*ts, (ty+1)*ts, tx*ts, ty*ts, ts)
}
}
// Bottom row of cur: new colour, not a match.
for tx := 0; tx < 4; tx++ {
fillTile(cur, tx*ts, 2*ts, ts, 0xff, 0xff, 0xff)
}
d := newCopyRectDetector(ts)
d.rebuild(prev, w, h)
tiles := diffTiles(prev, cur, w, h, ts)
moves, remaining := d.extractCopyRectTiles(cur, tiles)
// Expect 8 CopyRect moves (top two rows) and 4 residual tiles (bottom row).
if len(moves) != 8 {
t.Fatalf("moves: want 8, got %d", len(moves))
}
if len(remaining) != 4 {
t.Fatalf("remaining: want 4, got %d", len(remaining))
}
// Spot-check one move: cur (0, 0) should map to prev (0, 64).
var found bool
for _, m := range moves {
if m.dstX == 0 && m.dstY == 0 {
if m.srcX != 0 || m.srcY != ts {
t.Fatalf("move at (0,0): src=(%d,%d), want (0,%d)", m.srcX, m.srcY, ts)
}
found = true
}
}
if !found {
t.Fatalf("no move for dst (0,0)")
}
}
// rectsOverlap reports whether two ts×ts tiles at the given origins overlap.
func tilesOverlap(ax, ay, bx, by, ts int) bool {
return ax < bx+ts && bx < ax+ts && ay < by+ts && by < ay+ts
}
// TestCopyRectDetector_DownwardScrollNoOverlap exercises a downward scroll,
// where each move's source is the destination of the move one row above it.
// Emitting all of them in order would corrupt the client framebuffer because
// the earlier move overwrites the source pixels the later move reads. The
// detector must drop any move whose source overlaps a prior move's
// destination and route that tile to pixel encoding instead.
func TestCopyRectDetector_DownwardScrollNoOverlap(t *testing.T) {
const w, h = 256, 192 // 4×3 tiles at 64px
const ts = 64
prev := image.NewRGBA(image.Rect(0, 0, w, h))
cur := image.NewRGBA(image.Rect(0, 0, w, h))
// prev: 12 tiles each with a unique colour.
for ty := 0; ty < 3; ty++ {
for tx := 0; tx < 4; tx++ {
fillTile(prev, tx*ts, ty*ts, ts, byte(tx*40), byte(ty*60), 0x80)
}
}
// cur: scroll downward by one row. Rows 1 and 2 are copied from prev
// rows 0 and 1; the top row is new content.
for ty := 1; ty < 3; ty++ {
for tx := 0; tx < 4; tx++ {
copyTile(cur, prev, tx*ts, (ty-1)*ts, tx*ts, ty*ts, ts)
}
}
for tx := 0; tx < 4; tx++ {
fillTile(cur, tx*ts, 0, ts, 0xff, 0xff, 0xff)
}
d := newCopyRectDetector(ts)
d.rebuild(prev, w, h)
tiles := diffTiles(prev, cur, w, h, ts)
wantTiles := len(tiles)
moves, remaining := d.extractCopyRectTiles(cur, tiles)
// No move's source may overlap an earlier move's destination.
for i, m := range moves {
for _, prior := range moves[:i] {
if tilesOverlap(m.srcX, m.srcY, prior.dstX, prior.dstY, ts) {
t.Fatalf("move %d src (%d,%d) overlaps prior dst (%d,%d)",
i, m.srcX, m.srcY, prior.dstX, prior.dstY)
}
}
}
// The dropped row-2 moves must fall through to pixel encoding rather than
// being silently skipped, so the region still updates correctly.
if len(moves)+len(remaining) != wantTiles {
t.Fatalf("moves(%d)+remaining(%d) != dirty tiles(%d): a tile was lost",
len(moves), len(remaining), wantTiles)
}
if len(moves) != 4 {
t.Fatalf("moves: want 4 (top scrolled row only), got %d", len(moves))
}
}
func TestCopyRectDetector_RejectsSelfMatch(t *testing.T) {
const w, h = 128, 128
const ts = 64
prev := image.NewRGBA(image.Rect(0, 0, w, h))
cur := image.NewRGBA(image.Rect(0, 0, w, h))
// prev: 4 tiles, all unique
fillTile(prev, 0, 0, ts, 0x10, 0x20, 0x30)
fillTile(prev, ts, 0, ts, 0x40, 0x50, 0x60)
fillTile(prev, 0, ts, ts, 0x70, 0x80, 0x90)
fillTile(prev, ts, ts, ts, 0xa0, 0xb0, 0xc0)
// cur: tile (0,0) unchanged, others changed but content same as prev's (0,0).
fillTile(cur, 0, 0, ts, 0x10, 0x20, 0x30) // self-match
fillTile(cur, ts, 0, ts, 0xff, 0xff, 0xff)
fillTile(cur, 0, ts, ts, 0xff, 0xff, 0xff)
fillTile(cur, ts, ts, ts, 0xff, 0xff, 0xff)
d := newCopyRectDetector(ts)
d.rebuild(prev, w, h)
// Tile (0,0) is not in the dirty list (it's unchanged) so it should not
// produce a move even though its hash matches prev (0,0).
tiles := diffTiles(prev, cur, w, h, ts)
moves, _ := d.extractCopyRectTiles(cur, tiles)
for _, m := range moves {
if m.dstX == 0 && m.dstY == 0 {
t.Fatalf("unexpected move at (0,0)")
}
}
}
func TestCopyRectDetector_PassThroughWhenNoMatch(t *testing.T) {
const w, h = 64, 64
const ts = 64
prev := image.NewRGBA(image.Rect(0, 0, w, h))
cur := image.NewRGBA(image.Rect(0, 0, w, h))
fillTile(prev, 0, 0, ts, 0x11, 0x22, 0x33)
fillTile(cur, 0, 0, ts, 0xaa, 0xbb, 0xcc) // wholly different
d := newCopyRectDetector(ts)
d.rebuild(prev, w, h)
tiles := diffTiles(prev, cur, w, h, ts)
moves, remaining := d.extractCopyRectTiles(cur, tiles)
if len(moves) != 0 {
t.Fatalf("expected 0 moves, got %d", len(moves))
}
if len(remaining) != 1 {
t.Fatalf("expected 1 residual tile, got %d", len(remaining))
}
}
func TestEncodeCopyRectBody_Layout(t *testing.T) {
got := encodeCopyRectBody(100, 200, 300, 400, 64, 48)
if len(got) != 16 {
t.Fatalf("CopyRect body length: want 16, got %d", len(got))
}
// Dest position
if got[0] != 0x01 || got[1] != 0x2c || got[2] != 0x01 || got[3] != 0x90 {
t.Fatalf("bad dest bytes: % x", got[0:4])
}
// Width, height
if got[4] != 0 || got[5] != 64 || got[6] != 0 || got[7] != 48 {
t.Fatalf("bad size bytes: % x", got[4:8])
}
// Encoding = 1
if got[11] != 0x01 {
t.Fatalf("bad encoding byte: 0x%02x", got[11])
}
// Source position
if got[12] != 0 || got[13] != 100 || got[14] != 0 || got[15] != 200 {
t.Fatalf("bad src bytes: % x", got[12:16])
}
}

Some files were not shown because too many files have changed in this diff Show More