Compare commits

..

1 Commits

Author SHA1 Message Date
Maycon Santos
5207054701 [management] Enable lazy connections by default on new accounts 2026-06-28 15:42:59 +02:00
182 changed files with 4235 additions and 25783 deletions

View File

@@ -45,7 +45,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags 'devcert privileged' -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/testutil/privileged)
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0

View File

@@ -48,14 +48,14 @@ jobs:
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
time go build -o netbird client/main.go
# check all component except management, since we do not support management server on freebsd
time go test -tags privileged -timeout 1m -failfast ./base62/...
time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
time go test -tags privileged -timeout 8m -failfast -v -p 1 ./client/...
time go test -tags privileged -timeout 1m -failfast ./dns/...
time go test -tags privileged -timeout 1m -failfast ./encryption/...
time go test -tags privileged -timeout 1m -failfast ./formatter/...
time go test -tags privileged -timeout 1m -failfast ./client/iface/...
time go test -tags privileged -timeout 1m -failfast ./route/...
time go test -tags privileged -timeout 1m -failfast ./sharedsock/...
time go test -tags privileged -timeout 1m -failfast ./util/...
time go test -tags privileged -timeout 1m -failfast ./version/...
time go test -timeout 8m -failfast -v -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/...
time go test -timeout 1m -failfast ./client/iface/...
time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./util/...
time go test -timeout 1m -failfast ./version/...

View File

@@ -158,7 +158,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
@@ -229,7 +229,7 @@ jobs:
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags "devcert privileged" -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server -e /client/testutil/privileged)
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
'
test_relay:

View File

@@ -68,7 +68,7 @@ jobs:
run: |
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
$cmd = "$goExe test -tags `"devcert privileged`" -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
- name: test

View File

@@ -1,4 +1,4 @@
.PHONY: lint lint-all lint-install setup-hooks test-unit test-privileged
.PHONY: lint lint-all lint-install setup-hooks
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
# Install golangci-lint locally if needed
@@ -25,15 +25,3 @@ setup-hooks:
@git config core.hooksPath .githooks
@chmod +x .githooks/pre-push
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
# Host-safe unit tests: excludes the privileged-tagged tests (root / system-mutating).
# Runs as a normal user with no sudo and leaves host networking untouched.
test-unit:
@go test -tags devcert -timeout 10m ./...
# Privileged suite: runs the `privileged`-tagged tests inside a --privileged
# --cap-add=NET_ADMIN container via the ory/dockertest harness. Requires Docker.
# Narrow the run with env vars, e.g.:
# PRIV_RUN=TestNftablesManager PRIV_PKGS=./client/firewall/nftables/... make test-privileged
test-privileged:
@go test -tags 'devcert privileged' -timeout 30m -run TestRunPrivilegedSuiteInDocker -v ./client/testutil/privileged/...

View File

@@ -1,196 +0,0 @@
//go:build privileged
package cmd
import (
"context"
"fmt"
"os"
"runtime"
"testing"
"time"
"github.com/kardianos/service"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
serviceStartTimeout = 10 * time.Second
serviceStopTimeout = 5 * time.Second
statusPollInterval = 500 * time.Millisecond
)
// waitForServiceStatus waits for service to reach expected status with timeout
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
cfg, err := newSVCConfig()
if err != nil {
return false, err
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
return false, err
}
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
defer timeoutCancel()
ticker := time.NewTicker(statusPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
case <-ticker.C:
status, err := s.Status()
if err != nil {
// Continue polling on transient errors
continue
}
if status == expectedStatus {
return true, nil
}
}
}
}
// TestServiceLifecycle tests the complete service lifecycle
func TestServiceLifecycle(t *testing.T) {
// TODO: Add support for Windows and macOS
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
}
if os.Getenv("CONTAINER") == "true" {
t.Skip("Skipping service lifecycle test in container environment")
}
originalServiceName := serviceName
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
defer func() {
serviceName = originalServiceName
}()
tempDir := t.TempDir()
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
logLevel = "info"
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
t.Cleanup(func() {
cfg, err := newSVCConfig()
if err != nil {
t.Errorf("cleanup: create service config: %v", err)
return
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
t.Errorf("cleanup: create service: %v", err)
return
}
// If the subtests already cleaned up, there's nothing to do.
if _, err := s.Status(); err != nil {
return
}
if err := s.Stop(); err != nil {
t.Errorf("cleanup: stop service: %v", err)
}
if err := s.Uninstall(); err != nil {
t.Errorf("cleanup: uninstall service: %v", err)
}
})
ctx := context.Background()
t.Run("Install", func(t *testing.T) {
installCmd.SetContext(ctx)
err := installCmd.RunE(installCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
status, err := s.Status()
assert.NoError(t, err)
assert.NotEqual(t, service.StatusUnknown, status)
})
t.Run("Start", func(t *testing.T) {
startCmd.SetContext(ctx)
err := startCmd.RunE(startCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Restart", func(t *testing.T) {
restartCmd.SetContext(ctx)
err := restartCmd.RunE(restartCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Reconfigure", func(t *testing.T) {
originalLogLevel := logLevel
logLevel = "debug"
defer func() {
logLevel = originalLogLevel
}()
reconfigureCmd.SetContext(ctx)
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Stop", func(t *testing.T) {
stopCmd.SetContext(ctx)
err := stopCmd.RunE(stopCmd, []string{})
require.NoError(t, err)
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
require.NoError(t, err)
assert.True(t, stopped)
})
t.Run("Uninstall", func(t *testing.T) {
uninstallCmd.SetContext(ctx)
err := uninstallCmd.RunE(uninstallCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
_, err = s.Status()
assert.Error(t, err)
})
}

View File

@@ -1,12 +1,16 @@
package cmd
import (
"context"
"fmt"
"os"
"os/signal"
"runtime"
"syscall"
"testing"
"time"
"github.com/kardianos/service"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -27,6 +31,186 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}
const (
serviceStartTimeout = 10 * time.Second
serviceStopTimeout = 5 * time.Second
statusPollInterval = 500 * time.Millisecond
)
// waitForServiceStatus waits for service to reach expected status with timeout
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
cfg, err := newSVCConfig()
if err != nil {
return false, err
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
return false, err
}
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
defer timeoutCancel()
ticker := time.NewTicker(statusPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
case <-ticker.C:
status, err := s.Status()
if err != nil {
// Continue polling on transient errors
continue
}
if status == expectedStatus {
return true, nil
}
}
}
}
// TestServiceLifecycle tests the complete service lifecycle
func TestServiceLifecycle(t *testing.T) {
// TODO: Add support for Windows and macOS
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
}
if os.Getenv("CONTAINER") == "true" {
t.Skip("Skipping service lifecycle test in container environment")
}
originalServiceName := serviceName
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
defer func() {
serviceName = originalServiceName
}()
tempDir := t.TempDir()
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
logLevel = "info"
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
t.Cleanup(func() {
cfg, err := newSVCConfig()
if err != nil {
t.Errorf("cleanup: create service config: %v", err)
return
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
t.Errorf("cleanup: create service: %v", err)
return
}
// If the subtests already cleaned up, there's nothing to do.
if _, err := s.Status(); err != nil {
return
}
if err := s.Stop(); err != nil {
t.Errorf("cleanup: stop service: %v", err)
}
if err := s.Uninstall(); err != nil {
t.Errorf("cleanup: uninstall service: %v", err)
}
})
ctx := context.Background()
t.Run("Install", func(t *testing.T) {
installCmd.SetContext(ctx)
err := installCmd.RunE(installCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
status, err := s.Status()
assert.NoError(t, err)
assert.NotEqual(t, service.StatusUnknown, status)
})
t.Run("Start", func(t *testing.T) {
startCmd.SetContext(ctx)
err := startCmd.RunE(startCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Restart", func(t *testing.T) {
restartCmd.SetContext(ctx)
err := restartCmd.RunE(restartCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Reconfigure", func(t *testing.T) {
originalLogLevel := logLevel
logLevel = "debug"
defer func() {
logLevel = originalLogLevel
}()
reconfigureCmd.SetContext(ctx)
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Stop", func(t *testing.T) {
stopCmd.SetContext(ctx)
err := stopCmd.RunE(stopCmd, []string{})
require.NoError(t, err)
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
require.NoError(t, err)
assert.True(t, stopped)
})
t.Run("Uninstall", func(t *testing.T) {
uninstallCmd.SetContext(ctx)
err := uninstallCmd.RunE(uninstallCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
_, err = s.Status()
assert.Error(t, err)
})
}
// TestServiceEnvVars tests environment variable parsing
func TestServiceEnvVars(t *testing.T) {
tests := []struct {

View File

@@ -401,12 +401,6 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(serverSSHAllowedFlag).Changed {
req.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
req.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(disableVNCApprovalFlag).Changed {
req.DisableVNCApproval = &disableVNCApproval
}
if cmd.Flag(enableSSHRootFlag).Changed {
req.EnableSSHRoot = &enableSSHRoot
}
@@ -513,14 +507,30 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
ic.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(disableVNCApprovalFlag).Changed {
ic.DisableVNCApproval = &disableVNCApproval
if cmd.Flag(enableSSHRootFlag).Changed {
ic.EnableSSHRoot = &enableSSHRoot
}
applySSHFlagsToConfig(cmd, &ic)
if cmd.Flag(enableSSHSFTPFlag).Changed {
ic.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
ic.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
@@ -596,49 +606,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
return &ic, nil
}
func applySSHFlagsToConfig(cmd *cobra.Command, ic *profilemanager.ConfigInput) {
if cmd.Flag(enableSSHRootFlag).Changed {
ic.EnableSSHRoot = &enableSSHRoot
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
ic.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
ic.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
}
}
func applySSHFlagsToLogin(cmd *cobra.Command, req *proto.LoginRequest) {
if cmd.Flag(enableSSHRootFlag).Changed {
req.EnableSSHRoot = &enableSSHRoot
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
req.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
req.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ttl := int32(sshJWTCacheTTL)
req.SshJWTCacheTTL = &ttl
}
}
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
@@ -668,14 +635,31 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
if cmd.Flag(serverSSHAllowedFlag).Changed {
loginRequest.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(serverVNCAllowedFlag).Changed {
loginRequest.ServerVNCAllowed = &serverVNCAllowed
}
if cmd.Flag(disableVNCApprovalFlag).Changed {
loginRequest.DisableVNCApproval = &disableVNCApproval
if cmd.Flag(enableSSHRootFlag).Changed {
loginRequest.EnableSSHRoot = &enableSSHRoot
}
applySSHFlagsToLogin(cmd, &loginRequest)
if cmd.Flag(enableSSHSFTPFlag).Changed {
loginRequest.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
loginRequest.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if cmd.Flag(disableAutoConnectFlag).Changed {
loginRequest.DisableAutoConnect = &autoConnectDisabled

View File

@@ -1,100 +0,0 @@
//go:build windows || (darwin && !ios)
package cmd
import (
"fmt"
"net"
"net/netip"
"os"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
var (
vncAgentSocket string
vncAgentTargetUID uint32
)
func init() {
vncAgentCmd.Flags().StringVar(&vncAgentSocket, "socket", "", "Unix-domain socket path the agent listens on (required)")
vncAgentCmd.Flags().Uint32Var(&vncAgentTargetUID, "target-uid", 0, "uid the agent should drop privileges to before listening (darwin only; 0 = stay as current uid)")
rootCmd.AddCommand(vncAgentCmd)
}
// vncAgentCmd runs a VNC server inside the user's interactive session,
// listening on a Unix-domain socket. The NetBird service spawns it: on
// Windows via CreateProcessAsUser into the console session, on macOS via
// launchctl asuser into the Aqua session.
var vncAgentCmd = &cobra.Command{
Use: "vnc-agent",
Short: "Run VNC capture agent (internal, spawned by service)",
Hidden: true,
RunE: func(cmd *cobra.Command, args []string) error {
log.SetReportCaller(true)
log.SetFormatter(&log.JSONFormatter{})
log.SetOutput(os.Stderr)
if vncAgentSocket == "" {
return fmt.Errorf("--socket is required")
}
token := os.Getenv("NB_VNC_AGENT_TOKEN")
if token == "" {
return fmt.Errorf("NB_VNC_AGENT_TOKEN not set; agent requires a token from the service")
}
// Purge the token from env so it doesn't leak via /proc/<pid>/environ.
if err := os.Unsetenv("NB_VNC_AGENT_TOKEN"); err != nil {
log.Debugf("unset NB_VNC_AGENT_TOKEN: %v", err)
}
// Drop root privileges to the target console user BEFORE creating
// the listening socket: keeps a post-auth bug in the encoder /
// input / capture paths confined to the user's own privileges
// rather than escalating to host root, and makes the daemon's
// LOCAL_PEERCRED check see the right uid. No-op on Windows
// (both processes run as SYSTEM) and when --target-uid is 0.
if vncAgentTargetUID != 0 {
if err := dropAgentPrivileges(vncAgentTargetUID); err != nil {
return fmt.Errorf("drop privileges to uid %d: %w", vncAgentTargetUID, err)
}
}
if err := os.Remove(vncAgentSocket); err != nil && !os.IsNotExist(err) {
log.Debugf("remove stale socket %s: %v", vncAgentSocket, err)
}
ln, err := net.Listen("unix", vncAgentSocket)
if err != nil {
return fmt.Errorf("listen on %s: %w", vncAgentSocket, err)
}
if err := os.Chmod(vncAgentSocket, 0o600); err != nil {
log.Debugf("chmod %s: %v", vncAgentSocket, err)
}
capturer, injector, err := newAgentResources()
if err != nil {
_ = ln.Close()
return err
}
srv := vncserver.New(vncserver.Config{
Capturer: capturer,
Injector: injector,
DisableAuth: true,
AgentTokenHex: token,
Listener: ln,
})
if err := srv.Start(cmd.Context(), netip.AddrPort{}, netip.Prefix{}); err != nil {
return fmt.Errorf("start vnc server: %w", err)
}
log.Infof("vnc-agent listening on %s, ready", vncAgentSocket)
<-cmd.Context().Done()
log.Info("vnc-agent context cancelled, shutting down")
return srv.Stop()
},
SilenceUsage: true,
}

View File

@@ -1,18 +0,0 @@
//go:build darwin && !ios
package cmd
import (
"fmt"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newAgentResources() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
capturer := vncserver.NewMacPoller()
injector, err := vncserver.NewMacInputInjector()
if err != nil {
return nil, nil, fmt.Errorf("macOS input injector: %w", err)
}
return capturer, injector, nil
}

View File

@@ -1,74 +0,0 @@
//go:build darwin && !ios
package cmd
import (
"fmt"
"os"
"os/user"
"strconv"
"syscall"
)
// dropAgentPrivileges drops the vnc-agent process from root (its
// launchctl-asuser-inherited starting uid) to the target console user
// before any other initialisation runs. Without this the agent runs as
// root for the lifetime of the session; any post-auth memory-safety
// issue in the capture/input/encode paths would then be a root-level
// RCE on the host instead of a user-level one. Also makes the daemon's
// LOCAL_PEERCRED check correctly identify the agent as the console user,
// not as root.
//
// Returns an error when the agent is running as a non-root uid that
// differs from targetUID: non-root can only setuid to itself, so a
// mismatch here means the spawn went to the wrong session.
func dropAgentPrivileges(targetUID uint32) error {
if targetUID == 0 {
return fmt.Errorf("refusing to keep agent running as root (target uid 0)")
}
cur := uint32(os.Getuid())
if cur == targetUID {
return nil
}
if cur != 0 {
return fmt.Errorf("agent uid %d does not match expected %d and we lack root to fix it", cur, targetUID)
}
// Resolve the target user's real primary group rather than reusing
// targetUID as the gid: a user's primary group on macOS is typically
// staff(20), not gid==uid. Fail closed if the lookup fails.
targetGID, err := primaryGroupID(targetUID)
if err != nil {
return err
}
// Drop supplementary groups first: setgid alone doesn't touch the
// auxiliary group list, leaving root's groups attached would let the
// dropped process write to root-only group-writable files.
if err := syscall.Setgroups([]int{}); err != nil {
return fmt.Errorf("setgroups([]): %w", err)
}
if err := syscall.Setgid(targetGID); err != nil {
return fmt.Errorf("setgid(%d): %w", targetGID, err)
}
if err := syscall.Setuid(int(targetUID)); err != nil {
return fmt.Errorf("setuid(%d): %w", targetUID, err)
}
if uint32(os.Getuid()) != targetUID || uint32(os.Geteuid()) != targetUID {
return fmt.Errorf("setuid verification: uid=%d euid=%d, expected %d", os.Getuid(), os.Geteuid(), targetUID)
}
return nil
}
// primaryGroupID resolves the real primary group id of the user with the
// given uid. Fails closed: a lookup or parse error returns an error so the
// caller never falls back to using uid as the gid.
func primaryGroupID(targetUID uint32) (int, error) {
u, err := user.LookupId(strconv.Itoa(int(targetUID)))
if err != nil {
return 0, fmt.Errorf("look up uid %d: %w", targetUID, err)
}
gid, err := strconv.Atoi(u.Gid)
if err != nil {
return 0, fmt.Errorf("parse gid %q for uid %d: %w", u.Gid, targetUID, err)
}
return gid, nil
}

View File

@@ -1,55 +0,0 @@
//go:build darwin && !ios
package cmd
import (
"strings"
"testing"
)
// TestDropAgentPrivileges_RefusesRootTarget locks in the contract that
// dropAgentPrivileges must never be a no-op when asked to keep the
// agent as root (target uid 0). A future caller that passes 0 by
// mistake would otherwise leave the post-auth attack surface running
// with full root privileges.
func TestDropAgentPrivileges_RefusesRootTarget(t *testing.T) {
err := dropAgentPrivileges(0)
if err == nil {
t.Fatal("expected refusal for target uid 0, got nil")
}
if !strings.Contains(err.Error(), "root") {
t.Fatalf("error should mention root, got: %v", err)
}
}
// TestDropAgentPrivileges_NoOpWhenAlreadyTarget covers the dev path
// where the agent is launched by hand as the target user (no root
// available, no setuid needed). The helper must succeed silently
// instead of trying (and failing) a setuid to its current uid.
func TestDropAgentPrivileges_NoOpWhenAlreadyTarget(t *testing.T) {
// Skip when running as root: the early-return path we want to
// cover only fires when current uid == target uid.
uid := currentUIDForTest()
if uid == 0 {
t.Skip("test must not run as root; cannot exercise the no-op early-return")
}
if err := dropAgentPrivileges(uid); err != nil {
t.Fatalf("expected no-op when current uid == target, got: %v", err)
}
}
// TestDropAgentPrivileges_RefusesMismatchedNonRoot guards the "non-root
// caller tries to setuid to a different uid" path: setuid would fail
// with EPERM anyway, but the helper should surface a clear error
// before issuing the syscall so a misconfigured spawn (wrong --target-uid
// flag) is debuggable.
func TestDropAgentPrivileges_RefusesMismatchedNonRoot(t *testing.T) {
uid := currentUIDForTest()
if uid == 0 {
t.Skip("test must not run as root; covered case requires non-root caller")
}
err := dropAgentPrivileges(uid + 1)
if err == nil {
t.Fatal("expected refusal when non-root caller asks to setuid elsewhere")
}
}

View File

@@ -1,11 +0,0 @@
//go:build darwin && !ios
package cmd
import "os"
// currentUIDForTest exposes os.Getuid for the darwin dropprivs tests
// without leaking an os import into the test file itself.
func currentUIDForTest() uint32 {
return uint32(os.Getuid())
}

View File

@@ -1,14 +0,0 @@
//go:build windows
package cmd
// dropAgentPrivileges is a no-op on Windows: the agent and the daemon
// both run as SYSTEM (the daemon spawns the agent into the interactive
// session via CreateProcessAsUser with an impersonation token, but the
// resulting process still runs under SYSTEM, not under the user's
// account). The Windows path relies on the DACL-restricted socket
// directory, the unpredictable per-spawn socket name, the listen-readiness
// gate, and the per-spawn token for integrity instead.
func dropAgentPrivileges(_ uint32) error {
return nil
}

View File

@@ -1,15 +0,0 @@
//go:build windows
package cmd
import (
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newAgentResources() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
sessionID := vncserver.GetCurrentSessionID()
log.Infof("VNC agent running in Windows session %d", sessionID)
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector(), nil
}

View File

@@ -1,16 +0,0 @@
package cmd
const (
serverVNCAllowedFlag = "allow-server-vnc"
disableVNCApprovalFlag = "disable-vnc-approval"
)
var (
serverVNCAllowed bool
disableVNCApproval bool
)
func init() {
upCmd.PersistentFlags().BoolVar(&serverVNCAllowed, serverVNCAllowedFlag, false, "Allow embedded VNC server on peer")
upCmd.PersistentFlags().BoolVar(&disableVNCApproval, disableVNCApprovalFlag, false, "Disable per-connection user approval prompts for the embedded VNC server")
}

View File

@@ -6,30 +6,19 @@ import (
"runtime"
)
var (
// StateDir holds persistent state (config, profiles, install metadata).
StateDir string
// RuntimeDir holds ephemeral artifacts that should not survive reboot,
// such as Unix sockets for daemon and per-session IPC. Empty on
// platforms without a conventional /var/run-style location.
RuntimeDir string
)
var StateDir string
func init() {
StateDir = os.Getenv("NB_STATE_DIR")
if StateDir != "" {
return
}
switch runtime.GOOS {
case "windows":
StateDir = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird")
case "darwin", "linux":
StateDir = "/var/lib/netbird"
RuntimeDir = "/var/run/netbird"
case "freebsd", "openbsd", "netbsd", "dragonfly":
StateDir = "/var/db/netbird"
RuntimeDir = "/var/run/netbird"
}
if v := os.Getenv("NB_STATE_DIR"); v != "" {
StateDir = v
}
if v := os.Getenv("NB_RUNTIME_DIR"); v != "" {
RuntimeDir = v
}
}

View File

@@ -1,5 +1,3 @@
//go:build privileged
package iptables
import (

View File

@@ -1,4 +1,4 @@
//go:build !android && privileged
//go:build !android
package iptables

View File

@@ -1,5 +1,3 @@
//go:build privileged
package nftables
import (

View File

@@ -1,4 +1,4 @@
//go:build !android && privileged
//go:build !android
package nftables

View File

@@ -1,5 +1,3 @@
//go:build privileged
package iface
import (

View File

@@ -1,4 +1,4 @@
//go:build linux && !android && privileged
//go:build linux && !android
package wgproxy

View File

@@ -1,4 +1,4 @@
//go:build !linux || !privileged
//go:build !linux
package wgproxy

View File

@@ -1,4 +1,4 @@
//go:build linux && !android && privileged
//go:build linux && !android
package wgproxy
@@ -26,6 +26,64 @@ func compareUDPAddr(addr1, addr2 net.Addr) bool {
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
}
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
wgPort := 51850
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("192.168.0.56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
wgPort := 51851
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("fe80::56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
func TestRedirectAs_UDP_IPv4(t *testing.T) {
wgPort := 51852
@@ -198,64 +256,6 @@ func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *
}
}
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
wgPort := 51850
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("192.168.0.56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
wgPort := 51851
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("fe80::56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
func TestRedirectAs_Multiple_Switches(t *testing.T) {
wgPort := 51856

View File

@@ -1,219 +0,0 @@
// Package approval brokers per-attempt user-accept prompts for inbound
// remote access (VNC today, SSH and others in the future). A caller pushes
// a Prompt; the broker emits a SystemEvent on the daemon→UI stream and
// blocks until the UI calls the daemon's RespondApproval RPC, the per-
// request timeout fires, or no subscriber is connected. The latter case
// fails closed so a backgrounded UI cannot silently bypass the gate.
package approval
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/proto"
)
// Metadata keys the broker reserves on the emitted SystemEvent. Callers
// should not set these themselves; values in Prompt.Metadata that collide
// are overwritten by the broker.
const (
MetaRequestID = "request_id"
MetaKind = "kind"
MetaExpiresAt = "expires_at"
)
// ShortKeyFingerprint formats a hex-encoded Noise_IK static pubkey as a
// short, eyeball-able fingerprint to display in the approval dialog.
// The dashboard-supplied display name attached to a SessionPubKey isn't
// cryptographically asserted by the connecting client, so the prompt
// must also show something that IS: the key fingerprint, a hash of
// the static public key the client just proved possession of during the
// Noise handshake. Returns the empty string when the input is too short
// to plausibly be a hex pubkey, so the row is omitted rather than
// rendered as a misleading partial.
//
// Output format: 16 hex chars grouped as XXXX-XXXX-XXXX-XXXX (64 bits of
// fingerprint, resistant to random-prefix collisions and easy for a human
// to compare with an out-of-band reference).
func ShortKeyFingerprint(hexKey string) string {
if len(hexKey) < 8 {
return ""
}
src := hexKey
if len(src) > 16 {
src = src[:16]
}
var out []byte
for i, c := range src {
if i > 0 && i%4 == 0 {
out = append(out, '-')
}
out = append(out, byte(c))
}
return string(out)
}
// Kind values for the well-known prompt subjects. New subsystems should
// add a constant here so the UI can dispatch on a known string.
const (
KindVNC = "vnc"
KindSSH = "ssh"
)
// DefaultTimeout is the wall-clock window the user has to accept or deny a
// pending approval before the broker fails closed and returns ErrTimeout.
// Kept well under typical VNC client and dashboard connection timeouts so
// the RFB rejection actually reaches the browser instead of racing the
// browser's own "connection timed out" message.
const DefaultTimeout = 15 * time.Second
// timeoutValue returns the active timeout. It's a var so tests in this
// package can shorten the wait without exposing a setter on the public
// API. Production code always sees DefaultTimeout.
var timeoutValue = func() time.Duration { return DefaultTimeout }
// ErrNoSubscriber indicates no UI is connected to consume the prompt.
// The caller must reject the underlying connection (fail-closed).
var ErrNoSubscriber = errors.New("no UI subscriber connected for approval")
// ErrTimeout indicates the user did not respond within DefaultTimeout.
var ErrTimeout = errors.New("approval timed out")
// ErrDenied indicates the user explicitly denied the connection.
var ErrDenied = errors.New("approval denied")
// EventPublisher is the subset of peer.Status used to emit prompts.
type EventPublisher interface {
PublishEvent(
severity proto.SystemEvent_Severity,
category proto.SystemEvent_Category,
msg string,
userMsg string,
metadata map[string]string,
)
HasEventSubscribers() bool
}
// Prompt describes the pending request shown to the user. Kind selects
// the UI dispatch path (e.g. "vnc", "ssh"). Subject is the human-readable
// one-liner the UI may show as a title or notification body. Metadata is
// passed through verbatim and is the subsystem-specific payload (peer
// name, source IP, mode, etc.).
type Prompt struct {
Kind string
Subject string
Metadata map[string]string
}
// Decision carries the user's response to an approval prompt. ViewOnly is
// only meaningful when Accept is true; it lets the host grant the
// connection but signal the requester that input control is withheld.
type Decision struct {
Accept bool
ViewOnly bool
}
// Broker holds in-flight approval requests keyed by request ID.
type Broker struct {
pub EventPublisher
mu sync.Mutex
pending map[string]chan Decision
}
// New returns a broker that publishes prompts via pub.
func New(pub EventPublisher) *Broker {
return &Broker{
pub: pub,
pending: make(map[string]chan Decision),
}
}
// Request emits a SystemEvent for p and blocks until the UI calls Respond,
// ctx is cancelled, or DefaultTimeout elapses. Returns a Decision when
// the user replied; ErrDenied / ErrTimeout / ErrNoSubscriber / ctx.Err
// otherwise. Callers must treat any non-nil error as a deny.
func (b *Broker) Request(ctx context.Context, p Prompt) (Decision, error) {
var zero Decision
if b == nil || b.pub == nil {
return zero, fmt.Errorf("approval broker not configured")
}
if !b.pub.HasEventSubscribers() {
return zero, ErrNoSubscriber
}
id := uuid.NewString()
resp := make(chan Decision, 1)
b.mu.Lock()
b.pending[id] = resp
b.mu.Unlock()
defer b.dropPending(id)
timeout := timeoutValue()
expiresAt := time.Now().Add(timeout)
meta := make(map[string]string, len(p.Metadata)+3)
for k, v := range p.Metadata {
meta[k] = v
}
meta[MetaRequestID] = id
meta[MetaKind] = p.Kind
meta[MetaExpiresAt] = expiresAt.UTC().Format(time.RFC3339)
subject := p.Subject
if subject == "" {
subject = fmt.Sprintf("%s connection requires approval", p.Kind)
}
b.pub.PublishEvent(proto.SystemEvent_INFO, proto.SystemEvent_APPROVAL, subject, subject, meta)
log.Debugf("approval request %s (%s) emitted: %s", id, p.Kind, subject)
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case d := <-resp:
if !d.Accept {
return zero, ErrDenied
}
return d, nil
case <-timer.C:
return zero, ErrTimeout
case <-ctx.Done():
return zero, ctx.Err()
}
}
// Respond delivers the user's decision for id. Returns true when a pending
// request matched and was woken, false when id was unknown or already done.
func (b *Broker) Respond(id string, d Decision) bool {
if b == nil {
return false
}
b.mu.Lock()
ch, ok := b.pending[id]
if ok {
delete(b.pending, id)
}
b.mu.Unlock()
if !ok {
return false
}
select {
case ch <- d:
default:
}
return true
}
func (b *Broker) dropPending(id string) {
b.mu.Lock()
delete(b.pending, id)
b.mu.Unlock()
}

View File

@@ -1,434 +0,0 @@
package approval
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/proto"
)
// fakePublisher records published events and reports whether subscribers
// are connected. The subscribers flag is the security-critical signal:
// when false the broker must refuse to emit and the gate must fail closed.
type fakePublisher struct {
mu sync.Mutex
subscribers bool
events []*proto.SystemEvent
}
func (p *fakePublisher) PublishEvent(
severity proto.SystemEvent_Severity,
category proto.SystemEvent_Category,
msg string,
userMsg string,
metadata map[string]string,
) {
p.mu.Lock()
p.events = append(p.events, &proto.SystemEvent{
Severity: severity,
Category: category,
Message: msg,
UserMessage: userMsg,
Metadata: metadata,
})
p.mu.Unlock()
}
func (p *fakePublisher) HasEventSubscribers() bool {
p.mu.Lock()
defer p.mu.Unlock()
return p.subscribers
}
func (p *fakePublisher) lastEvent(t *testing.T) *proto.SystemEvent {
t.Helper()
p.mu.Lock()
defer p.mu.Unlock()
require.NotEmpty(t, p.events, "publisher saw no events")
return p.events[len(p.events)-1]
}
func (p *fakePublisher) eventCount() int {
p.mu.Lock()
defer p.mu.Unlock()
return len(p.events)
}
// TestRequestNoSubscriberFailsClosed is the core fail-closed invariant:
// when the UI is not subscribed, the broker must refuse without emitting
// an event or arming a waiter. A regression here is a silent bypass.
func TestRequestNoSubscriberFailsClosed(t *testing.T) {
pub := &fakePublisher{subscribers: false}
b := New(pub)
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
assert.ErrorIs(t, err, ErrNoSubscriber)
assert.Equal(t, 0, pub.eventCount(), "no event must be emitted when fail-closed")
b.mu.Lock()
pending := len(b.pending)
b.mu.Unlock()
assert.Equal(t, 0, pending, "no waiter must be registered on fail-closed")
}
// TestRequestTimeoutDenies verifies that a request without a UI response
// returns ErrTimeout (deny) rather than nil (silent accept). Uses a short
// per-test broker timeout via Respond after the fact to keep the test fast.
func TestRequestTimeoutDenies(t *testing.T) {
// Replace DefaultTimeout for the lifetime of this test.
orig := DefaultTimeout
defaultTimeout(t, 60*time.Millisecond)
defer defaultTimeout(t, orig)
pub := &fakePublisher{subscribers: true}
b := New(pub)
start := time.Now()
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
assert.ErrorIs(t, err, ErrTimeout, "missing user response must yield ErrTimeout, not nil")
assert.GreaterOrEqual(t, time.Since(start), 50*time.Millisecond, "timeout fired prematurely")
}
// TestRequestDenied returns ErrDenied when the UI responds with false.
func TestRequestDenied(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
var requestID string
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
}()
requestID = waitForRequestID(t, pub)
require.True(t, b.Respond(requestID, Decision{Accept: false}))
select {
case err := <-done:
assert.ErrorIs(t, err, ErrDenied)
case <-time.After(time.Second):
t.Fatal("Request did not return after Respond(false)")
}
}
// TestRequestAccepted is the happy path. Failure here doesn't bypass the
// gate but breaks the feature.
func TestRequestAccepted(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
}()
id := waitForRequestID(t, pub)
require.True(t, b.Respond(id, Decision{Accept: true}))
select {
case err := <-done:
assert.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("Request did not return after Respond(true)")
}
}
// TestRequestCtxCancelDenies verifies that an upstream cancel (e.g. the
// engine shutting down mid-prompt) returns the cancel error rather than
// nil. A nil here would be a silent bypass on shutdown races.
func TestRequestCtxCancelDenies(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
ctx, cancel := context.WithCancel(context.Background())
done := make(chan error, 1)
go func() {
done <- requestErr(b, ctx, Prompt{Kind: KindVNC, Subject: "test"})
}()
// Wait until the prompt is in flight so cancel races a live waiter.
_ = waitForRequestID(t, pub)
cancel()
select {
case err := <-done:
assert.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
t.Fatal("Request did not return after ctx cancel")
}
}
// TestRespondUnknownIsNoop ensures a stray RespondApproval RPC cannot
// affect or accidentally accept any in-flight request whose id it doesn't
// match. Also confirms it doesn't panic.
func TestRespondUnknownIsNoop(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
// No in-flight prompts: Respond returns false.
assert.False(t, b.Respond("does-not-exist", Decision{Accept: true}))
// With an in-flight prompt, a wrong id still returns false and the
// prompt remains armed (eventually timing out as a deny).
defaultTimeout(t, 60*time.Millisecond)
defer defaultTimeout(t, DefaultTimeout)
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
}()
realID := waitForRequestID(t, pub)
assert.False(t, b.Respond("totally-bogus", Decision{Accept: true}), "unknown id must not match")
assert.NotEqual(t, "totally-bogus", realID)
select {
case err := <-done:
assert.ErrorIs(t, err, ErrTimeout, "armed prompt must still time out, not accept")
case <-time.After(time.Second):
t.Fatal("prompt did not resolve")
}
}
// TestRespondAfterTimeoutNoop confirms a late accept response can't
// retroactively flip a denied (timed-out) request. The dropPending defer
// in Request must have removed the entry by the time Respond races in.
func TestRespondAfterTimeoutNoop(t *testing.T) {
defaultTimeout(t, 30*time.Millisecond)
defer defaultTimeout(t, DefaultTimeout)
pub := &fakePublisher{subscribers: true}
b := New(pub)
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
}()
id := waitForRequestID(t, pub)
select {
case err := <-done:
require.ErrorIs(t, err, ErrTimeout)
case <-time.After(time.Second):
t.Fatal("prompt did not time out")
}
assert.False(t, b.Respond(id, Decision{Accept: true}), "late respond must be no-op")
}
// TestRespondDoubleNoop ensures a duplicate ack from the UI doesn't leak
// past the matched waiter or panic on a closed/full channel.
func TestRespondDoubleNoop(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
}()
id := waitForRequestID(t, pub)
require.True(t, b.Respond(id, Decision{Accept: true}))
assert.False(t, b.Respond(id, Decision{Accept: false}), "second response must be no-op")
select {
case err := <-done:
assert.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("prompt did not resolve")
}
}
// TestNilBrokerRequestErrors guards the engine pre-init path where the
// broker may not yet exist (or its publisher is nil): Request must
// error, never silently accept.
func TestNilBrokerRequestErrors(t *testing.T) {
var b *Broker
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC})
assert.Error(t, err, "nil broker must error, never silently accept")
b2 := New(nil)
_, err = b2.Request(context.Background(), Prompt{Kind: KindVNC})
assert.Error(t, err, "broker with nil publisher must error, never silently accept")
}
// TestPromptMetadataInjected confirms the broker stamps request_id, kind,
// and expires_at on the emitted event. The UI relies on these keys; if
// they are dropped, the user cannot route the prompt and the response
// path breaks (which fails closed via timeout).
func TestPromptMetadataInjected(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
done := make(chan error, 1)
go func() {
done <- requestErr(b, context.Background(), Prompt{
Kind: KindVNC,
Subject: "VNC connection from peerA",
Metadata: map[string]string{"peer_name": "peerA"},
})
}()
id := waitForRequestID(t, pub)
ev := pub.lastEvent(t)
assert.Equal(t, proto.SystemEvent_APPROVAL, ev.Category)
assert.Equal(t, KindVNC, ev.Metadata[MetaKind])
assert.Equal(t, id, ev.Metadata[MetaRequestID])
assert.NotEmpty(t, ev.Metadata[MetaExpiresAt])
assert.Equal(t, "peerA", ev.Metadata["peer_name"], "caller metadata must pass through")
require.True(t, b.Respond(id, Decision{Accept: true}))
<-done
}
// TestConcurrentRequests verifies that two concurrent prompts are tracked
// independently. A bug that aliases ids would let one Respond unblock
// the wrong waiter (a silent accept across prompts).
func TestConcurrentRequests(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
const n = 20
results := make(chan error, n)
for i := 0; i < n; i++ {
go func() {
results <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
}()
}
ids := waitForNRequestIDs(t, pub, n)
require.Len(t, ids, n)
// Deny exactly half, accept the rest. Track outcome per id so we can
// match each Request's return value against the response we sent.
denySet := make(map[string]bool, n)
for i, id := range ids {
deny := i%2 == 0
denySet[id] = deny
require.True(t, b.Respond(id, Decision{Accept: !deny}))
}
// Collect all returns and check no nil errors slipped past a deny.
var accepted, denied atomic.Int32
for i := 0; i < n; i++ {
select {
case err := <-results:
if err == nil {
accepted.Add(1)
} else {
assert.ErrorIs(t, err, ErrDenied)
denied.Add(1)
}
case <-time.After(2 * time.Second):
t.Fatalf("only got %d/%d responses", i, n)
}
}
assert.Equal(t, int32(n/2), denied.Load())
assert.Equal(t, int32(n/2), accepted.Load())
}
// waitForRequestID blocks until the publisher sees its next event and
// returns the request_id stamped on it.
func waitForRequestID(t *testing.T, pub *fakePublisher) string {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
pub.mu.Lock()
count := len(pub.events)
var id string
if count > 0 {
id = pub.events[count-1].Metadata[MetaRequestID]
}
pub.mu.Unlock()
if id != "" {
return id
}
time.Sleep(2 * time.Millisecond)
}
t.Fatal("timeout waiting for emitted event")
return ""
}
func waitForNRequestIDs(t *testing.T, pub *fakePublisher, n int) []string {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
pub.mu.Lock()
count := len(pub.events)
pub.mu.Unlock()
if count >= n {
break
}
time.Sleep(2 * time.Millisecond)
}
pub.mu.Lock()
defer pub.mu.Unlock()
out := make([]string, 0, len(pub.events))
seen := make(map[string]struct{}, len(pub.events))
for _, ev := range pub.events {
id := ev.Metadata[MetaRequestID]
if id == "" {
continue
}
if _, dup := seen[id]; dup {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
if len(out) < n {
t.Fatalf("only got %d/%d request ids", len(out), n)
}
return out
}
// defaultTimeout swaps the broker's per-request wall-clock window so the
// timeout tests run quickly. Restores the prior value on the next call.
func defaultTimeout(t *testing.T, d time.Duration) {
t.Helper()
if d <= 0 {
t.Fatal("defaultTimeout must be > 0")
}
timeoutValue = func() time.Duration { return d }
}
// requestErr wraps Broker.Request to drop the Decision when tests only
// care about the error path. Keeps the goroutine bodies tight.
func requestErr(b *Broker, ctx context.Context, p Prompt) error {
_, err := b.Request(ctx, p)
return err
}
// TestRequestViewOnly checks the view-only outcome flows through Request's
// Decision return without being silently swallowed.
func TestRequestViewOnly(t *testing.T) {
pub := &fakePublisher{subscribers: true}
b := New(pub)
type result struct {
d Decision
err error
}
done := make(chan result, 1)
go func() {
d, err := b.Request(context.Background(), Prompt{Kind: KindVNC})
done <- result{d, err}
}()
id := waitForRequestID(t, pub)
require.True(t, b.Respond(id, Decision{Accept: true, ViewOnly: true}))
select {
case r := <-done:
assert.NoError(t, r.err)
assert.True(t, r.d.Accept)
assert.True(t, r.d.ViewOnly, "ViewOnly must survive the round-trip")
case <-time.After(time.Second):
t.Fatal("view-only request did not resolve")
}
}

View File

@@ -1,62 +0,0 @@
package approval
import "testing"
// TestShortKeyFingerprint locks in the format the VNC approval prompt
// shows to the user. The fingerprint is the user's only cryptographic
// anchor against a malicious management server that pushes a spoofed
// display name, so accidental changes to its format would silently
// undermine that defence.
func TestShortKeyFingerprint(t *testing.T) {
cases := []struct {
name string
in string
want string
}{
{
name: "full_32_byte_pubkey",
in: "0123456789abcdeffedcba9876543210ffeeddccbbaa99887766554433221100",
want: "0123-4567-89ab-cdef",
},
{
name: "exactly_16_chars",
in: "0123456789abcdef",
want: "0123-4567-89ab-cdef",
},
{
name: "borderline_8_chars",
in: "01234567",
want: "0123-4567",
},
{
name: "too_short_returns_empty",
in: "0123",
want: "",
},
{
name: "empty_returns_empty",
in: "",
want: "",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := ShortKeyFingerprint(tc.in)
if got != tc.want {
t.Fatalf("ShortKeyFingerprint(%q) = %q, want %q", tc.in, got, tc.want)
}
})
}
}
// TestShortKeyFingerprint_DistinctKeysDistinctOutputs guards against a
// formatting bug that would collapse different prefixes onto the same
// displayed fingerprint and let an attacker substitute their pubkey for
// a victim's while keeping the prompt visually identical.
func TestShortKeyFingerprint_DistinctKeysDistinctOutputs(t *testing.T) {
a := ShortKeyFingerprint("0123456789abcdef" + "rest_of_pubkey_ignored")
b := ShortKeyFingerprint("0123456789abcde0" + "rest_of_pubkey_ignored")
if a == b {
t.Fatalf("expected distinct outputs for distinct prefixes, both = %q", a)
}
}

View File

@@ -315,7 +315,6 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
a.config.RosenpassEnabled,
a.config.RosenpassPermissive,
a.config.ServerSSHAllowed,
a.config.ServerVNCAllowed,
a.config.DisableClientRoutes,
a.config.DisableServerRoutes,
a.config.DisableDNS,

View File

@@ -581,8 +581,6 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
ServerVNCAllowed: config.ServerVNCAllowed != nil && *config.ServerVNCAllowed,
DisableVNCApproval: config.DisableVNCApproval,
EnableSSHRoot: config.EnableSSHRoot,
EnableSSHSFTP: config.EnableSSHSFTP,
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
@@ -665,7 +663,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.ServerVNCAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,

View File

@@ -655,12 +655,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
if g.internalConfig.SSHJWTCacheTTL != nil {
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
}
if g.internalConfig.ServerVNCAllowed != nil {
configContent.WriteString(fmt.Sprintf("ServerVNCAllowed: %v\n", *g.internalConfig.ServerVNCAllowed))
}
if g.internalConfig.DisableVNCApproval != nil {
configContent.WriteString(fmt.Sprintf("DisableVNCApproval: %v\n", *g.internalConfig.DisableVNCApproval))
}
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))

View File

@@ -864,8 +864,6 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
RosenpassEnabled: true,
RosenpassPermissive: true,
ServerSSHAllowed: &bTrue,
ServerVNCAllowed: &bTrue,
DisableVNCApproval: &bTrue,
EnableSSHRoot: &bTrue,
EnableSSHSFTP: &bTrue,
EnableSSHLocalPortForwarding: &bTrue,

View File

@@ -1,485 +0,0 @@
//go:build privileged
package dns
import (
"context"
"fmt"
"net/netip"
"os"
"testing"
"github.com/golang/mock/gomock"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface"
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns"
)
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
testCases := []struct {
name string
initUpstreamMap []handlerWrapper
initLocalZones []nbdns.CustomZone
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
shouldFail bool
expectedUpstreamMap []handlerWrapper
expectedLocalQs []dns.Question
}{
{
name: "Initial Config Should Succeed",
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
},
expectedUpstreamMap: []handlerWrapper{
{
domain: "netbird.io",
priority: PriorityUpstream,
},
{
domain: "netbird.cloud",
priority: PriorityLocal,
},
{
domain: nbdns.RootZone,
priority: PriorityDefault,
},
},
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
},
{
name: "New Config Should Succeed",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: "netbird.cloud",
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
},
},
expectedUpstreamMap: []handlerWrapper{
{
domain: "netbird.io",
priority: PriorityUpstream,
},
{
domain: "netbird.cloud",
priority: PriorityLocal,
},
},
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 2,
inputSerial: 1,
shouldFail: true,
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
Primary: true,
},
},
},
expectedUpstreamMap: []handlerWrapper{{
domain: ".",
priority: PriorityDefault,
}},
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{},
},
{
name: "Disabled Service Should clean map",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: false},
expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{},
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
privKey, _ := wgtypes.GenerateKey()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun230%d", n),
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil {
t.Fatal(err)
}
err = wgIface.Create()
if err != nil {
t.Fatal(err)
}
defer func() {
err = wgIface.Close()
if err != nil {
t.Log(err)
}
}()
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil {
t.Fatal(err)
}
err = dnsServer.Initialize()
if err != nil {
t.Fatal(err)
}
defer func() {
err = dnsServer.hostManager.restoreHostDNS()
if err != nil {
t.Log(err)
}
}()
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil {
if testCase.shouldFail {
return
}
t.Fatalf("update dns server should not fail, got error: %v", err)
}
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
}
for _, expected := range testCase.expectedUpstreamMap {
found := false
for _, got := range dnsServer.dnsMuxHandlers {
if got.domain == expected.domain && got.priority == expected.priority {
found = true
break
}
}
if !found {
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
}
}
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
for _, q := range testCase.expectedLocalQs {
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
Question: []dns.Question{q},
})
}
if len(testCase.expectedLocalQs) > 0 {
assert.NotNil(t, responseMSG, "response message should not be nil")
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
}
})
}
}
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
if err != nil {
t.Errorf("create stdnet: %v", err)
return
}
privKey, _ := wgtypes.GeneratePrivateKey()
opts := iface.WGIFaceOpts{
IFaceName: "utun2301",
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil {
t.Errorf("build interface wireguard: %v", err)
return
}
err = wgIface.Create()
if err != nil {
t.Errorf("create and init wireguard interface: %v", err)
return
}
defer func() {
if err = wgIface.Close(); err != nil {
t.Logf("close wireguard interface: %v", err)
}
}()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)
return
}
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil {
t.Errorf("create DNS server: %v", err)
return
}
err = dnsServer.Initialize()
if err != nil {
t.Errorf("run DNS server: %v", err)
return
}
defer func() {
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
t.Logf("restore DNS settings on the host: %v", err)
return
}
}()
dnsServer.dnsMuxHandlers = []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &local.Resolver{},
priority: PriorityUpstream,
},
}
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
}
// Start the server with regular configuration
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update2 := update
update2.ServiceEnable = false
// Disable the server, stop the listener
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update3 := update2
update3.NameServerGroups = update3.NameServerGroups[:1]
// But service still get updates and we checking that we handle
// internal state in the right way
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
}

View File

@@ -10,6 +10,7 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
@@ -22,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/test"
@@ -102,6 +104,466 @@ func init() {
formatter.SetTextFormatter(log.StandardLogger())
}
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
testCases := []struct {
name string
initUpstreamMap []handlerWrapper
initLocalZones []nbdns.CustomZone
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
shouldFail bool
expectedUpstreamMap []handlerWrapper
expectedLocalQs []dns.Question
}{
{
name: "Initial Config Should Succeed",
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
},
expectedUpstreamMap: []handlerWrapper{
{
domain: "netbird.io",
priority: PriorityUpstream,
},
{
domain: "netbird.cloud",
priority: PriorityLocal,
},
{
domain: nbdns.RootZone,
priority: PriorityDefault,
},
},
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
},
{
name: "New Config Should Succeed",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: "netbird.cloud",
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
},
},
expectedUpstreamMap: []handlerWrapper{
{
domain: "netbird.io",
priority: PriorityUpstream,
},
{
domain: "netbird.cloud",
priority: PriorityLocal,
},
},
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 2,
inputSerial: 1,
shouldFail: true,
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
Primary: true,
},
},
},
expectedUpstreamMap: []handlerWrapper{{
domain: ".",
priority: PriorityDefault,
}},
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{},
},
{
name: "Disabled Service Should clean map",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: false},
expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{},
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
privKey, _ := wgtypes.GenerateKey()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun230%d", n),
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil {
t.Fatal(err)
}
err = wgIface.Create()
if err != nil {
t.Fatal(err)
}
defer func() {
err = wgIface.Close()
if err != nil {
t.Log(err)
}
}()
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil {
t.Fatal(err)
}
err = dnsServer.Initialize()
if err != nil {
t.Fatal(err)
}
defer func() {
err = dnsServer.hostManager.restoreHostDNS()
if err != nil {
t.Log(err)
}
}()
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil {
if testCase.shouldFail {
return
}
t.Fatalf("update dns server should not fail, got error: %v", err)
}
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
}
for _, expected := range testCase.expectedUpstreamMap {
found := false
for _, got := range dnsServer.dnsMuxHandlers {
if got.domain == expected.domain && got.priority == expected.priority {
found = true
break
}
}
if !found {
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
}
}
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
for _, q := range testCase.expectedLocalQs {
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
Question: []dns.Question{q},
})
}
if len(testCase.expectedLocalQs) > 0 {
assert.NotNil(t, responseMSG, "response message should not be nil")
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
}
})
}
}
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
if err != nil {
t.Errorf("create stdnet: %v", err)
return
}
privKey, _ := wgtypes.GeneratePrivateKey()
opts := iface.WGIFaceOpts{
IFaceName: "utun2301",
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil {
t.Errorf("build interface wireguard: %v", err)
return
}
err = wgIface.Create()
if err != nil {
t.Errorf("create and init wireguard interface: %v", err)
return
}
defer func() {
if err = wgIface.Close(); err != nil {
t.Logf("close wireguard interface: %v", err)
}
}()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)
return
}
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil {
t.Errorf("create DNS server: %v", err)
return
}
err = dnsServer.Initialize()
if err != nil {
t.Errorf("run DNS server: %v", err)
return
}
defer func() {
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
t.Logf("restore DNS settings on the host: %v", err)
return
}
}()
dnsServer.dnsMuxHandlers = []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &local.Resolver{},
priority: PriorityUpstream,
},
}
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
}
// Start the server with regular configuration
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update2 := update
update2.ServiceEnable = false
// Disable the server, stop the listener
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update3 := update2
update3.NameServerGroups = update3.NameServerGroups[:1]
// But service still get updates and we checking that we handle
// internal state in the right way
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
}
func TestDNSServerStartStop(t *testing.T) {
testCases := []struct {
name string

View File

@@ -34,7 +34,6 @@ import (
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/approval"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
@@ -126,8 +125,6 @@ type EngineConfig struct {
RosenpassPermissive bool
ServerSSHAllowed bool
ServerVNCAllowed bool
DisableVNCApproval *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -215,9 +212,7 @@ type Engine struct {
networkMonitor *networkmonitor.NetworkMonitor
sshServer sshServer
vncSrv vncServer
approvalBroker *approval.Broker
sshServer sshServer
statusRecorder *peer.Status
@@ -309,7 +304,6 @@ func NewEngine(
TURNs: []*stun.URI{},
networkSerial: 0,
statusRecorder: services.StatusRecorder,
approvalBroker: approval.New(services.StatusRecorder),
stateManager: services.StateManager,
portForwardManager: portforward.NewManager(),
checks: services.Checks,
@@ -372,10 +366,6 @@ func (e *Engine) stopLocked() {
log.Warnf("failed to stop SSH server: %v", err)
}
if err := e.stopVNCServer(); err != nil {
log.Warnf("failed to stop VNC server: %v", err)
}
e.cleanupSSHConfig()
if e.ingressGatewayMgr != nil {
@@ -1085,7 +1075,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -1147,10 +1136,6 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
}
}
if err := e.updateVNC(); err != nil {
log.Warnf("failed handling VNC server setup: %v", err)
}
state := e.statusRecorder.GetLocalPeerState()
state.IP = e.wgInterface.Address().String()
state.IPv6 = e.wgInterface.Address().IPv6String()
@@ -1278,7 +1263,6 @@ func (e *Engine) receiveManagementEvents() {
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -1468,11 +1452,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.updateSSHServerAuth(networkMap.GetSshAuth())
}
// VNC auth: always sync, including nil so cleared auth on the management
// side is applied locally, and so it isn't skipped on the RemotePeersIsEmpty
// cleanup path.
e.updateVNCServerAuth(networkMap.GetVncAuth())
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
@@ -1952,7 +1931,6 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
&e.config.ServerVNCAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
@@ -2721,16 +2699,3 @@ func decodeRelayIP(b []byte) netip.Addr {
}
return ip.Unmap()
}
// RespondApproval relays the user's decision for a pending approval to
// the broker. viewOnly is honoured only when accept is true. Returns
// true when the request_id matched a live prompt.
func (e *Engine) RespondApproval(requestID string, accept, viewOnly bool) bool {
if e == nil || e.approvalBroker == nil {
return false
}
return e.approvalBroker.Respond(requestID, approval.Decision{
Accept: accept,
ViewOnly: accept && viewOnly,
})
}

View File

@@ -1,565 +0,0 @@
//go:build privileged
package internal
import (
"context"
"fmt"
"net"
"runtime"
"strings"
"sync"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/util"
)
func TestEngine_SSH(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(
ctx, cancel,
&EngineConfig{
WgIfaceName: "utun101",
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
WgPrivateKey: key,
WgPort: 33100,
ServerSSHAllowed: true,
MTU: iface.DefaultMTU,
SSHKey: sshKey,
},
EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
},
MobileDependency{},
)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
err = engine.Start(nil, nil)
require.NoError(t, err)
defer func() {
err := engine.Stop()
if err != nil {
return
}
}()
peerWithSSH := &mgmtProto.RemotePeerConfig{
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.21/24"},
SshConfig: &mgmtProto.SSHConfig{
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
},
}
// SSH server is not enabled so SSH config of a remote peer should be ignored
networkMap := &mgmtProto.NetworkMap{
Serial: 6,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
// SSH server is enabled, therefore SSH config should be applied
networkMap = &mgmtProto.NetworkMap{
Serial: 7,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{
SshEnabled: true,
JwtConfig: &mgmtProto.JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
KeysLocation: "test-keys",
MaxTokenAge: 3600,
},
}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
// now remove peer
networkMap = &mgmtProto.NetworkMap{
Serial: 8,
RemotePeers: []*mgmtProto.RemotePeerConfig{},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
// time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
// now disable SSH server
networkMap = &mgmtProto.NetworkMap{
Serial: 9,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
}
func TestEngine_Sync(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
// feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse)
defer close(updates)
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates {
err := msgHandler(msg)
if err != nil {
t.Fatal(err)
}
}
return nil
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &EngineConfig{
WgIfaceName: "utun103",
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{})
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
defer func() {
err := engine.Stop()
if err != nil {
return
}
}()
err = engine.Start(nil, nil)
if err != nil {
t.Fatal(err)
return
}
peer1 := &mgmtProto.RemotePeerConfig{
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.10/24"},
}
peer2 := &mgmtProto.RemotePeerConfig{
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
AllowedIps: []string{"100.64.0.11/24"},
}
peer3 := &mgmtProto.RemotePeerConfig{
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
AllowedIps: []string{"100.64.0.12/24"},
}
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
updates <- &mgmtProto.SyncResponse{
NetworkMap: &mgmtProto.NetworkMap{
Serial: 10,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
RemotePeersIsEmpty: false,
},
}
timeout := time.After(time.Second * 2)
for {
select {
case <-timeout:
t.Fatalf("timeout while waiting for test to finish")
return
default:
}
if getPeers(engine) == 3 && engine.networkSerial == 10 {
break
}
}
}
func TestEngine_MultiplePeers(t *testing.T) {
// log.SetLevel(log.DebugLevel)
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
sigServer, signalAddr, err := startSignal(t)
if err != nil {
t.Fatal(err)
return
}
defer sigServer.Stop()
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
if err != nil {
t.Fatal(err)
return
}
defer mgmtServer.GracefulStop()
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
mu := sync.Mutex{}
engines := []*Engine{}
numPeers := 10
wg := sync.WaitGroup{}
wg.Add(numPeers)
// create and start peers
for i := 0; i < numPeers; i++ {
j := i
go func() {
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
if err != nil {
wg.Done()
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
return
}
engine.dnsServer = &dns.MockServer{}
mu.Lock()
defer mu.Unlock()
guid := fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid)
err = engine.Start(nil, nil)
if err != nil {
t.Errorf("unable to start engine for peer %d with error %v", j, err)
wg.Done()
return
}
engines = append(engines, engine)
wg.Done()
}()
}
// wait until all have been created and started
wg.Wait()
if len(engines) != numPeers {
t.Fatal("not all peers were started")
}
// check whether all the peer have expected peers connected
expectedConnected := numPeers * (numPeers - 1)
// adjust according to timeouts
timeout := 50 * time.Second
timeoutChan := time.After(timeout)
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
loop:
for {
select {
case <-timeoutChan:
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
break loop
case <-ticker.C:
totalConnected := 0
for _, engine := range engines {
totalConnected += getConnectedPeers(engine)
}
if totalConnected == expectedConnected {
log.Infof("total connected=%d", totalConnected)
break loop
}
log.Infof("total connected=%d", totalConnected)
}
}
// cleanup test
for n, peerEngine := range engines {
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
errStop := peerEngine.mgmClient.Close()
if errStop != nil {
log.Infoln("got error trying to close management clients from engine: ", errStop)
}
errStop = peerEngine.Stop()
if errStop != nil {
log.Infoln("got error trying to close testing peers engine: ", errStop)
}
}
}
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
)
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
}
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
if err != nil {
return nil, err
}
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
if err != nil {
return nil, err
}
info := system.GetInfo(ctx)
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
if err != nil {
return nil, err
}
var ifaceName string
if runtime.GOOS == "darwin" {
ifaceName = fmt.Sprintf("utun1%d", i)
} else {
ifaceName = fmt.Sprintf("wt%d", i)
}
wgPort := 33100 + i
conf := &EngineConfig{
WgIfaceName: ifaceName,
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
WgPrivateKey: key,
WgPort: wgPort,
MTU: iface.DefaultMTU,
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, conf, EngineServices{
SignalClient: signalClient,
MgmClient: mgmtClient,
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{}), nil
e.ctx = ctx
return e, err
}
func startSignal(t *testing.T) (*grpc.Server, string, error) {
t.Helper()
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
t.Helper()
config := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Relay: &config.Relay{
Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222",
},
Signal: &config.Host{
Proto: "http",
URI: "localhost:10000",
},
Datadir: dataDir,
HttpConfig: nil,
}
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
if err != nil {
return nil, "", err
}
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
}
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil).
AnyTimes()
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
groupsManager := groups.NewManagerMock()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
if err != nil {
return nil, "", err
}
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
return nil, "", err
}
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
func getConnectedPeers(e *Engine) int {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
i := 0
for _, id := range e.peerStore.PeersPubKey() {
conn, _ := e.peerStore.PeerConn(id)
if conn.IsConnected() {
i++
}
}
return i
}
func getPeers(e *Engine) int {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
return len(e.peerStore.PeersPubKey())
}

View File

@@ -12,10 +12,10 @@ import (
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/netstack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
@@ -237,18 +237,22 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
return errors.New("wg interface not initialized")
}
wgAddr := e.wgInterface.Address()
serverConfig := &sshserver.Config{
HostKeyPEM: e.config.SSHKey,
JWT: jwtConfig,
NetstackNet: e.wgInterface.GetNet(),
NetworkValidation: wgAddr,
HostKeyPEM: e.config.SSHKey,
JWT: jwtConfig,
}
server := sshserver.New(serverConfig)
wgAddr := e.wgInterface.Address()
server.SetNetworkValidation(wgAddr)
netbirdIP := wgAddr.IP
listenAddr := netip.AddrPortFrom(netbirdIP, sshserver.InternalSSHPort)
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
server.SetNetstackNet(netstackNet)
}
e.configureSSHServer(server)
if err := server.Start(e.ctx, listenAddr); err != nil {

View File

@@ -6,18 +6,37 @@ import (
"net"
"net/netip"
"os"
"runtime"
"strings"
"sync"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
@@ -31,7 +50,18 @@ import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/routemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/monotime"
"github.com/netbirdio/netbird/route"
mgmt "github.com/netbirdio/netbird/shared/management/client"
@@ -39,9 +69,25 @@ import (
"github.com/netbirdio/netbird/shared/netiputil"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/util"
)
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
)
type MockWGIface struct {
CreateFunc func() error
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
@@ -188,6 +234,129 @@ func TestMain(m *testing.M) {
os.Exit(code)
}
func TestEngine_SSH(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(
ctx, cancel,
&EngineConfig{
WgIfaceName: "utun101",
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
WgPrivateKey: key,
WgPort: 33100,
ServerSSHAllowed: true,
MTU: iface.DefaultMTU,
SSHKey: sshKey,
},
EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
},
MobileDependency{},
)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
err = engine.Start(nil, nil)
require.NoError(t, err)
defer func() {
err := engine.Stop()
if err != nil {
return
}
}()
peerWithSSH := &mgmtProto.RemotePeerConfig{
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.21/24"},
SshConfig: &mgmtProto.SSHConfig{
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
},
}
// SSH server is not enabled so SSH config of a remote peer should be ignored
networkMap := &mgmtProto.NetworkMap{
Serial: 6,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
// SSH server is enabled, therefore SSH config should be applied
networkMap = &mgmtProto.NetworkMap{
Serial: 7,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{
SshEnabled: true,
JwtConfig: &mgmtProto.JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
KeysLocation: "test-keys",
MaxTokenAge: 3600,
},
}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
// now remove peer
networkMap = &mgmtProto.NetworkMap{
Serial: 8,
RemotePeers: []*mgmtProto.RemotePeerConfig{},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
// time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
// now disable SSH server
networkMap = &mgmtProto.NetworkMap{
Serial: 9,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
}
func TestEngine_SSHUpdateLogic(t *testing.T) {
// Test that SSH server start/stop logic works based on config
engine := &Engine{
@@ -462,6 +631,97 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
}
}
func TestEngine_Sync(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
// feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse)
defer close(updates)
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates {
err := msgHandler(msg)
if err != nil {
t.Fatal(err)
}
}
return nil
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &EngineConfig{
WgIfaceName: "utun103",
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{})
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
defer func() {
err := engine.Stop()
if err != nil {
return
}
}()
err = engine.Start(nil, nil)
if err != nil {
t.Fatal(err)
return
}
peer1 := &mgmtProto.RemotePeerConfig{
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.10/24"},
}
peer2 := &mgmtProto.RemotePeerConfig{
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
AllowedIps: []string{"100.64.0.11/24"},
}
peer3 := &mgmtProto.RemotePeerConfig{
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
AllowedIps: []string{"100.64.0.12/24"},
}
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
updates <- &mgmtProto.SyncResponse{
NetworkMap: &mgmtProto.NetworkMap{
Serial: 10,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
RemotePeersIsEmpty: false,
},
}
timeout := time.After(time.Second * 2)
for {
select {
case <-timeout:
t.Fatalf("timeout while waiting for test to finish")
return
default:
}
if getPeers(engine) == 3 && engine.networkSerial == 10 {
break
}
}
}
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
testCases := []struct {
name string
@@ -845,6 +1105,104 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
}
}
func TestEngine_MultiplePeers(t *testing.T) {
// log.SetLevel(log.DebugLevel)
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
sigServer, signalAddr, err := startSignal(t)
if err != nil {
t.Fatal(err)
return
}
defer sigServer.Stop()
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
if err != nil {
t.Fatal(err)
return
}
defer mgmtServer.GracefulStop()
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
mu := sync.Mutex{}
engines := []*Engine{}
numPeers := 10
wg := sync.WaitGroup{}
wg.Add(numPeers)
// create and start peers
for i := 0; i < numPeers; i++ {
j := i
go func() {
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
if err != nil {
wg.Done()
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
return
}
engine.dnsServer = &dns.MockServer{}
mu.Lock()
defer mu.Unlock()
guid := fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid)
err = engine.Start(nil, nil)
if err != nil {
t.Errorf("unable to start engine for peer %d with error %v", j, err)
wg.Done()
return
}
engines = append(engines, engine)
wg.Done()
}()
}
// wait until all have been created and started
wg.Wait()
if len(engines) != numPeers {
t.Fatal("not all peers was started")
}
// check whether all the peer have expected peers connected
expectedConnected := numPeers * (numPeers - 1)
// adjust according to timeouts
timeout := 50 * time.Second
timeoutChan := time.After(timeout)
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
loop:
for {
select {
case <-timeoutChan:
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
break loop
case <-ticker.C:
totalConnected := 0
for _, engine := range engines {
totalConnected += getConnectedPeers(engine)
}
if totalConnected == expectedConnected {
log.Infof("total connected=%d", totalConnected)
break loop
}
log.Infof("total connected=%d", totalConnected)
}
}
// cleanup test
for n, peerEngine := range engines {
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
errStop := peerEngine.mgmClient.Close()
if errStop != nil {
log.Infoln("got error trying to close management clients from engine: ", errStop)
}
errStop = peerEngine.Stop()
if errStop != nil {
log.Infoln("got error trying to close testing peers engine: ", errStop)
}
}
}
func Test_ParseNATExternalIPMappings(t *testing.T) {
ifaceList, err := net.Interfaces()
if err != nil {
@@ -1168,6 +1526,187 @@ func TestCompareNetIPLists(t *testing.T) {
}
}
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
}
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
if err != nil {
return nil, err
}
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
if err != nil {
return nil, err
}
info := system.GetInfo(ctx)
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
if err != nil {
return nil, err
}
var ifaceName string
if runtime.GOOS == "darwin" {
ifaceName = fmt.Sprintf("utun1%d", i)
} else {
ifaceName = fmt.Sprintf("wt%d", i)
}
wgPort := 33100 + i
conf := &EngineConfig{
WgIfaceName: ifaceName,
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
WgPrivateKey: key,
WgPort: wgPort,
MTU: iface.DefaultMTU,
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, conf, EngineServices{
SignalClient: signalClient,
MgmClient: mgmtClient,
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{}), nil
e.ctx = ctx
return e, err
}
func startSignal(t *testing.T) (*grpc.Server, string, error) {
t.Helper()
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
t.Helper()
config := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Relay: &config.Relay{
Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222",
},
Signal: &config.Host{
Proto: "http",
URI: "localhost:10000",
},
Datadir: dataDir,
HttpConfig: nil,
}
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
if err != nil {
return nil, "", err
}
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
}
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil).
AnyTimes()
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
groupsManager := groups.NewManagerMock()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
if err != nil {
return nil, "", err
}
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
return nil, "", err
}
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
func getConnectedPeers(e *Engine) int {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
i := 0
for _, id := range e.peerStore.PeersPubKey() {
conn, _ := e.peerStore.PeerConn(id)
if conn.IsConnected() {
i++
}
}
return i
}
func getPeers(e *Engine) int {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
return len(e.peerStore.PeersPubKey())
}
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
t.Helper()
b, err := netiputil.EncodePrefix(p)

View File

@@ -1,302 +0,0 @@
//go:build !js && !ios && !android
package internal
import (
"context"
"errors"
"fmt"
"net/netip"
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/approval"
"github.com/netbirdio/netbird/client/internal/metrics"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/vnc"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
type vncServer interface {
Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error
Stop() error
ActiveSessions() []vncserver.ActiveSessionInfo
}
func (e *Engine) setupVNCPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, vnc.ExternalPort, vnc.InternalPort); err != nil {
return fmt.Errorf("add VNC port redirection: %w", err)
}
log.Infof("VNC port redirection: %s:%d -> %s:%d", localAddr, vnc.ExternalPort, localAddr, vnc.InternalPort)
return nil
}
func (e *Engine) cleanupVNCPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, vnc.ExternalPort, vnc.InternalPort); err != nil {
return fmt.Errorf("remove VNC port redirection: %w", err)
}
return nil
}
// updateVNC handles starting/stopping the VNC server based on the config flag.
func (e *Engine) updateVNC() error {
if !e.config.ServerVNCAllowed {
if e.vncSrv != nil {
log.Info("VNC server disabled, stopping")
}
return e.stopVNCServer()
}
if e.config.BlockInbound {
log.Info("VNC server disabled because inbound connections are blocked")
return e.stopVNCServer()
}
if e.vncSrv != nil {
return nil
}
return e.startVNCServer()
}
func (e *Engine) startVNCServer() error {
if e.wgInterface == nil {
return errors.New("wg interface not initialized")
}
capturer, injector, ok := newPlatformVNC()
if !ok {
log.Debug("VNC server not supported on this platform")
return nil
}
netbirdIP := e.wgInterface.Address().IP
var sessionRecorder func(vncserver.SessionTick)
if e.clientMetrics != nil {
sessionRecorder = func(t vncserver.SessionTick) {
e.clientMetrics.RecordVNCSessionTick(e.ctx, metrics.VNCSessionTick{
Period: t.Period,
BytesOut: t.BytesOut,
Writes: t.Writes,
FBUs: t.FBUs,
MaxFBUBytes: t.MaxFBUBytes,
MaxFBURects: t.MaxFBURects,
MaxWriteBytes: t.MaxWriteBytes,
WriteNanos: t.WriteNanos,
})
}
}
serviceMode := vncNeedsServiceMode()
if serviceMode {
log.Info("VNC: running as system service, enabling service mode (per-session agent proxy)")
}
requireApproval := e.config.DisableVNCApproval == nil || !*e.config.DisableVNCApproval
srv := vncserver.New(vncserver.Config{
Capturer: capturer,
Injector: injector,
IdentityKey: e.config.WgPrivateKey[:],
ServiceMode: serviceMode,
SessionRecorder: sessionRecorder,
NetstackNet: e.wgInterface.GetNet(),
RequireApproval: requireApproval,
Approver: &vncApprover{broker: e.approvalBroker, statusRecorder: e.statusRecorder},
})
listenAddr := netip.AddrPortFrom(netbirdIP, vnc.InternalPort)
network := e.wgInterface.Address().Network
if err := srv.Start(e.ctx, listenAddr, network); err != nil {
return fmt.Errorf("start VNC server: %w", err)
}
e.vncSrv = srv
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.RegisterNetstackService(nftypes.TCP, vnc.InternalPort)
log.Debugf("registered VNC service with netstack for TCP:%d", vnc.InternalPort)
}
}
if err := e.setupVNCPortRedirection(); err != nil {
log.Warnf("setup VNC port redirection: %v", err)
}
log.Info("VNC server enabled")
return nil
}
// updateVNCServerAuth updates VNC fine-grained access control from management.
// A nil vncAuth clears all authorized users and session pubkeys so management
// can revoke access by omitting the field on the next sync.
func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
if e.vncSrv == nil {
return
}
vncSrv, ok := e.vncSrv.(*vncserver.Server)
if !ok {
return
}
if vncAuth == nil {
vncSrv.UpdateVNCAuth(&sshauth.Config{})
return
}
protoUsers := vncAuth.GetAuthorizedUsers()
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
for i, hash := range protoUsers {
if len(hash) != 16 {
log.Warnf("invalid VNC auth hash length %d, expected 16", len(hash))
return
}
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
}
machineUsers := make(map[string][]uint32)
for osUser, indexes := range vncAuth.GetMachineUsers() {
machineUsers[osUser] = indexes.GetIndexes()
}
sessionPubKeys := make([]sshauth.SessionPubKey, 0, len(vncAuth.GetSessionPubKeys()))
for _, pk := range vncAuth.GetSessionPubKeys() {
pub := pk.GetPubKey()
if len(pub) != 32 {
log.Warnf("VNC session pubkey wrong length %d", len(pub))
continue
}
hash := pk.GetUserIdHash()
if len(hash) != 16 {
log.Warnf("VNC session user id hash wrong length %d", len(hash))
continue
}
sessionPubKeys = append(sessionPubKeys, sshauth.SessionPubKey{
PubKey: pub,
UserIDHash: sshuserhash.UserIDHash(hash),
DisplayName: pk.GetDisplayName(),
})
}
vncSrv.UpdateVNCAuth(&sshauth.Config{
AuthorizedUsers: authorizedUsers,
MachineUsers: machineUsers,
SessionPubKeys: sessionPubKeys,
})
}
// GetVNCServerStatus returns whether the VNC server is running and the list
// of active VNC sessions. The pointer is captured under syncMsgMux so a
// concurrent updateVNC/stopVNCServer cannot swap it out between the nil
// check and the ActiveSessions call.
func (e *Engine) GetVNCServerStatus() (enabled bool, sessions []vncserver.ActiveSessionInfo) {
e.syncMsgMux.Lock()
vncSrv := e.vncSrv
e.syncMsgMux.Unlock()
if vncSrv == nil {
return false, nil
}
return true, vncSrv.ActiveSessions()
}
func (e *Engine) stopVNCServer() error {
if e.vncSrv == nil {
return nil
}
if err := e.cleanupVNCPortRedirection(); err != nil {
log.Warnf("cleanup VNC port redirection: %v", err)
}
if e.wgInterface != nil && e.wgInterface.GetNet() != nil {
if registrar, ok := e.firewall.(interface {
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.UnregisterNetstackService(nftypes.TCP, vnc.InternalPort)
}
}
log.Info("stopping VNC server")
err := e.vncSrv.Stop()
e.vncSrv = nil
if err != nil {
return fmt.Errorf("stop VNC server: %w", err)
}
return nil
}
// vncApprover adapts the generic approval.Broker for the VNC server.
type vncApprover struct {
broker *approval.Broker
statusRecorder *peer.Status
}
func (a *vncApprover) Request(ctx context.Context, info vncserver.ApprovalInfo) (vncserver.ApprovalDecision, error) {
// Resolve the source overlay IP to a peer FQDN for the prompt label.
if info.PeerName == "" && info.SourceIP != "" && a.statusRecorder != nil {
if fqdn, ok := a.statusRecorder.PeerByIP(info.SourceIP); ok {
info.PeerName = fqdn
}
}
subject := fmt.Sprintf("VNC connection from %s", displayPeer(info))
meta := map[string]string{
"peer_name": info.PeerName,
"peer_pubkey": info.PeerPubKey,
"source_ip": info.SourceIP,
"mode": info.Mode,
"username": info.Username,
"initiator": info.Initiator,
}
d, err := a.broker.Request(ctx, approval.Prompt{
Kind: approval.KindVNC,
Subject: subject,
Metadata: meta,
})
if err != nil {
return vncserver.ApprovalDecision{}, err
}
return vncserver.ApprovalDecision{ViewOnly: d.ViewOnly}, nil
}
func displayPeer(info vncserver.ApprovalInfo) string {
if info.Initiator != "" {
return info.Initiator
}
if info.PeerName != "" {
return info.PeerName
}
if info.SourceIP != "" {
return info.SourceIP
}
if info.PeerPubKey != "" {
return info.PeerPubKey
}
return "unknown peer"
}

View File

@@ -1,31 +0,0 @@
//go:build freebsd
package internal
import (
"fmt"
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
// newConsoleVNC builds the FreeBSD console fallback: vt(4) framebuffer
// for capture, /dev/uinput for input. The uinput device requires the
// `uinput` kernel module (`kldload uinput`); without it, input init
// fails and we drop to a stub injector so the user still gets a
// view-only screen mirror.
func newConsoleVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
poller := vncserver.NewFBPoller("")
w, h := poller.Width(), poller.Height()
if w == 0 || h == 0 {
poller.Close()
return nil, nil, fmt.Errorf("vt framebuffer init failed (vt may not allow mmap on this driver)")
}
if inj, err := vncserver.NewUInputInjector(w, h); err == nil {
return poller, inj, nil
} else {
log.Infof("VNC console: uinput unavailable (%v); view-only mode. Run `kldload uinput` to enable input.", err)
return poller, &vncserver.StubInputInjector{}, nil
}
}

View File

@@ -1,30 +0,0 @@
//go:build linux && !android
package internal
import (
"fmt"
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
// newConsoleVNC builds a framebuffer + uinput VNC backend for boxes
// without a running X server. Used as the auto-fallback when
// newPlatformVNC can't reach X. Returns an error when /dev/fb0 or
// /dev/uinput aren't usable so the caller can drop back to a stub.
func newConsoleVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
poller := vncserver.NewFBPoller("")
w, h := poller.Width(), poller.Height()
if w == 0 || h == 0 {
poller.Close()
return nil, nil, fmt.Errorf("framebuffer capturer init failed (is /dev/fb0 readable?)")
}
inj, err := vncserver.NewUInputInjector(w, h)
if err != nil {
log.Debugf("uinput unavailable, falling back to view-only VNC: %v", err)
return poller, &vncserver.StubInputInjector{}, nil
}
return poller, inj, nil
}

View File

@@ -1,34 +0,0 @@
//go:build darwin && !ios
package internal
import (
"os"
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
capturer := vncserver.NewMacPoller()
// Prompt for Screen Recording at server-enable time rather than first
// client-connect. The native prompt is far easier for users to act on
// in the moment they toggled VNC on than later when "the screen looks
// like wallpaper" would otherwise be the only clue.
vncserver.PrimeScreenCapturePermission()
injector, err := vncserver.NewMacInputInjector()
if err != nil {
log.Debugf("VNC: macOS input injector: %v", err)
return capturer, &vncserver.StubInputInjector{}, true
}
return capturer, injector, true
}
// vncNeedsServiceMode reports whether the running process is a system
// LaunchDaemon (root, parented by launchd). Daemons sit in the global
// bootstrap namespace and cannot talk to WindowServer; we route capture
// through a per-user agent in that case.
func vncNeedsServiceMode() bool {
return os.Geteuid() == 0 && os.Getppid() == 1
}

View File

@@ -1,23 +0,0 @@
//go:build js || ios || android
package internal
import (
log "github.com/sirupsen/logrus"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
type vncServer interface{}
func (e *Engine) updateVNC() error { return nil }
func (e *Engine) updateVNCServerAuth(auth *mgmProto.VNCAuth) {
if auth == nil {
return
}
log.Debugf("ignoring VNC auth push on platform without a VNC server: %d session pubkeys, %d authorized users",
len(auth.GetSessionPubKeys()), len(auth.GetAuthorizedUsers()))
}
func (e *Engine) stopVNCServer() error { return nil }

View File

@@ -1,13 +0,0 @@
//go:build windows
package internal
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector(), true
}
func vncNeedsServiceMode() bool {
return vncserver.GetCurrentSessionID() == 0
}

View File

@@ -1,35 +0,0 @@
//go:build (linux && !android) || freebsd
package internal
import (
log "github.com/sirupsen/logrus"
vncserver "github.com/netbirdio/netbird/client/vnc/server"
)
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
// Prefer X11 when an X server is reachable. NewX11InputInjector probes
// DISPLAY (and /proc) eagerly, so a non-nil error here means no X.
injector, err := vncserver.NewX11InputInjector("", "", "")
if err == nil {
return vncserver.NewX11Poller("", ""), injector, true
}
log.Debugf("VNC: X11 not available: %v", err)
// Fallback for headless / pre-X states (kernel console, login manager
// without X, physical server in recovery): stream the framebuffer and
// inject input via /dev/uinput.
consoleCap, consoleInj, err := newConsoleVNC()
if err == nil {
log.Infof("VNC: using framebuffer console capture (%dx%d)", consoleCap.Width(), consoleCap.Height())
return consoleCap, consoleInj, true
}
log.Debugf("VNC: framebuffer console fallback unavailable: %v", err)
return &vncserver.StubCapturer{}, &vncserver.StubInputInjector{}, false
}
func vncNeedsServiceMode() bool {
return false
}

View File

@@ -120,36 +120,6 @@ func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentI
m.trimLocked()
}
func (m *influxDBMetrics) RecordVNCSessionTick(_ context.Context, agentInfo AgentInfo, tick VNCSessionTick) {
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s",
agentInfo.DeploymentType.String(),
agentInfo.Version,
agentInfo.OS,
agentInfo.Arch,
agentInfo.peerID,
)
m.mu.Lock()
defer m.mu.Unlock()
m.samples = append(m.samples, influxSample{
measurement: "netbird_vnc_traffic",
tags: tags,
fields: map[string]float64{
"period_seconds": tick.Period.Seconds(),
"bytes_out": float64(tick.BytesOut),
"writes": float64(tick.Writes),
"fbus": float64(tick.FBUs),
"max_fbu_bytes": float64(tick.MaxFBUBytes),
"max_fbu_rects": float64(tick.MaxFBURects),
"max_write_bytes": float64(tick.MaxWriteBytes),
"write_time_seconds": float64(tick.WriteNanos) / 1e9,
},
timestamp: time.Now(),
})
m.trimLocked()
}
func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) {
result := "success"
if !success {

View File

@@ -59,11 +59,6 @@ type metricsImplementation interface {
// RecordLoginDuration records how long the login to management took
RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool)
// RecordVNCSessionTick records a periodic snapshot of one VNC
// session's wire activity. Called once per metricsConn tick interval
// (and once at session close), only when the tick saw activity.
RecordVNCSessionTick(ctx context.Context, agentInfo AgentInfo, tick VNCSessionTick)
// Export exports metrics in InfluxDB line protocol format
Export(w io.Writer) error
@@ -83,21 +78,6 @@ type ClientMetrics struct {
pushCancel context.CancelFunc
}
// VNCSessionTick is one sampling slice of a VNC session's wire activity.
// BytesOut / Writes / FBUs / WriteNanos are deltas observed during this
// tick; Max* fields are the high-water marks observed during the tick.
// Period is the wall-clock duration the deltas cover.
type VNCSessionTick struct {
Period time.Duration
BytesOut uint64
Writes uint64
FBUs uint64
MaxFBUBytes uint64
MaxFBURects uint64
MaxWriteBytes uint64
WriteNanos uint64
}
// ConnectionStageTimestamps holds timestamps for each connection stage
type ConnectionStageTimestamps struct {
SignalingReceived time.Time // First signal received from remote peer (both initial and reconnection)
@@ -147,17 +127,6 @@ func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Du
c.impl.RecordSyncDuration(ctx, agentInfo, duration)
}
// RecordVNCSessionTick records a periodic snapshot of one VNC session.
func (c *ClientMetrics) RecordVNCSessionTick(ctx context.Context, tick VNCSessionTick) {
if c == nil {
return
}
c.mu.RLock()
agentInfo := c.agentInfo
c.mu.RUnlock()
c.impl.RecordVNCSessionTick(ctx, agentInfo, tick)
}
// RecordLoginDuration records how long the login to management server took
func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) {
if c == nil {

View File

@@ -73,9 +73,6 @@ func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.
func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) {
}
func (m *mockMetrics) RecordVNCSessionTick(_ context.Context, _ AgentInfo, _ VNCSessionTick) {
}
func (m *mockMetrics) Export(w io.Writer) error {
if m.exportData != "" {
_, err := w.Write([]byte(m.exportData))

View File

@@ -1241,15 +1241,6 @@ func (d *Status) SubscribeToEvents() *EventSubscription {
}
}
// HasEventSubscribers reports whether any client is currently subscribed
// to the daemon's SystemEvent stream. Used by the VNC approval broker to
// fail closed when no UI is connected to prompt the user.
func (d *Status) HasEventSubscribers() bool {
d.eventMux.Lock()
defer d.eventMux.Unlock()
return len(d.eventStreams) > 0
}
// UnsubscribeFromEvents removes an event subscription
func (d *Status) UnsubscribeFromEvents(sub *EventSubscription) {
if sub == nil {

View File

@@ -70,8 +70,6 @@ type ConfigInput struct {
StateFilePath string
PreSharedKey *string
ServerSSHAllowed *bool
ServerVNCAllowed *bool
DisableVNCApproval *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -127,8 +125,6 @@ type Config struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed *bool
ServerVNCAllowed *bool
DisableVNCApproval *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
@@ -458,33 +454,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.ServerVNCAllowed != nil {
if config.ServerVNCAllowed == nil || *input.ServerVNCAllowed != *config.ServerVNCAllowed {
if *input.ServerVNCAllowed {
log.Infof("enabling VNC server")
} else {
log.Infof("disabling VNC server")
}
config.ServerVNCAllowed = input.ServerVNCAllowed
updated = true
}
} else if config.ServerVNCAllowed == nil {
config.ServerVNCAllowed = util.False()
updated = true
}
if input.DisableVNCApproval != nil {
if config.DisableVNCApproval == nil || *input.DisableVNCApproval != *config.DisableVNCApproval {
if *input.DisableVNCApproval {
log.Infof("disabling VNC connection approval prompt")
} else {
log.Infof("enabling VNC connection approval prompt")
}
config.DisableVNCApproval = input.DisableVNCApproval
updated = true
}
}
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
if *input.EnableSSHRoot {
log.Infof("enabling SSH root login")

View File

@@ -1,5 +1,3 @@
//go:build privileged
package routemanager
import (

View File

@@ -1,69 +0,0 @@
//go:build linux && !android
package systemops
import (
"fmt"
"os"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestEntryExists(t *testing.T) {
tempDir := t.TempDir()
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
content := []string{
"1000 reserved",
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
"9999 other_table",
}
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
file, err := os.Open(tempFilePath)
require.NoError(t, err)
defer func() {
assert.NoError(t, file.Close())
}()
tests := []struct {
name string
id int
shouldExist bool
err error
}{
{
name: "ExistsWithNetbirdPrefix",
id: 7120,
shouldExist: true,
err: nil,
},
{
name: "ExistsWithDifferentName",
id: 1000,
shouldExist: true,
err: ErrTableIDExists,
},
{
name: "DoesNotExist",
id: 1234,
shouldExist: false,
err: nil,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
exists, err := entryExists(file, tc.id)
if tc.err != nil {
assert.ErrorIs(t, err, tc.err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.shouldExist, exists)
})
}
}

View File

@@ -1,191 +0,0 @@
//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && privileged
package systemops
import (
"fmt"
"net"
"net/netip"
"os/exec"
"regexp"
"runtime"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func init() {
testCases = append(testCases, []testCase{
{
name: "To more specific route without custom dialer via vpn",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
},
}...)
}
func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
var intf *net.Interface
var nexthop Nexthop
_, intf = setupDummyInterface(t)
nexthop = Nexthop{netip.Addr{}, intf}
r := New(nil, nil)
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.addToRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
baseIP = netip.MustParseAddr("192.0.2.0")
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
}
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper()
if runtime.GOOS == "darwin" {
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
require.NoError(t, err, "Failed to create loopback alias")
t.Cleanup(func() {
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
assert.NoError(t, err, "Failed to remove loopback alias")
})
return intf
}
prefix, err := netip.ParsePrefix(ipAddressCIDR)
require.NoError(t, err, "Failed to parse prefix")
netIntf, err := net.InterfaceByName(intf)
require.NoError(t, err, "Failed to get interface by name")
nexthop := Nexthop{netip.Addr{}, netIntf}
r := New(nil, nil)
err = r.addToRouteTable(prefix, nexthop)
require.NoError(t, err, "Failed to add route to table")
t.Cleanup(func() {
err := r.removeFromRouteTable(prefix, nexthop)
assert.NoError(t, err, "Failed to remove route from table")
})
return intf
}
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
t.Helper()
var originalNexthop net.IP
if dstCIDR == "0.0.0.0/0" {
var err error
originalNexthop, err = fetchOriginalGateway()
if err != nil {
t.Logf("Failed to fetch original gateway: %v", err)
}
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
t.Logf("Failed to delete route: %v, output: %s", err, output)
}
}
t.Cleanup(func() {
if originalNexthop != nil {
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
assert.NoError(t, err, "Failed to restore original route")
}
})
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
require.NoError(t, err, "Failed to add route")
t.Cleanup(func() {
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
assert.NoError(t, err, "Failed to remove route")
})
}
func fetchOriginalGateway() (net.IP, error) {
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
if err != nil {
return nil, err
}
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
if len(matches) == 0 {
return nil, fmt.Errorf("gateway not found")
}
return net.ParseIP(matches[1]), nil
}
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
t.Helper()
if runtime.GOOS == "darwin" {
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
}
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
tunName := strings.TrimSpace(string(output))
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
intf, err := net.InterfaceByName(tunName)
require.NoError(t, err, "Failed to get interface by name")
t.Cleanup(func() {
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
}
})
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
}
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
}

View File

@@ -3,24 +3,79 @@
package systemops
import (
"fmt"
"net"
"net/netip"
"os/exec"
"regexp"
"runtime"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/route"
)
// Interface names used by the shared routing test fixtures. Kept untagged (no
// privileged build tag) so the non-privileged test files in this package compile.
//
//nolint:unused // consumed by the privileged-tagged routing tests
var expectedVPNint = "utun100"
//nolint:unused // consumed by the privileged-tagged routing tests
var expectedExternalInt = "lo0"
//nolint:unused // consumed by the privileged-tagged routing tests
var expectedInternalInt = "lo0"
func init() {
testCases = append(testCases, []testCase{
{
name: "To more specific route without custom dialer via vpn",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
},
}...)
}
func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
var intf *net.Interface
var nexthop Nexthop
_, intf = setupDummyInterface(t)
nexthop = Nexthop{netip.Addr{}, intf}
r := New(nil, nil)
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.addToRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
baseIP = netip.MustParseAddr("192.0.2.0")
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
}
func TestBits(t *testing.T) {
tests := []struct {
name string
@@ -67,3 +122,122 @@ func TestBits(t *testing.T) {
})
}
}
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper()
if runtime.GOOS == "darwin" {
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
require.NoError(t, err, "Failed to create loopback alias")
t.Cleanup(func() {
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
assert.NoError(t, err, "Failed to remove loopback alias")
})
return intf
}
prefix, err := netip.ParsePrefix(ipAddressCIDR)
require.NoError(t, err, "Failed to parse prefix")
netIntf, err := net.InterfaceByName(intf)
require.NoError(t, err, "Failed to get interface by name")
nexthop := Nexthop{netip.Addr{}, netIntf}
r := New(nil, nil)
err = r.addToRouteTable(prefix, nexthop)
require.NoError(t, err, "Failed to add route to table")
t.Cleanup(func() {
err := r.removeFromRouteTable(prefix, nexthop)
assert.NoError(t, err, "Failed to remove route from table")
})
return intf
}
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
t.Helper()
var originalNexthop net.IP
if dstCIDR == "0.0.0.0/0" {
var err error
originalNexthop, err = fetchOriginalGateway()
if err != nil {
t.Logf("Failed to fetch original gateway: %v", err)
}
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
t.Logf("Failed to delete route: %v, output: %s", err, output)
}
}
t.Cleanup(func() {
if originalNexthop != nil {
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
assert.NoError(t, err, "Failed to restore original route")
}
})
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
require.NoError(t, err, "Failed to add route")
t.Cleanup(func() {
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
assert.NoError(t, err, "Failed to remove route")
})
}
func fetchOriginalGateway() (net.IP, error) {
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
if err != nil {
return nil, err
}
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
if len(matches) == 0 {
return nil, fmt.Errorf("gateway not found")
}
return net.ParseIP(matches[1]), nil
}
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
t.Helper()
if runtime.GOOS == "darwin" {
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
}
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
tunName := strings.TrimSpace(string(output))
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
intf, err := net.InterfaceByName(tunName)
require.NoError(t, err, "Failed to get interface by name")
t.Cleanup(func() {
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
}
})
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
}
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
}

View File

@@ -1,17 +0,0 @@
//go:build !android && !ios
package systemops
import (
"context"
"net"
)
// dialer is shared by the per-platform routing test cases. Kept untagged (no
// privileged build tag) so the non-privileged test files compile on every platform.
//
//nolint:unused // consumed by the privileged-tagged routing tests
type dialer interface {
Dial(network, address string) (net.Conn, error)
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}

View File

@@ -1,4 +1,4 @@
//go:build !android && !ios && privileged
//go:build !android && !ios
package systemops
@@ -26,6 +26,11 @@ import (
nbnet "github.com/netbirdio/netbird/client/net"
)
type dialer interface {
Dial(network, address string) (net.Conn, error)
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
func TestAddVPNRoute(t *testing.T) {
testCases := []struct {
name string
@@ -510,3 +515,125 @@ func setupTestEnv(t *testing.T) {
// unique route in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
}
func TestIsVpnRoute(t *testing.T) {
tests := []struct {
name string
addr string
vpnRoutes []string
localRoutes []string
expectedVpn bool
expectedPrefix netip.Prefix
}{
{
name: "Match in VPN routes",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Match in local routes",
addr: "10.1.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
},
{
name: "No match",
addr: "172.16.0.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.Prefix{},
},
{
name: "Default route ignored",
addr: "192.168.1.1",
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Default route matches but ignored",
addr: "172.16.1.1",
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.Prefix{},
},
{
name: "Longest prefix match local",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.0.0/16"},
localRoutes: []string{"192.168.1.0/24"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Longest prefix match local multiple",
addr: "192.168.0.1",
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
},
{
name: "Longest prefix match vpn",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"192.168.0.0/16"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Longest prefix match vpn multiple",
addr: "192.168.0.1",
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
},
{
name: "Duplicate prefix in both",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"192.168.1.0/24"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, err := netip.ParseAddr(tt.addr)
if err != nil {
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
}
var vpnRoutes, localRoutes []netip.Prefix
for _, route := range tt.vpnRoutes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
}
vpnRoutes = append(vpnRoutes, prefix)
}
for _, route := range tt.localRoutes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
t.Fatalf("Failed to parse local route %s: %v", route, err)
}
localRoutes = append(localRoutes, prefix)
}
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
})
}
}

View File

@@ -1,132 +0,0 @@
//go:build !android && !ios
package systemops
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsVpnRoute(t *testing.T) {
tests := []struct {
name string
addr string
vpnRoutes []string
localRoutes []string
expectedVpn bool
expectedPrefix netip.Prefix
}{
{
name: "Match in VPN routes",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Match in local routes",
addr: "10.1.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
},
{
name: "No match",
addr: "172.16.0.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.Prefix{},
},
{
name: "Default route ignored",
addr: "192.168.1.1",
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Default route matches but ignored",
addr: "172.16.1.1",
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.Prefix{},
},
{
name: "Longest prefix match local",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.0.0/16"},
localRoutes: []string{"192.168.1.0/24"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Longest prefix match local multiple",
addr: "192.168.0.1",
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
},
{
name: "Longest prefix match vpn",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"192.168.0.0/16"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Longest prefix match vpn multiple",
addr: "192.168.0.1",
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
},
{
name: "Duplicate prefix in both",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"192.168.1.0/24"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, err := netip.ParseAddr(tt.addr)
if err != nil {
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
}
var vpnRoutes, localRoutes []netip.Prefix
for _, route := range tt.vpnRoutes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
}
vpnRoutes = append(vpnRoutes, prefix)
}
for _, route := range tt.localRoutes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
t.Fatalf("Failed to parse local route %s: %v", route, err)
}
localRoutes = append(localRoutes, prefix)
}
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
})
}
}

View File

@@ -1,10 +1,13 @@
//go:build linux && !android && privileged
//go:build !android
package systemops
import (
"errors"
"fmt"
"net"
"os"
"strings"
"syscall"
"testing"
@@ -15,6 +18,10 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
)
var expectedVPNint = "wgtest0"
var expectedExternalInt = "dummyext0"
var expectedInternalInt = "dummyint0"
func init() {
testCases = append(testCases, []testCase{
{
@@ -26,6 +33,62 @@ func init() {
}...)
}
func TestEntryExists(t *testing.T) {
tempDir := t.TempDir()
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
content := []string{
"1000 reserved",
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
"9999 other_table",
}
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
file, err := os.Open(tempFilePath)
require.NoError(t, err)
defer func() {
assert.NoError(t, file.Close())
}()
tests := []struct {
name string
id int
shouldExist bool
err error
}{
{
name: "ExistsWithNetbirdPrefix",
id: 7120,
shouldExist: true,
err: nil,
},
{
name: "ExistsWithDifferentName",
id: 1000,
shouldExist: true,
err: ErrTableIDExists,
},
{
name: "DoesNotExist",
id: 1234,
shouldExist: false,
err: nil,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
exists, err := entryExists(file, tc.id)
if tc.err != nil {
assert.ErrorIs(t, err, tc.err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.shouldExist, exists)
})
}
}
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
t.Helper()

View File

@@ -1,15 +0,0 @@
//go:build linux && !android
package systemops
// Interface names used by the shared routing test fixtures. Kept untagged (no
// privileged build tag) so the non-privileged test files in this package compile.
//
//nolint:unused // consumed by the privileged-tagged routing tests
var expectedVPNint = "wgtest0"
//nolint:unused // consumed by the privileged-tagged routing tests
var expectedExternalInt = "dummyext0"
//nolint:unused // consumed by the privileged-tagged routing tests
var expectedInternalInt = "dummyint0"

View File

@@ -1,83 +0,0 @@
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
package systemops
import (
"net"
nbnet "github.com/netbirdio/netbird/client/net"
)
// Shared, non-privileged routing test fixtures. The privileged TestRouting (and its
// per-platform init() appenders) consume these; they live here so the unprivileged
// BSD/darwin test files compile without the privileged build tag.
type PacketExpectation struct {
SrcIP net.IP
DstIP net.IP
SrcPort int
DstPort int
UDP bool
TCP bool
}
//nolint:unused // consumed by the privileged-tagged routing tests
type testCase struct {
name string
expectedInterface string
dialer dialer
expectedPacket PacketExpectation
}
//nolint:unused // consumed by the privileged-tagged routing tests
var testCases = []testCase{
{
name: "To external host without custom dialer via vpn",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To external host with custom dialer via physical interface",
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To duplicate internal route with custom dialer via physical interface",
expectedInterface: expectedInternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
},
{
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
expectedInterface: expectedInternalInt,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
},
{
name: "To unique vpn route with custom dialer via physical interface",
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
},
{
name: "To unique vpn route without custom dialer via vpn",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
},
}
//nolint:unused // consumed by the privileged-tagged routing tests
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
return PacketExpectation{
SrcIP: net.ParseIP(srcIP),
DstIP: net.ParseIP(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
UDP: true,
}
}

View File

@@ -1,4 +1,4 @@
//go:build ((linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly) && privileged
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
package systemops
@@ -20,6 +20,63 @@ import (
nbnet "github.com/netbirdio/netbird/client/net"
)
type PacketExpectation struct {
SrcIP net.IP
DstIP net.IP
SrcPort int
DstPort int
UDP bool
TCP bool
}
type testCase struct {
name string
expectedInterface string
dialer dialer
expectedPacket PacketExpectation
}
var testCases = []testCase{
{
name: "To external host without custom dialer via vpn",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To external host with custom dialer via physical interface",
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To duplicate internal route with custom dialer via physical interface",
expectedInterface: expectedInternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
},
{
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
expectedInterface: expectedInternalInt,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
},
{
name: "To unique vpn route with custom dialer via physical interface",
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
},
{
name: "To unique vpn route without custom dialer via vpn",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
},
}
func TestRouting(t *testing.T) {
nbnet.Init()
for _, tc := range testCases {
@@ -45,6 +102,16 @@ func TestRouting(t *testing.T) {
}
}
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
return PacketExpectation{
SrcIP: net.ParseIP(srcIP),
DstIP: net.ParseIP(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
UDP: true,
}
}
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
t.Helper()

View File

@@ -1,5 +1,3 @@
//go:build windows && privileged
package systemops
import (

View File

@@ -11,8 +11,6 @@ import (
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
// interface so route lookups for global IPv6 prefixes resolve in environments
// without v6 connectivity. If a default already exists it is left alone.
//
//nolint:unused // consumed by the privileged-tagged routing tests
func ensureIPv6DefaultRoute(t *testing.T) {
t.Helper()

View File

@@ -1,4 +1,4 @@
//go:build linux && !android && privileged
//go:build linux && !android
package systemops

View File

@@ -8,14 +8,11 @@ import (
"testing"
)
//nolint:unused // consumed by the privileged-tagged routing tests
const loopbackIfaceWindows = "Loopback Pseudo-Interface 1"
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
// interface so route lookups for global IPv6 prefixes resolve in environments
// without v6 connectivity. If a default already exists it is left alone.
//
//nolint:unused // consumed by the privileged-tagged routing tests
func ensureIPv6DefaultRoute(t *testing.T) {
t.Helper()

View File

@@ -74,14 +74,6 @@ func New(filePath string) *Manager {
}
}
// FilePath returns the path of the underlying state file.
func (m *Manager) FilePath() string {
if m == nil {
return ""
}
return m.filePath
}
// Start starts the state manager periodic save routine
func (m *Manager) Start() {
if m == nil {

File diff suppressed because it is too large Load Diff

View File

@@ -121,14 +121,6 @@ service DaemonService {
// ExposeService exposes a local port via the NetBird reverse proxy
rpc ExposeService(ExposeServiceRequest) returns (stream ExposeServiceEvent) {}
// RespondApproval delivers the user's accept/deny decision for a
// pending user-approval prompt. The daemon pushes the prompt as a
// SystemEvent with category APPROVAL and metadata key "request_id";
// the UI calls this RPC with the same request_id to unblock whichever
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
// the UI which subsystem the prompt belongs to.
rpc RespondApproval(RespondApprovalRequest) returns (RespondApprovalResponse) {}
}
@@ -215,10 +207,6 @@ message LoginRequest {
optional bool disableSSHAuth = 38;
optional int32 sshJWTCacheTTL = 39;
optional bool disable_ipv6 = 40;
optional bool serverVNCAllowed = 41;
optional bool disableVNCApproval = 42;
}
message LoginResponse {
@@ -329,16 +317,12 @@ message GetConfigResponse {
bool disable_ipv6 = 27;
bool serverVNCAllowed = 28;
bool disableVNCApproval = 29;
// mDMManagedFields lists the names of configuration keys whose value is
// currently enforced by an MDM policy. Names match mdm.Key* constants
// (e.g. "managementURL", "disableClientRoutes"). UI/CLI clients should
// render the corresponding inputs as read-only and display a "managed
// by MDM" indicator.
repeated string mDMManagedFields = 30;
repeated string mDMManagedFields = 28;
}
// PeerState contains the latest state of a peer
@@ -423,25 +407,6 @@ message SSHServerState {
repeated SSHSessionInfo sessions = 2;
}
// VNCSessionInfo contains information about an active VNC session
message VNCSessionInfo {
string remoteAddress = 1;
string mode = 2;
string username = 3;
// userID is the Noise-verified session identity (hashed user ID from
// the ACL session-key entry), empty when auth is disabled.
string userID = 4;
// initiator is the human-readable display name of the dashboard user
// who minted the SessionPubKey, when known.
string initiator = 5;
}
// VNCServerState contains the latest state of the VNC server
message VNCServerState {
bool enabled = 1;
repeated VNCSessionInfo sessions = 2;
}
// FullStatus contains the full state held by the Status instance
message FullStatus {
ManagementState managementState = 1;
@@ -456,7 +421,6 @@ message FullStatus {
bool lazyConnectionEnabled = 9;
SSHServerState sshServerState = 10;
VNCServerState vncServerState = 11;
}
// Networks
@@ -645,7 +609,6 @@ message SystemEvent {
AUTHENTICATION = 2;
CONNECTIVITY = 3;
SYSTEM = 4;
APPROVAL = 5;
}
string id = 1;
@@ -736,10 +699,6 @@ message SetConfigRequest {
optional bool disableSSHAuth = 33;
optional int32 sshJWTCacheTTL = 34;
optional bool disable_ipv6 = 35;
optional bool serverVNCAllowed = 36;
optional bool disableVNCApproval = 37;
}
message SetConfigResponse{}
@@ -970,18 +929,3 @@ message StartBundleCaptureRequest {
message StartBundleCaptureResponse {}
message StopBundleCaptureRequest {}
message StopBundleCaptureResponse {}
message RespondApprovalRequest {
// request_id matches the SystemEvent metadata key emitted by the daemon
// when a subsystem awaits user approval for an inbound connection.
string request_id = 1;
// accept is true if the user approved the request, false if they
// denied it. A missing or unknown request_id is treated as a no-op.
bool accept = 2;
// view_only signals that the user granted the connection but withheld
// input control. Only meaningful when accept is true; ignored when
// accept is false.
bool view_only = 3;
}
message RespondApprovalResponse {}

View File

@@ -59,7 +59,6 @@ const (
DaemonService_StopCPUProfile_FullMethodName = "/daemon.DaemonService/StopCPUProfile"
DaemonService_GetInstallerResult_FullMethodName = "/daemon.DaemonService/GetInstallerResult"
DaemonService_ExposeService_FullMethodName = "/daemon.DaemonService/ExposeService"
DaemonService_RespondApproval_FullMethodName = "/daemon.DaemonService/RespondApproval"
)
// DaemonServiceClient is the client API for DaemonService service.
@@ -137,13 +136,6 @@ type DaemonServiceClient interface {
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
// ExposeService exposes a local port via the NetBird reverse proxy
ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ExposeServiceEvent], error)
// RespondApproval delivers the user's accept/deny decision for a
// pending user-approval prompt. The daemon pushes the prompt as a
// SystemEvent with category APPROVAL and metadata key "request_id";
// the UI calls this RPC with the same request_id to unblock whichever
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
// the UI which subsystem the prompt belongs to.
RespondApproval(ctx context.Context, in *RespondApprovalRequest, opts ...grpc.CallOption) (*RespondApprovalResponse, error)
}
type daemonServiceClient struct {
@@ -581,16 +573,6 @@ func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServi
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type DaemonService_ExposeServiceClient = grpc.ServerStreamingClient[ExposeServiceEvent]
func (c *daemonServiceClient) RespondApproval(ctx context.Context, in *RespondApprovalRequest, opts ...grpc.CallOption) (*RespondApprovalResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(RespondApprovalResponse)
err := c.cc.Invoke(ctx, DaemonService_RespondApproval_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility.
@@ -666,13 +648,6 @@ type DaemonServiceServer interface {
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
// ExposeService exposes a local port via the NetBird reverse proxy
ExposeService(*ExposeServiceRequest, grpc.ServerStreamingServer[ExposeServiceEvent]) error
// RespondApproval delivers the user's accept/deny decision for a
// pending user-approval prompt. The daemon pushes the prompt as a
// SystemEvent with category APPROVAL and metadata key "request_id";
// the UI calls this RPC with the same request_id to unblock whichever
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
// the UI which subsystem the prompt belongs to.
RespondApproval(context.Context, *RespondApprovalRequest) (*RespondApprovalResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -803,9 +778,6 @@ func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *Ins
func (UnimplementedDaemonServiceServer) ExposeService(*ExposeServiceRequest, grpc.ServerStreamingServer[ExposeServiceEvent]) error {
return status.Error(codes.Unimplemented, "method ExposeService not implemented")
}
func (UnimplementedDaemonServiceServer) RespondApproval(context.Context, *RespondApprovalRequest) (*RespondApprovalResponse, error) {
return nil, status.Error(codes.Unimplemented, "method RespondApproval not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
func (UnimplementedDaemonServiceServer) testEmbeddedByValue() {}
@@ -1526,24 +1498,6 @@ func _DaemonService_ExposeService_Handler(srv interface{}, stream grpc.ServerStr
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
type DaemonService_ExposeServiceServer = grpc.ServerStreamingServer[ExposeServiceEvent]
func _DaemonService_RespondApproval_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RespondApprovalRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).RespondApproval(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: DaemonService_RespondApproval_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).RespondApproval(ctx, req.(*RespondApprovalRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -1699,10 +1653,6 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetInstallerResult",
Handler: _DaemonService_GetInstallerResult_Handler,
},
{
MethodName: "RespondApproval",
Handler: _DaemonService_RespondApproval_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@@ -111,7 +111,7 @@ func (s *Server) StartCapture(req *proto.StartCaptureRequest, stream proto.Daemo
return status.Errorf(codes.Internal, "create capture session: %v", err)
}
engine, err := s.claimCapture(sess, func() { pw.Close() })
engine, err := s.claimCapture(sess)
if err != nil {
sess.Stop()
pw.Close()
@@ -190,7 +190,10 @@ func (s *Server) StartBundleCapture(_ context.Context, req *proto.StartBundleCap
s.stopBundleCaptureLocked()
s.cleanupBundleCapture()
s.evictActiveCaptureLocked()
if s.activeCapture != nil {
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
}
engine, err := s.getCaptureEngineLocked()
if err != nil {
@@ -301,58 +304,29 @@ func (s *Server) cleanupBundleCapture() {
s.bundleCapture = nil
}
// claimCapture reserves the engine's capture slot for sess. If another
// capture is already running it is evicted: a previous streaming session
// whose gRPC client died and never freed the slot stays stuck otherwise,
// and a bundle capture is just informational state.
func (s *Server) claimCapture(sess *capture.Session, cancel func()) (*internal.Engine, error) {
// claimCapture reserves the engine's capture slot for sess. Returns
// FailedPrecondition if another capture is already active.
func (s *Server) claimCapture(sess *capture.Session) (*internal.Engine, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.evictActiveCaptureLocked()
if s.activeCapture != nil {
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
}
engine, err := s.getCaptureEngineLocked()
if err != nil {
return nil, err
}
s.activeCapture = sess
s.activeCaptureCancel = cancel
return engine, nil
}
// evictActiveCaptureLocked tears down whatever capture currently owns
// the engine slot so a fresh claim can succeed. Caller must hold mutex.
func (s *Server) evictActiveCaptureLocked() {
if s.activeCapture == nil {
return
}
if s.bundleCapture != nil && s.bundleCapture.sess == s.activeCapture {
log.Infof("evicting running bundle capture to start a new capture")
s.stopBundleCaptureLocked()
return
}
log.Infof("evicting previous streaming capture to start a new one")
prev := s.activeCapture
cancel := s.activeCaptureCancel
if engine, err := s.getCaptureEngineLocked(); err == nil {
if err := engine.SetCapture(nil); err != nil {
log.Debugf("clear previous capture: %v", err)
}
}
s.activeCapture = nil
s.activeCaptureCancel = nil
prev.Stop()
if cancel != nil {
cancel()
}
}
// releaseCapture clears the active-capture owner if it still matches sess.
func (s *Server) releaseCapture(sess *capture.Session) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.activeCapture == sess {
s.activeCapture = nil
s.activeCaptureCancel = nil
}
}
@@ -367,7 +341,6 @@ func (s *Server) clearCaptureIfOwner(sess *capture.Session, engine *internal.Eng
log.Debugf("clear capture: %v", err)
}
s.activeCapture = nil
s.activeCaptureCancel = nil
}
func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) {

View File

@@ -100,12 +100,8 @@ type Server struct {
captureEnabled bool
bundleCapture *bundleCapture
// activeCapture is the session currently installed on the engine; guarded by s.mutex.
activeCapture *capture.Session
// activeCaptureCancel tears down the streaming pipe/cancel for the
// active streaming capture so eviction unblocks the StartCapture RPC
// handler. Nil for bundle captures (they own their own context).
activeCaptureCancel func()
networksDisabled bool
activeCapture *capture.Session
networksDisabled bool
sleepHandler *sleephandler.SleepHandler
@@ -460,8 +456,6 @@ func (s *Server) setConfigInputFromRequest(msg *proto.SetConfigRequest) (profile
config.RosenpassPermissive = msg.RosenpassPermissive
config.DisableAutoConnect = msg.DisableAutoConnect
config.ServerSSHAllowed = msg.ServerSSHAllowed
config.ServerVNCAllowed = msg.ServerVNCAllowed
config.DisableVNCApproval = msg.DisableVNCApproval
config.NetworkMonitor = msg.NetworkMonitor
config.DisableClientRoutes = msg.DisableClientRoutes
config.DisableServerRoutes = msg.DisableServerRoutes
@@ -1257,7 +1251,6 @@ func (s *Server) Status(
pbFullStatus := fullStatus.ToProto()
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
pbFullStatus.SshServerState = s.getSSHServerState()
pbFullStatus.VncServerState = s.getVNCServerState()
statusResponse.FullStatus = pbFullStatus
}
@@ -1297,38 +1290,6 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
return sshServerState
}
// getVNCServerState retrieves the current VNC server state.
func (s *Server) getVNCServerState() *proto.VNCServerState {
s.mutex.Lock()
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return nil
}
engine := connectClient.Engine()
if engine == nil {
return nil
}
enabled, sessions := engine.GetVNCServerStatus()
pbSessions := make([]*proto.VNCSessionInfo, 0, len(sessions))
for _, sess := range sessions {
pbSessions = append(pbSessions, &proto.VNCSessionInfo{
RemoteAddress: sess.RemoteAddress,
Mode: sess.Mode,
Username: sess.Username,
UserID: sess.UserID,
Initiator: sess.Initiator,
})
}
return &proto.VNCServerState{
Enabled: enabled,
Sessions: pbSessions,
}
}
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
func (s *Server) GetPeerSSHHostKey(
ctx context.Context,
@@ -1569,27 +1530,6 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon
return nil
}
// RespondApproval relays the user's accept/deny decision for a pending
// approval prompt to the engine's broker. Unknown or already-resolved
// request_ids are silently no-op'd so a slow UI cannot deny a prompt the
// user already handled (or that already timed out).
func (s *Server) RespondApproval(_ context.Context, msg *proto.RespondApprovalRequest) (*proto.RespondApprovalResponse, error) {
s.mutex.Lock()
connectClient := s.connectClient
s.mutex.Unlock()
if connectClient == nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "client not initialized")
}
engine := connectClient.Engine()
if engine == nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "engine not running")
}
if !engine.RespondApproval(msg.GetRequestId(), msg.GetAccept(), msg.GetViewOnly()) {
log.Debugf("approval response for unknown request_id %s", msg.GetRequestId())
}
return &proto.RespondApprovalResponse{}, nil
}
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false
@@ -1705,8 +1645,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed,
ServerVNCAllowed: cfg.ServerVNCAllowed != nil && *cfg.ServerVNCAllowed,
DisableVNCApproval: cfg.DisableVNCApproval != nil && *cfg.DisableVNCApproval,
RosenpassEnabled: cfg.RosenpassEnabled,
RosenpassPermissive: cfg.RosenpassPermissive,
LazyConnectionEnabled: cfg.LazyConnectionEnabled,

View File

@@ -1,235 +0,0 @@
//go:build privileged
package server
import (
"context"
"net"
"os/user"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
)
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
)
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
// we will use a management server started via to simulate the server and capture the number of retries
func TestConnectWithRetryRuns(t *testing.T) {
// start the signal server
_, signalAddr, err := startSignal(t)
if err != nil {
t.Fatalf("failed to start signal server: %v", err)
}
counter := 0
// start the management server
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
if err != nil {
t.Fatalf("failed to start management server: %v", err)
}
ctx := internal.CtxInitState(context.Background())
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
defer cancel()
// create new server
ic := profilemanager.ConfigInput{
ManagementURL: "http://" + mgmtAddr,
ConfigPath: t.TempDir() + "/test-profile.json",
}
config, err := profilemanager.UpdateOrCreateConfig(ic)
if err != nil {
t.Fatalf("failed to create config: %v", err)
}
currUser, err := user.Current()
require.NoError(t, err)
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: "test-profile",
Username: currUser.Username,
})
if err != nil {
t.Fatalf("failed to set active profile state: %v", err)
}
s := New(ctx, "debug", "", false, false, false, false)
s.config = config
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
t.Setenv(retryInitialIntervalVar, "1s")
t.Setenv(maxRetryIntervalVar, "2s")
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}
}
type mockServer struct {
mgmtProto.ManagementServiceServer
counter *int
}
func (m *mockServer) Login(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
*m.counter++
return m.ManagementServiceServer.Login(ctx, req)
}
func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Server, string, error) {
t.Helper()
dataDir := t.TempDir()
config := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Signal: &config.Host{
Proto: "http",
URI: signalAddr,
},
Datadir: dataDir,
HttpConfig: nil,
}
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
if err != nil {
return nil, "", err
}
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl)
peersManager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
if err != nil {
return nil, "", err
}
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
return nil, "", err
}
mock := &mockServer{
ManagementServiceServer: mgmtServer,
counter: counter,
}
mgmtProto.RegisterManagementServiceServer(s, mock)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
func startSignal(t *testing.T) (*grpc.Server, string, error) {
t.Helper()
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
}
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}

View File

@@ -2,22 +2,124 @@ package server
import (
"context"
"net"
"net/url"
"os/user"
"path/filepath"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
daemonProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
)
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
)
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
// we will use a management server started via to simulate the server and capture the number of retries
func TestConnectWithRetryRuns(t *testing.T) {
// start the signal server
_, signalAddr, err := startSignal(t)
if err != nil {
t.Fatalf("failed to start signal server: %v", err)
}
counter := 0
// start the management server
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
if err != nil {
t.Fatalf("failed to start management server: %v", err)
}
ctx := internal.CtxInitState(context.Background())
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
defer cancel()
// create new server
ic := profilemanager.ConfigInput{
ManagementURL: "http://" + mgmtAddr,
ConfigPath: t.TempDir() + "/test-profile.json",
}
config, err := profilemanager.UpdateOrCreateConfig(ic)
if err != nil {
t.Fatalf("failed to create config: %v", err)
}
currUser, err := user.Current()
require.NoError(t, err)
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: "test-profile",
Username: currUser.Username,
})
if err != nil {
t.Fatalf("failed to set active profile state: %v", err)
}
s := New(ctx, "debug", "", false, false, false, false)
s.config = config
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
t.Setenv(retryInitialIntervalVar, "1s")
t.Setenv(maxRetryIntervalVar, "2s")
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}
}
func TestServer_Up(t *testing.T) {
tempDir := t.TempDir()
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
@@ -157,3 +259,119 @@ func TestServer_SubcribeEvents(t *testing.T) {
assert.NoError(t, err)
}
type mockServer struct {
mgmtProto.ManagementServiceServer
counter *int
}
func (m *mockServer) Login(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
*m.counter++
return m.ManagementServiceServer.Login(ctx, req)
}
func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Server, string, error) {
t.Helper()
dataDir := t.TempDir()
config := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Signal: &config.Host{
Proto: "http",
URI: signalAddr,
},
Datadir: dataDir,
HttpConfig: nil,
}
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
if err != nil {
return nil, "", err
}
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl)
peersManager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
if err != nil {
return nil, "", err
}
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
return nil, "", err
}
mock := &mockServer{
ManagementServiceServer: mgmtServer,
counter: counter,
}
mgmtProto.RegisterManagementServiceServer(s, mock)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
func startSignal(t *testing.T) (*grpc.Server, string, error) {
t.Helper()
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}

View File

@@ -58,8 +58,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
rosenpassEnabled := true
rosenpassPermissive := true
serverSSHAllowed := true
serverVNCAllowed := true
disableVNCApproval := true
interfaceName := "utun100"
wireguardPort := int64(51820)
preSharedKey := "test-psk"
@@ -85,8 +83,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
RosenpassEnabled: &rosenpassEnabled,
RosenpassPermissive: &rosenpassPermissive,
ServerSSHAllowed: &serverSSHAllowed,
ServerVNCAllowed: &serverVNCAllowed,
DisableVNCApproval: &disableVNCApproval,
InterfaceName: &interfaceName,
WireguardPort: &wireguardPort,
OptionalPreSharedKey: &preSharedKey,
@@ -131,10 +127,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
require.NotNil(t, cfg.ServerSSHAllowed)
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
require.NotNil(t, cfg.ServerVNCAllowed)
require.Equal(t, serverVNCAllowed, *cfg.ServerVNCAllowed)
require.NotNil(t, cfg.DisableVNCApproval)
require.Equal(t, disableVNCApproval, *cfg.DisableVNCApproval)
require.Equal(t, interfaceName, cfg.WgIface)
require.Equal(t, int(wireguardPort), cfg.WgPort)
require.Equal(t, preSharedKey, cfg.PreSharedKey)
@@ -187,8 +179,6 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
"RosenpassEnabled": true,
"RosenpassPermissive": true,
"ServerSSHAllowed": true,
"ServerVNCAllowed": true,
"DisableVNCApproval": true,
"InterfaceName": true,
"WireguardPort": true,
"OptionalPreSharedKey": true,
@@ -250,8 +240,6 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
"enable-rosenpass": "RosenpassEnabled",
"rosenpass-permissive": "RosenpassPermissive",
"allow-server-ssh": "ServerSSHAllowed",
"allow-server-vnc": "ServerVNCAllowed",
"disable-vnc-approval": "DisableVNCApproval",
"interface-name": "InterfaceName",
"wireguard-port": "WireguardPort",
"preshared-key": "OptionalPreSharedKey",

View File

@@ -1,4 +1,4 @@
package sessionauth
package auth
import (
"errors"
@@ -15,8 +15,6 @@ const (
DefaultUserIDClaim = "sub"
// Wildcard is a special user ID that matches all users
Wildcard = "*"
// sessionPubKeyLen is the size of an X25519 static public key in bytes.
sessionPubKeyLen = 32
)
var (
@@ -24,7 +22,6 @@ var (
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
ErrSessionKeyNotKnown = errors.New("session pubkey not registered")
)
// Authorizer handles SSH fine-grained access control authorization
@@ -38,17 +35,6 @@ type Authorizer struct {
// machineUsers maps OS login usernames to lists of authorized user indexes
machineUsers map[string][]uint32
// sessionPubKeys maps an X25519 static public key (as map-safe
// array) to the hashed user identity that key authenticates as.
// Populated from management's temporary-access flow; used by VNC to
// authenticate via the Noise_IK handshake.
sessionPubKeys map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash
// sessionDisplayNames mirrors sessionPubKeys with the optional
// human-readable display name management associated with each
// session key. Used by the per-connection UI approval prompt; not
// consulted by any authorization decision.
sessionDisplayNames map[[sessionPubKeyLen]byte]string
// mu protects the list of users
mu sync.RWMutex
}
@@ -64,29 +50,13 @@ type Config struct {
// MachineUsers maps OS login usernames to indexes in AuthorizedUsers
// If a user wants to login as a specific OS user, their index must be in the corresponding list
MachineUsers map[string][]uint32
// SessionPubKeys binds ephemeral X25519 static public keys to hashed
// user identities. Populated for VNC; ignored on the SSH side.
SessionPubKeys []SessionPubKey
}
// SessionPubKey is a single ephemeral-key entry: the 32-byte X25519
// static public key plus the hashed user identity it authenticates as,
// optionally plus a human-readable display name for the UI approval
// prompt to identify the requester.
type SessionPubKey struct {
PubKey []byte
UserIDHash sshuserhash.UserIDHash
DisplayName string
}
// NewAuthorizer creates a new SSH authorizer with empty configuration
func NewAuthorizer() *Authorizer {
a := &Authorizer{
userIDClaim: DefaultUserIDClaim,
machineUsers: make(map[string][]uint32),
sessionPubKeys: make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash),
sessionDisplayNames: make(map[[sessionPubKeyLen]byte]string),
userIDClaim: DefaultUserIDClaim,
machineUsers: make(map[string][]uint32),
}
return a
@@ -102,8 +72,6 @@ func (a *Authorizer) Update(config *Config) {
a.userIDClaim = DefaultUserIDClaim
a.authorizedUsers = []sshuserhash.UserIDHash{}
a.machineUsers = make(map[string][]uint32)
a.sessionPubKeys = make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash)
a.sessionDisplayNames = make(map[[sessionPubKeyLen]byte]string)
log.Info("SSH authorization cleared")
return
}
@@ -126,35 +94,8 @@ func (a *Authorizer) Update(config *Config) {
}
a.machineUsers = machineUsers
sessionPubKeys := make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash, len(config.SessionPubKeys))
sessionDisplayNames := make(map[[sessionPubKeyLen]byte]string, len(config.SessionPubKeys))
conflicted := make(map[[sessionPubKeyLen]byte]struct{})
for _, e := range config.SessionPubKeys {
if len(e.PubKey) != sessionPubKeyLen {
continue
}
var key [sessionPubKeyLen]byte
copy(key[:], e.PubKey)
if _, bad := conflicted[key]; bad {
continue
}
if existing, ok := sessionPubKeys[key]; ok && existing != e.UserIDHash {
log.Warnf("SSH auth: session pubkey bound to conflicting user hashes; dropping binding")
delete(sessionPubKeys, key)
delete(sessionDisplayNames, key)
conflicted[key] = struct{}{}
continue
}
sessionPubKeys[key] = e.UserIDHash
if e.DisplayName != "" {
sessionDisplayNames[key] = e.DisplayName
}
}
a.sessionPubKeys = sessionPubKeys
a.sessionDisplayNames = sessionDisplayNames
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings, %d session pubkeys",
len(config.AuthorizedUsers), len(machineUsers), len(sessionPubKeys))
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings",
len(config.AuthorizedUsers), len(machineUsers))
}
// Authorize validates if a user is authorized to login as the specified OS user.
@@ -214,54 +155,6 @@ func (a *Authorizer) GetUserIDClaim() string {
return a.userIDClaim
}
// LookupSessionKey resolves a Noise-verified static public key to the
// hashed user identity registered with it. Fails closed when the key is
// unknown.
func (a *Authorizer) LookupSessionKey(pubKey []byte) (sshuserhash.UserIDHash, error) {
var zero sshuserhash.UserIDHash
if len(pubKey) != sessionPubKeyLen {
return zero, fmt.Errorf("session pubkey wrong length: %d", len(pubKey))
}
var key [sessionPubKeyLen]byte
copy(key[:], pubKey)
a.mu.RLock()
hash, ok := a.sessionPubKeys[key]
a.mu.RUnlock()
if !ok {
return zero, ErrSessionKeyNotKnown
}
return hash, nil
}
// LookupSessionDisplayName returns the human-readable display name
// management associated with a session pubkey, or empty string when none
// is recorded. Never returns an error: a missing/unknown key reports as
// "" and the caller falls back to other identifiers.
func (a *Authorizer) LookupSessionDisplayName(pubKey []byte) string {
if len(pubKey) != sessionPubKeyLen {
return ""
}
var key [sessionPubKeyLen]byte
copy(key[:], pubKey)
a.mu.RLock()
name := a.sessionDisplayNames[key]
a.mu.RUnlock()
return name
}
// AuthorizeOSUserBySessionKey resolves the OS-user mapping for a session
// key. Mirrors Authorize but skips the JWT-hash step since the key has
// already been verified and the user identity hash is in hand.
func (a *Authorizer) AuthorizeOSUserBySessionKey(userIDHash sshuserhash.UserIDHash, osUsername string) (string, error) {
a.mu.RLock()
defer a.mu.RUnlock()
userIndex, found := a.findUserIndex(userIDHash)
if !found {
return "", fmt.Errorf("session user (hash: %s) not in authorized list for OS user %q: %w", userIDHash, osUsername, ErrUserNotAuthorized)
}
return a.checkMachineUserMapping("session", osUsername, userIndex)
}
// findUserIndex finds the index of a hashed user ID in the authorized users list
// Returns the index and true if found, 0 and false if not found
func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) {

View File

@@ -1,7 +1,6 @@
package sessionauth
package auth
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
@@ -611,61 +610,3 @@ func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
assert.Error(t, err)
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
}
func TestAuthorizer_LookupSessionKey_Valid(t *testing.T) {
pub := bytesRepeat(0x11, sessionPubKeyLen)
userHash, err := sshauth.HashUserID("alice")
require.NoError(t, err)
a := NewAuthorizer()
a.Update(&Config{
AuthorizedUsers: []sshauth.UserIDHash{userHash},
MachineUsers: map[string][]uint32{Wildcard: {0}},
SessionPubKeys: []SessionPubKey{{PubKey: pub, UserIDHash: userHash}},
})
got, err := a.LookupSessionKey(pub)
require.NoError(t, err)
assert.Equal(t, userHash, got)
if _, err := a.AuthorizeOSUserBySessionKey(got, "alice"); err != nil {
t.Fatalf("AuthorizeOSUserBySessionKey: %v", err)
}
}
func TestAuthorizer_LookupSessionKey_UnknownPub(t *testing.T) {
a := NewAuthorizer()
a.Update(&Config{})
_, err := a.LookupSessionKey(bytesRepeat(0x22, sessionPubKeyLen))
require.ErrorIs(t, err, ErrSessionKeyNotKnown)
}
func TestAuthorizer_LookupSessionKey_WrongLength(t *testing.T) {
a := NewAuthorizer()
_, err := a.LookupSessionKey([]byte("short"))
require.Error(t, err)
}
func TestAuthorizer_LookupSessionKey_UpdateClears(t *testing.T) {
pub := bytesRepeat(0x33, sessionPubKeyLen)
userHash, err := sshauth.HashUserID("alice")
require.NoError(t, err)
a := NewAuthorizer()
a.Update(&Config{SessionPubKeys: []SessionPubKey{{PubKey: pub, UserIDHash: userHash}}})
if _, err := a.LookupSessionKey(pub); err != nil {
t.Fatalf("setup lookup: %v", err)
}
a.Update(&Config{})
if _, err := a.LookupSessionKey(pub); !errors.Is(err, ErrSessionKeyNotKnown) {
t.Fatalf("expected ErrSessionKeyNotKnown, got %v", err)
}
}
func bytesRepeat(b byte, n int) []byte {
out := make([]byte, n)
for i := range out {
out[i] = b
}
return out
}

View File

@@ -1,118 +0,0 @@
//go:build privileged
package client
import (
"context"
"errors"
"runtime"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/ssh/testutil"
)
func TestSSHClient_CommandExecution(t *testing.T) {
if runtime.GOOS == "windows" && testutil.IsCI() {
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
}
server, _, client := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
defer func() {
err := client.Close()
assert.NoError(t, err)
}()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
t.Run("ExecuteCommand captures output", func(t *testing.T) {
output, err := client.ExecuteCommand(ctx, "echo hello")
assert.NoError(t, err)
assert.Contains(t, string(output), "hello")
})
t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
err := client.ExecuteCommandWithIO(ctx, "echo world")
assert.NoError(t, err)
})
t.Run("commands with flags work", func(t *testing.T) {
output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
assert.NoError(t, err)
assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
})
t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
var testCmd string
if runtime.GOOS == "windows" {
testCmd = "echo hello | Select-String notfound"
} else {
testCmd = "echo 'hello' | grep 'notfound'"
}
_, err := client.ExecuteCommand(ctx, testCmd)
assert.NoError(t, err)
})
}
func TestSSHClient_ContextCancellation(t *testing.T) {
server, serverAddr, _ := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
t.Run("connection with short timeout", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
currentUser := testutil.GetTestUsername(t)
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
if err != nil {
// Check for actual timeout-related errors rather than string matching
assert.True(t,
errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, context.Canceled) ||
strings.Contains(err.Error(), "timeout"),
"Expected timeout-related error, got: %v", err)
}
})
t.Run("command execution cancellation", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := testutil.GetTestUsername(t)
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
t.Logf("client close error: %v", err)
}
}()
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cmdCancel()
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
if err != nil {
var exitMissingErr *cryptossh.ExitMissingError
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, context.Canceled) ||
errors.As(err, &exitMissingErr)
assert.True(t, isValidCancellation, "Should handle command cancellation properly")
}
})
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/ssh"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
@@ -77,6 +78,53 @@ func TestSSHClient_DialWithKey(t *testing.T) {
assert.NotNil(t, client.client)
}
func TestSSHClient_CommandExecution(t *testing.T) {
if runtime.GOOS == "windows" && testutil.IsCI() {
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
}
server, _, client := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
defer func() {
err := client.Close()
assert.NoError(t, err)
}()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
t.Run("ExecuteCommand captures output", func(t *testing.T) {
output, err := client.ExecuteCommand(ctx, "echo hello")
assert.NoError(t, err)
assert.Contains(t, string(output), "hello")
})
t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
err := client.ExecuteCommandWithIO(ctx, "echo world")
assert.NoError(t, err)
})
t.Run("commands with flags work", func(t *testing.T) {
output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
assert.NoError(t, err)
assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
})
t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
var testCmd string
if runtime.GOOS == "windows" {
testCmd = "echo hello | Select-String notfound"
} else {
testCmd = "echo 'hello' | grep 'notfound'"
}
_, err := client.ExecuteCommand(ctx, testCmd)
assert.NoError(t, err)
})
}
func TestSSHClient_ConnectionHandling(t *testing.T) {
server, serverAddr, _ := setupTestSSHServerAndClient(t)
defer func() {
@@ -106,6 +154,59 @@ func TestSSHClient_ConnectionHandling(t *testing.T) {
}
}
func TestSSHClient_ContextCancellation(t *testing.T) {
server, serverAddr, _ := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
t.Run("connection with short timeout", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
currentUser := testutil.GetTestUsername(t)
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
if err != nil {
// Check for actual timeout-related errors rather than string matching
assert.True(t,
errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, context.Canceled) ||
strings.Contains(err.Error(), "timeout"),
"Expected timeout-related error, got: %v", err)
}
})
t.Run("command execution cancellation", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := testutil.GetTestUsername(t)
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
t.Logf("client close error: %v", err)
}
}()
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cmdCancel()
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
if err != nil {
var exitMissingErr *cryptossh.ExitMissingError
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, context.Canceled) ||
errors.As(err, &exitMissingErr)
assert.True(t, isValidCancellation, "Should handle command cancellation properly")
}
})
}
func TestSSHClient_NoAuthMode(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)

View File

@@ -1,423 +0,0 @@
//go:build privileged
package proxy
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"os"
"runtime"
"strconv"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
func (m *mockDaemon) setJWTToken(token string) {
m.impl.jwtToken = token
}
func TestSSHProxy_Connect(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
// TODO: Windows test times out - user switching and command execution tested on Linux
if runtime.GOOS == "windows" {
t.Skip("Skipping on Windows - covered by Linux tests")
}
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
// Configure SSH authorization for the test user
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0}, // Index 0 in AuthorizedUsers
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer)
defer func() { _ = sshServer.Stop() }()
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)
clientConn, proxyConn := net.Pipe()
defer func() { _ = clientConn.Close() }()
origStdin := os.Stdin
origStdout := os.Stdout
defer func() {
os.Stdin = origStdin
os.Stdout = origStdout
}()
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
go func() {
_, _ = io.Copy(stdinWriter, proxyConn)
}()
go func() {
_, _ = io.Copy(proxyConn, stdoutReader)
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
connectErrCh := make(chan error, 1)
go func() {
connectErrCh <- proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 3 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err, "Should connect to proxy server")
defer func() { _ = sshClientConn.Close() }()
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
session, err := sshClient.NewSession()
require.NoError(t, err, "Should create session through full proxy to backend")
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output("echo hello-from-proxy")
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
require.NoError(t, err, "Command should execute successfully through proxy")
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
case <-time.After(3 * time.Second):
t.Fatal("Command execution timed out")
}
_ = session.Close()
_ = sshClient.Close()
_ = clientConn.Close()
cancel()
}
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
// when forwarding commands to the backend. This is critical for tools like
// Ansible that send commands such as:
//
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
//
// The single quotes must be preserved so the backend shell receives the
// subshell expression as a single argument to -c.
func TestSSHProxy_CommandQuoting(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
sshClient, cleanup := setupProxySSHClient(t)
defer cleanup()
// These commands simulate what the SSH protocol delivers as exec payloads.
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
// the local shell strips the outer single quotes, and the SSH exec request
// contains the raw string: /bin/sh -c "( echo hello )"
//
// The proxy must forward this string verbatim. Using session.Command()
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
// the command on the backend.
tests := []struct {
name string
command string
expect string
}{
{
name: "subshell_in_double_quotes",
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
expect: "from-subshell\nouter\n",
},
{
name: "printf_with_special_chars",
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
expect: "hello world\n",
},
{
name: "nested_command_substitution",
command: `/bin/sh -c "echo $(echo nested)"`,
expect: "nested\n",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
session, err := sshClient.NewSession()
require.NoError(t, err)
defer func() { _ = session.Close() }()
var stderrBuf bytes.Buffer
session.Stderr = &stderrBuf
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output(tc.command)
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
if stderrBuf.Len() > 0 {
t.Logf("stderr: %s", stderrBuf.String())
}
require.NoError(t, err, "command should succeed: %s", tc.command)
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
case <-time.After(5 * time.Second):
t.Fatalf("command timed out: %s", tc.command)
}
})
}
}
// setupProxySSHClient creates a full proxy test environment and returns
// an SSH client connected through the proxy to a backend NetBird SSH server.
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
t.Helper()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0},
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer)
mockDaemon := startMockDaemon(t)
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)
origStdin := os.Stdin
origStdout := os.Stdout
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
clientConn, proxyConn := net.Pipe()
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
go func() {
_ = proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err)
client := cryptossh.NewClient(sshClientConn, chans, reqs)
cleanupFn := func() {
_ = client.Close()
_ = clientConn.Close()
cancel()
os.Stdin = origStdin
os.Stdout = origStdout
_ = sshServer.Stop()
mockDaemon.stop()
jwksServer.Close()
}
return client, cleanupFn
}
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
t.Helper()
privateKey, jwksJSON := generateTestJWKS(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jwksJSON); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
return server, privateKey, server.URL
}
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
publicKey := &privateKey.PublicKey
n := publicKey.N.Bytes()
e := publicKey.E
jwk := nbjwt.JSONWebKey{
Kty: "RSA",
Kid: "test-key-id",
Use: "sig",
N: base64.RawURLEncoding.EncodeToString(n),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
}
jwks := nbjwt.Jwks{
Keys: []nbjwt.JSONWebKey{jwk},
}
jwksJSON, err := json.Marshal(jwks)
require.NoError(t, err)
return privateKey, jwksJSON
}
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
t.Helper()
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": user,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}

View File

@@ -1,12 +1,25 @@
package proxy
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"os"
"runtime"
"strconv"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
@@ -15,7 +28,11 @@ import (
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
func TestMain(m *testing.M) {
@@ -89,6 +106,331 @@ func TestSSHProxy_verifyHostKey(t *testing.T) {
})
}
func TestSSHProxy_Connect(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
// TODO: Windows test times out - user switching and command execution tested on Linux
if runtime.GOOS == "windows" {
t.Skip("Skipping on Windows - covered by Linux tests")
}
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
// Configure SSH authorization for the test user
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0}, // Index 0 in AuthorizedUsers
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer)
defer func() { _ = sshServer.Stop() }()
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)
clientConn, proxyConn := net.Pipe()
defer func() { _ = clientConn.Close() }()
origStdin := os.Stdin
origStdout := os.Stdout
defer func() {
os.Stdin = origStdin
os.Stdout = origStdout
}()
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
go func() {
_, _ = io.Copy(stdinWriter, proxyConn)
}()
go func() {
_, _ = io.Copy(proxyConn, stdoutReader)
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
connectErrCh := make(chan error, 1)
go func() {
connectErrCh <- proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 3 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err, "Should connect to proxy server")
defer func() { _ = sshClientConn.Close() }()
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
session, err := sshClient.NewSession()
require.NoError(t, err, "Should create session through full proxy to backend")
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output("echo hello-from-proxy")
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
require.NoError(t, err, "Command should execute successfully through proxy")
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
case <-time.After(3 * time.Second):
t.Fatal("Command execution timed out")
}
_ = session.Close()
_ = sshClient.Close()
_ = clientConn.Close()
cancel()
}
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
// when forwarding commands to the backend. This is critical for tools like
// Ansible that send commands such as:
//
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
//
// The single quotes must be preserved so the backend shell receives the
// subshell expression as a single argument to -c.
func TestSSHProxy_CommandQuoting(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
sshClient, cleanup := setupProxySSHClient(t)
defer cleanup()
// These commands simulate what the SSH protocol delivers as exec payloads.
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
// the local shell strips the outer single quotes, and the SSH exec request
// contains the raw string: /bin/sh -c "( echo hello )"
//
// The proxy must forward this string verbatim. Using session.Command()
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
// the command on the backend.
tests := []struct {
name string
command string
expect string
}{
{
name: "subshell_in_double_quotes",
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
expect: "from-subshell\nouter\n",
},
{
name: "printf_with_special_chars",
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
expect: "hello world\n",
},
{
name: "nested_command_substitution",
command: `/bin/sh -c "echo $(echo nested)"`,
expect: "nested\n",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
session, err := sshClient.NewSession()
require.NoError(t, err)
defer func() { _ = session.Close() }()
var stderrBuf bytes.Buffer
session.Stderr = &stderrBuf
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output(tc.command)
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
if stderrBuf.Len() > 0 {
t.Logf("stderr: %s", stderrBuf.String())
}
require.NoError(t, err, "command should succeed: %s", tc.command)
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
case <-time.After(5 * time.Second):
t.Fatalf("command timed out: %s", tc.command)
}
})
}
}
// setupProxySSHClient creates a full proxy test environment and returns
// an SSH client connected through the proxy to a backend NetBird SSH server.
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
t.Helper()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0},
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer)
mockDaemon := startMockDaemon(t)
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)
origStdin := os.Stdin
origStdout := os.Stdout
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
clientConn, proxyConn := net.Pipe()
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
go func() {
_ = proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err)
client := cryptossh.NewClient(sshClientConn, chans, reqs)
cleanupFn := func() {
_ = client.Close()
_ = clientConn.Close()
cancel()
os.Stdin = origStdin
os.Stdout = origStdout
_ = sshServer.Stop()
mockDaemon.stop()
jwksServer.Close()
}
return client, cleanupFn
}
type mockDaemonServer struct {
proto.UnimplementedDaemonServiceServer
hostKeys map[string][]byte
@@ -150,6 +492,10 @@ func (m *mockDaemon) setHostKey(addr string, pubKey []byte) {
m.impl.hostKeys[addr] = pubKey
}
func (m *mockDaemon) setJWTToken(token string) {
m.impl.jwtToken = token
}
func (m *mockDaemon) stop() {
if m.server != nil {
m.server.Stop()
@@ -162,3 +508,63 @@ func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey {
require.NoError(t, err)
return pubKey
}
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
t.Helper()
privateKey, jwksJSON := generateTestJWKS(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jwksJSON); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
return server, privateKey, server.URL
}
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
publicKey := &privateKey.PublicKey
n := publicKey.N.Bytes()
e := publicKey.E
jwk := nbjwt.JSONWebKey{
Kty: "RSA",
Kid: "test-key-id",
Use: "sig",
N: base64.RawURLEncoding.EncodeToString(n),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
}
jwks := nbjwt.Jwks{
Keys: []nbjwt.JSONWebKey{jwk},
}
jwksJSON, err := json.Marshal(jwks)
require.NoError(t, err)
return privateKey, jwksJSON
}
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
t.Helper()
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": user,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}

View File

@@ -1,66 +0,0 @@
//go:build unix && privileged
package server
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
pd := NewPrivilegeDropper()
config := ExecutorConfig{
UID: 1000,
GID: 1000,
Groups: []uint32{1000, 1001},
WorkingDir: "/home/testuser",
Shell: "/bin/bash",
Command: "ls -la",
}
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, cmd)
// Verify the command is calling netbird ssh exec
assert.Contains(t, cmd.Args, "ssh")
assert.Contains(t, cmd.Args, "exec")
assert.Contains(t, cmd.Args, "--uid")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "--gid")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "--groups")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "1001")
assert.Contains(t, cmd.Args, "--working-dir")
assert.Contains(t, cmd.Args, "/home/testuser")
assert.Contains(t, cmd.Args, "--shell")
assert.Contains(t, cmd.Args, "/bin/bash")
assert.Contains(t, cmd.Args, "--cmd")
assert.Contains(t, cmd.Args, "ls -la")
}
func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
pd := NewPrivilegeDropper()
config := ExecutorConfig{
UID: 1000,
GID: 1000,
Groups: []uint32{1000},
WorkingDir: "/home/testuser",
Shell: "/bin/bash",
Command: "",
}
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, cmd)
// Verify no command mode (command is empty so no --cmd flag)
assert.NotContains(t, cmd.Args, "--cmd")
assert.NotContains(t, cmd.Args, "--interactive")
}

View File

@@ -73,6 +73,61 @@ func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) {
}
}
func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
pd := NewPrivilegeDropper()
config := ExecutorConfig{
UID: 1000,
GID: 1000,
Groups: []uint32{1000, 1001},
WorkingDir: "/home/testuser",
Shell: "/bin/bash",
Command: "ls -la",
}
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, cmd)
// Verify the command is calling netbird ssh exec
assert.Contains(t, cmd.Args, "ssh")
assert.Contains(t, cmd.Args, "exec")
assert.Contains(t, cmd.Args, "--uid")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "--gid")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "--groups")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "1001")
assert.Contains(t, cmd.Args, "--working-dir")
assert.Contains(t, cmd.Args, "/home/testuser")
assert.Contains(t, cmd.Args, "--shell")
assert.Contains(t, cmd.Args, "/bin/bash")
assert.Contains(t, cmd.Args, "--cmd")
assert.Contains(t, cmd.Args, "ls -la")
}
func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
pd := NewPrivilegeDropper()
config := ExecutorConfig{
UID: 1000,
GID: 1000,
Groups: []uint32{1000},
WorkingDir: "/home/testuser",
Shell: "/bin/bash",
Command: "",
}
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, cmd)
// Verify no command mode (command is empty so no --cmd flag)
assert.NotContains(t, cmd.Args, "--cmd")
assert.NotContains(t, cmd.Args, "--interactive")
}
// TestPrivilegeDropper_ActualPrivilegeDrop tests actual privilege dropping
// This test requires root privileges and will be skipped if not running as root
func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) {

View File

@@ -23,11 +23,11 @@ import (
"github.com/stretchr/testify/require"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)

View File

@@ -23,10 +23,10 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/auth/jwt"
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
"github.com/netbirdio/netbird/util/netrelay"
"github.com/netbirdio/netbird/version"
)
@@ -197,14 +197,6 @@ type Config struct {
// HostKey is the SSH server host key in PEM format
HostKeyPEM []byte
// NetstackNet, when non-nil, makes the SSH server listen via the
// supplied userspace network stack instead of an OS socket.
NetstackNet *netstack.Net
// NetworkValidation, when non-zero, restricts inbound connections to
// peers inside the NetBird overlay defined by this WireGuard address.
NetworkValidation wgaddr.Address
}
// SessionInfo contains information about an active SSH session
@@ -216,15 +208,12 @@ type SessionInfo struct {
PortForwards []string
}
// New creates an SSH server instance from the supplied Config. Fields are
// read once at construction; mutating Config afterwards has no effect.
// JWT == nil disables JWT authentication.
// New creates an SSH server instance with the provided host key and optional JWT configuration
// If jwtConfig is nil, JWT authentication is disabled
func New(config *Config) *Server {
s := &Server{
mu: sync.RWMutex{},
hostKeyPEM: config.HostKeyPEM,
netstackNet: config.NetstackNet,
wgAddress: config.NetworkValidation,
sessions: make(map[sessionKey]*sessionState),
pendingAuthJWT: make(map[authKey]string),
remoteForwardListeners: make(map[forwardKey]net.Listener),
@@ -445,6 +434,20 @@ func (s *Server) buildSessionInfo(state *sessionState) SessionInfo {
return info
}
// SetNetstackNet sets the netstack network for userspace networking
func (s *Server) SetNetstackNet(net *netstack.Net) {
s.mu.Lock()
defer s.mu.Unlock()
s.netstackNet = net
}
// SetNetworkValidation configures network-based connection filtering
func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
s.mu.Lock()
defer s.mu.Unlock()
s.wgAddress = addr
}
// UpdateSSHAuth updates the SSH fine-grained access control configuration
// This should be called when network map updates include new SSH auth configuration
func (s *Server) UpdateSSHAuth(config *sshauth.Config) {

View File

@@ -132,19 +132,6 @@ type SSHServerStateOutput struct {
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
}
type VNCSessionOutput struct {
RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"`
Mode string `json:"mode" yaml:"mode"`
Username string `json:"username,omitempty" yaml:"username,omitempty"`
UserID string `json:"userID,omitempty" yaml:"userID,omitempty"`
Initiator string `json:"initiator,omitempty" yaml:"initiator,omitempty"`
}
type VNCServerStateOutput struct {
Enabled bool `json:"enabled" yaml:"enabled"`
Sessions []VNCSessionOutput `json:"sessions" yaml:"sessions"`
}
type OutputOverview struct {
Peers PeersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
@@ -168,7 +155,6 @@ type OutputOverview struct {
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
ProfileName string `json:"profileName" yaml:"profileName"`
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
VNCServerState VNCServerStateOutput `json:"vncServer" yaml:"vncServer"`
}
// ConvertToStatusOutputOverview converts protobuf status to the output overview.
@@ -189,7 +175,6 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
relayOverview := mapRelays(pbFullStatus.GetRelays())
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
vncServerOverview := mapVNCServer(pbFullStatus.GetVncServerState())
peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter)
overview := OutputOverview{
@@ -215,7 +200,6 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
ProfileName: opts.ProfileName,
SSHServerState: sshServerOverview,
VNCServerState: vncServerOverview,
}
if opts.Anonymize {
@@ -297,26 +281,6 @@ func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput {
}
}
func mapVNCServer(state *proto.VNCServerState) VNCServerStateOutput {
if state == nil {
return VNCServerStateOutput{Sessions: []VNCSessionOutput{}}
}
sessions := make([]VNCSessionOutput, 0, len(state.GetSessions()))
for _, sess := range state.GetSessions() {
sessions = append(sessions, VNCSessionOutput{
RemoteAddress: sess.GetRemoteAddress(),
Mode: sess.GetMode(),
Username: sess.GetUsername(),
UserID: sess.GetUserID(),
Initiator: sess.GetInitiator(),
})
}
return VNCServerStateOutput{
Enabled: state.GetEnabled(),
Sessions: sessions,
}
}
func mapPeers(
peers []*proto.PeerState,
statusFilter string,
@@ -581,26 +545,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
}
}
vncServerStatus := "Disabled"
if o.VNCServerState.Enabled {
vncSessionCount := len(o.VNCServerState.Sessions)
if vncSessionCount > 0 {
sessionWord := "session"
if vncSessionCount > 1 {
sessionWord = "sessions"
}
vncServerStatus = fmt.Sprintf("Enabled (%d active %s)", vncSessionCount, sessionWord)
} else {
vncServerStatus = "Enabled"
}
if showSSHSessions && vncSessionCount > 0 {
for _, sess := range o.VNCServerState.Sessions {
vncServerStatus += "\n " + formatVNCSessionLine(sess)
}
}
}
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
var forwardingRulesString string
@@ -647,7 +591,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
"Quantum resistance: %s\n"+
"Lazy connection: %s\n"+
"SSH Server: %s\n"+
"VNC Server: %s\n"+
"Networks: %s\n"+
"%s"+
"Peers count: %s\n",
@@ -667,7 +610,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
rosenpassEnabledStatus,
lazyConnectionEnabledStatus,
sshServerStatus,
vncServerStatus,
networks,
forwardingRulesString,
peersCountString,
@@ -1027,26 +969,6 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *PeerStateDetailOutput) {
}
}
// formatVNCSessionLine renders a single VNC session row for the detailed
// status output. The leading slot identifies the initiator (display name
// when known, hashed UserID otherwise); the post-arrow slot is the OS
// user the session targets and is omitted in attach mode where the
// destination is the current console user (unknown to the daemon).
func formatVNCSessionLine(sess VNCSessionOutput) string {
who := sess.Initiator
if who == "" {
who = sess.UserID
}
prefix := sess.RemoteAddress
if who != "" {
prefix = fmt.Sprintf("%s@%s", who, sess.RemoteAddress)
}
if sess.Username != "" {
return fmt.Sprintf("[%s -> %s] mode=%s", prefix, sess.Username, sess.Mode)
}
return fmt.Sprintf("[%s] mode=%s", prefix, sess.Mode)
}
func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
for i, peer := range overview.Peers.Details {
peer := peer
@@ -1067,19 +989,6 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
overview.Relays.Details[i] = detail
}
anonymizeNSServerGroups(a, overview)
for i, route := range overview.Networks {
overview.Networks[i] = a.AnonymizeRoute(route)
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
anonymizeEvents(a, overview)
anonymizeServerSessions(a, overview)
}
func anonymizeNSServerGroups(a *anonymize.Anonymizer, overview *OutputOverview) {
for i, nsGroup := range overview.NSServerGroups {
for j, domain := range nsGroup.Domains {
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
@@ -1091,9 +1000,13 @@ func anonymizeNSServerGroups(a *anonymize.Anonymizer, overview *OutputOverview)
}
}
}
}
func anonymizeEvents(a *anonymize.Anonymizer, overview *OutputOverview) {
for i, route := range overview.Networks {
overview.Networks[i] = a.AnonymizeRoute(route)
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
for i, event := range overview.Events {
overview.Events[i].Message = a.AnonymizeString(event.Message)
overview.Events[i].UserMessage = a.AnonymizeString(event.UserMessage)
@@ -1102,24 +1015,13 @@ func anonymizeEvents(a *anonymize.Anonymizer, overview *OutputOverview) {
event.Metadata[k] = a.AnonymizeString(v)
}
}
}
func anonymizeRemoteAddress(a *anonymize.Anonymizer, addr string) string {
if host, port, err := net.SplitHostPort(addr); err == nil {
return fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
}
return a.AnonymizeIPString(addr)
}
func anonymizeServerSessions(a *anonymize.Anonymizer, overview *OutputOverview) {
for i, session := range overview.SSHServerState.Sessions {
overview.SSHServerState.Sessions[i].RemoteAddress = anonymizeRemoteAddress(a, session.RemoteAddress)
if host, port, err := net.SplitHostPort(session.RemoteAddress); err == nil {
overview.SSHServerState.Sessions[i].RemoteAddress = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
} else {
overview.SSHServerState.Sessions[i].RemoteAddress = a.AnonymizeIPString(session.RemoteAddress)
}
overview.SSHServerState.Sessions[i].Command = a.AnonymizeString(session.Command)
}
for i, sess := range overview.VNCServerState.Sessions {
overview.VNCServerState.Sessions[i].RemoteAddress = anonymizeRemoteAddress(a, sess.RemoteAddress)
overview.VNCServerState.Sessions[i].Username = a.AnonymizeString(sess.Username)
overview.VNCServerState.Sessions[i].UserID = a.AnonymizeString(sess.UserID)
overview.VNCServerState.Sessions[i].Initiator = a.AnonymizeString(sess.Initiator)
}
}

View File

@@ -242,10 +242,6 @@ var overview = OutputOverview{
Enabled: false,
Sessions: []SSHSessionOutput{},
},
VNCServerState: VNCServerStateOutput{
Enabled: false,
Sessions: []VNCSessionOutput{},
},
}
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
@@ -411,10 +407,6 @@ func TestParsingToJSON(t *testing.T) {
"sshServer":{
"enabled":false,
"sessions":[]
},
"vncServer":{
"enabled":false,
"sessions":[]
}
}`
// @formatter:on
@@ -525,9 +517,6 @@ profileName: ""
sshServer:
enabled: false
sessions: []
vncServer:
enabled: false
sessions: []
`
assert.Equal(t, expectedYAML, yaml)
@@ -598,7 +587,6 @@ Wireguard port: %d
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
VNC Server: Disabled
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion, overview.WgPort)
@@ -625,7 +613,6 @@ Wireguard port: 51820
Quantum resistance: false
Lazy connection: false
SSH Server: Disabled
VNC Server: Disabled
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`

View File

@@ -63,7 +63,6 @@ type Info struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed bool
ServerVNCAllowed bool
DisableClientRoutes bool
DisableServerRoutes bool
@@ -85,7 +84,6 @@ type Info struct {
func (i *Info) SetFlags(
rosenpassEnabled, rosenpassPermissive bool,
serverSSHAllowed *bool,
serverVNCAllowed *bool,
disableClientRoutes, disableServerRoutes,
disableDNS, disableFirewall, blockLANAccess, blockInbound, disableIPv6, lazyConnectionEnabled bool,
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
@@ -96,9 +94,6 @@ func (i *Info) SetFlags(
if serverSSHAllowed != nil {
i.ServerSSHAllowed = *serverSSHAllowed
}
if serverVNCAllowed != nil {
i.ServerVNCAllowed = *serverVNCAllowed
}
i.DisableClientRoutes = disableClientRoutes
i.DisableServerRoutes = disableServerRoutes

View File

@@ -1,196 +0,0 @@
//go:build privileged && (linux || darwin)
// Package privileged provides a self-hosting harness that runs the repo's
// privileged-tagged test suite inside a --privileged --cap-add=NET_ADMIN
// container, so developers can exercise the root/system-mutating tests on a
// non-root host with a single `go test` invocation.
package privileged
import (
"bytes"
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
"time"
"github.com/moby/moby/api/types/container"
"github.com/ory/dockertest/v4"
)
// containerImage / containerTag match the image used by the CI privileged job
// (.github/workflows/golang-test-linux.yml, test_client_on_docker).
const (
containerImage = "golang"
containerTag = "1.25-alpine"
)
const (
containerWorkdir = "/app"
containerGoCache = "/root/.cache/go-build"
containerGoModCache = "/go/pkg/mod"
)
// alpinePackages are the build/runtime deps the privileged tests need, mirroring
// the CI container setup.
const alpinePackages = "ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base"
// privilegedTestPackages is the package list the suite runs, excluding the
// server-side trees and UI/upload helpers, matching the CI Docker job's filter.
const privilegedTestPackages = `go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server`
// testWriter forwards container output to the test log line by line.
type testWriter struct{ t *testing.T }
func (w testWriter) Write(p []byte) (int, error) {
for _, line := range strings.Split(strings.TrimRight(string(p), "\n"), "\n") {
w.t.Log(line)
}
return len(p), nil
}
// TestRunPrivilegedSuiteInDocker spins up a privileged container, mounts the repo,
// and runs `go test -tags 'devcert privileged'` inside it. When already running
// inside that container (DOCKER_CI=true) it returns immediately so the real
// privileged tests in the suite execute in place instead of recursing.
func TestRunPrivilegedSuiteInDocker(t *testing.T) {
if os.Getenv("DOCKER_CI") == "true" {
t.Skip("inside privileged container, skipping container spawn; privileged tests run in place")
}
repoRoot, err := findRepoRoot()
if err != nil {
t.Fatalf("locate repo root: %v", err)
}
goCache, goModCache := hostGoCaches(t)
// dockertest reads DOCKER_HOST; point it at the active context's socket when
// the default one is absent (macOS Docker Desktop, Colima, OrbStack).
if host := dockerHost(); host != "" {
t.Setenv("DOCKER_HOST", host)
}
// NewPoolT registers container cleanup via t.Cleanup automatically.
pool := dockertest.NewPoolT(t, "", dockertest.WithMaxWait(30*time.Minute))
// Keep the container alive so the suite runs via Exec, which yields a clean
// exit code (the v4 Resource API exposes no container wait/exit-code).
resource := pool.RunT(t, containerImage,
dockertest.WithTag(containerTag),
dockertest.WithWorkingDir(containerWorkdir),
dockertest.WithMounts([]string{
repoRoot + ":" + containerWorkdir,
goCache + ":" + containerGoCache,
goModCache + ":" + containerGoModCache,
}),
dockertest.WithEnv([]string{
"CGO_ENABLED=1",
"CI=true",
"DOCKER_CI=true",
"CONTAINER=true",
"GOCACHE=" + containerGoCache,
"GOMODCACHE=" + containerGoModCache,
}),
dockertest.WithCmd([]string{"sleep", "infinity"}),
dockertest.WithHostConfig(func(hc *container.HostConfig) {
hc.Privileged = true
hc.CapAdd = []string{"NET_ADMIN"}
}),
dockertest.WithoutReuse(),
)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
defer cancel()
result, err := resource.Exec(ctx, []string{"sh", "-c", buildTestScript()})
if err != nil {
t.Fatalf("run privileged suite in container: %v", err)
}
w := testWriter{t}
_, _ = w.Write([]byte(result.StdOut))
_, _ = w.Write([]byte(result.StdErr))
if result.ExitCode != 0 {
t.Fatalf("privileged test suite failed in container (exit code %d)", result.ExitCode)
}
}
// findRepoRoot walks up from the test's working directory to the module root.
func findRepoRoot() (string, error) {
dir, err := os.Getwd()
if err != nil {
return "", err
}
for {
if _, statErr := os.Stat(filepath.Join(dir, "go.mod")); statErr == nil {
return dir, nil
}
parent := filepath.Dir(dir)
if parent == dir {
return "", fmt.Errorf("go.mod not found above %s", dir)
}
dir = parent
}
}
// dockerHost returns a DOCKER_HOST override when the default socket is missing.
// An empty result means the caller should leave DOCKER_HOST untouched (it is
// already set, or the default unix socket exists). When neither is present
// (common on macOS Docker Desktop, Colima and OrbStack, which use a per-user
// socket), it resolves the active docker context's endpoint.
func dockerHost() string {
if os.Getenv("DOCKER_HOST") != "" {
return ""
}
if _, err := os.Stat("/var/run/docker.sock"); err == nil {
return ""
}
out, err := exec.Command("docker", "context", "inspect", "-f", "{{.Endpoints.docker.Host}}").Output()
if err != nil {
return ""
}
return strings.TrimSpace(string(out))
}
// hostGoCaches resolves the host GOCACHE/GOMODCACHE so the container reuses the
// existing build/module cache for speed.
func hostGoCaches(t *testing.T) (string, string) {
t.Helper()
return goEnv(t, "GOCACHE"), goEnv(t, "GOMODCACHE")
}
func goEnv(t *testing.T, key string) string {
t.Helper()
var out bytes.Buffer
cmd := exec.Command("go", "env", key)
cmd.Stdout = &out
if err := cmd.Run(); err != nil {
t.Fatalf("go env %s: %v", key, err)
}
return strings.TrimSpace(out.String())
}
// buildTestScript builds the in-container command. PRIV_PKGS overrides the package
// list (default: the full filtered set); PRIV_RUN adds a -run test-name filter.
// Both empty reproduces the full privileged suite.
func buildTestScript() string {
pkgs := privilegedTestPackages + " | xargs"
if p := os.Getenv("PRIV_PKGS"); p != "" {
pkgs = "echo " + p + " | xargs"
}
runFilter := ""
if r := os.Getenv("PRIV_RUN"); r != "" {
runFilter = "-run '" + r + "' "
}
return fmt.Sprintf(
"apk update >/dev/null && apk add --no-cache %s >/dev/null && %s go test -buildvcs=false -tags 'devcert privileged' %s-v -timeout 20m -p 1",
alpinePackages, pkgs, runFilter,
)
}

View File

@@ -1,259 +0,0 @@
//go:build !(linux && 386)
package main
import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"strings"
"time"
"fyne.io/fyne/v2"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/widget"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/approval"
"github.com/netbirdio/netbird/client/proto"
)
// Approval metadata that is remote-peer or dashboard controlled is passed to
// the forked netbird-ui via environment variables rather than argv, so it is
// not exposed to other local users through ps.
const (
envApprovalInitiator = "NB_APPROVAL_INITIATOR"
envApprovalPeerName = "NB_APPROVAL_PEER_NAME"
envApprovalSourceIP = "NB_APPROVAL_SOURCE_IP"
envApprovalUsername = "NB_APPROVAL_USERNAME"
envApprovalKeyFingerprint = "NB_APPROVAL_KEY_FINGERPRINT"
envApprovalSubject = "NB_APPROVAL_SUBJECT"
)
// handleApprovalEvent forks a netbird-ui child process to render the
// dialog on its own fyne main loop. Top-level windows opened from a
// background goroutine of the tray process don't render reliably on
// Linux/GTK, so the rest of the UI (settings, login URL, update) uses
// the same fork pattern.
func (s *serviceClient) handleApprovalEvent(ev *proto.SystemEvent) {
if ev == nil || ev.Category != proto.SystemEvent_APPROVAL {
return
}
requestID := ev.Metadata["request_id"]
if requestID == "" {
log.Warnf("approval event missing request_id: %v", ev.Metadata)
return
}
// Only the request id, kind, and deadline stay on argv: they are
// daemon-issued and non-sensitive. The remote-influenced fields go
// through the child's environment.
args := []string{
"--approval-request-id=" + requestID,
"--approval-kind=" + ev.Metadata["kind"],
"--approval-expires-at=" + ev.Metadata["expires_at"],
}
env := append(os.Environ(),
envApprovalInitiator+"="+ev.Metadata["initiator"],
envApprovalPeerName+"="+ev.Metadata["peer_name"],
envApprovalSourceIP+"="+ev.Metadata["source_ip"],
envApprovalUsername+"="+ev.Metadata["username"],
envApprovalKeyFingerprint+"="+ev.Metadata["peer_pubkey"],
envApprovalSubject+"="+ev.UserMessage,
)
go s.runApprovalCommand(s.ctx, env, args)
}
// runApprovalCommand forks netbird-ui to render the approval dialog,
// inheriting the parent environment plus the approval-specific variables. It
// mirrors runSelfCommand but sets cmd.Env so the sensitive metadata never
// appears on the child's argv.
func (s *serviceClient) runApprovalCommand(ctx context.Context, env, args []string) {
proc, err := os.Executable()
if err != nil {
log.Errorf("get executable path: %v", err)
return
}
cmdArgs := append([]string{"--approval=true", "--daemon-addr=" + s.addr}, args...)
cmd := exec.CommandContext(ctx, proc, cmdArgs...)
cmd.Env = env
if out := s.attachOutput(cmd); out != nil {
defer func() {
if err := out.Close(); err != nil {
log.Errorf("close log file %s: %v", s.logFile, err)
}
}()
}
log.Printf("running approval command: %s", cmd.String())
if err := cmd.Run(); err != nil {
var exitErr *exec.ExitError
if errors.As(err, &exitErr) {
log.Printf("approval command failed with exit code %d", exitErr.ExitCode())
}
}
}
// showApprovalUI runs the dialog on the forked process's fyne main loop
// and forwards the user's decision to the daemon via RespondApproval.
func (s *serviceClient) showApprovalUI(req approvalRequest) {
w := s.app.NewWindow(approvalTitle(req.kind))
w.Resize(fyne.NewSize(480, 260))
w.CenterOnScreen()
w.RequestFocus()
var rows []string
if req.initiator != "" {
// The display name comes from the management dashboard and is
// not cryptographically asserted by the connecting client. The
// key fingerprint that follows IS: it's the Noise_IK static
// public key the client just proved possession of. Show both
// so the user can sanity-check that "Alice" is really the
// Alice they trust.
rows = append(rows, "From user: "+req.initiator)
}
if fp := approval.ShortKeyFingerprint(req.keyFingerprint); fp != "" {
rows = append(rows, "Key fp: "+fp)
}
if req.peerName != "" {
rows = append(rows, "Via peer: "+req.peerName)
}
if req.sourceIP != "" && req.sourceIP != req.peerName {
rows = append(rows, "Source IP: "+req.sourceIP)
}
if req.username != "" {
rows = append(rows, "OS user: "+req.username)
}
if len(rows) == 0 {
rows = []string{"Remote: " + req.displayPeer()}
}
body := strings.Join(rows, "\n")
bodyLabel := widget.NewLabel(body)
bodyLabel.Wrapping = fyne.TextWrapWord
countdown := widget.NewLabel("")
deadline := req.deadline()
updateCountdown := func() {
remaining := time.Until(deadline).Round(time.Second)
if remaining < 0 {
remaining = 0
}
countdown.SetText(fmt.Sprintf("Auto-deny in %s", remaining))
}
updateCountdown()
type outcome struct {
accept bool
viewOnly bool
}
decided := make(chan outcome, 1)
decide := func(o outcome) {
select {
case decided <- o:
default:
}
}
allow := widget.NewButton("Allow", func() { decide(outcome{accept: true}) })
allow.Importance = widget.HighImportance
allowView := widget.NewButton("Allow (view only)", func() { decide(outcome{accept: true, viewOnly: true}) })
deny := widget.NewButton("Deny", func() { decide(outcome{accept: false}) })
header := widget.NewLabelWithStyle(req.subject, fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
buttonRow := container.NewGridWithColumns(3, allow, allowView, deny)
info := container.NewVBox(header, widget.NewSeparator(), bodyLabel, widget.NewSeparator(), countdown)
w.SetContent(container.NewPadded(container.NewBorder(nil, buttonRow, nil, nil, info)))
w.SetCloseIntercept(func() { decide(outcome{}) })
go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for range ticker.C {
if time.Until(deadline) <= 0 {
decide(outcome{})
return
}
fyne.Do(updateCountdown)
}
}()
go func() {
o := <-decided
s.sendApprovalResponse(req.requestID, o.accept, o.viewOnly)
fyne.Do(func() {
w.Close()
s.app.Quit()
})
}()
w.Show()
}
func (s *serviceClient) sendApprovalResponse(requestID string, accept, viewOnly bool) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Warnf("approval response: get daemon client: %v", err)
return
}
ctx, cancel := context.WithTimeout(s.ctx, defaultFailTimeout)
defer cancel()
if _, err := conn.RespondApproval(ctx, &proto.RespondApprovalRequest{
RequestId: requestID,
Accept: accept,
ViewOnly: viewOnly,
}); err != nil {
log.Warnf("approval response: %v", err)
}
}
// approvalRequest is the parsed --approval-* CLI args that the forked
// dialog process consumes.
type approvalRequest struct {
requestID string
kind string
initiator string
peerName string
sourceIP string
username string
subject string
expiresAt string
keyFingerprint string
}
func (r approvalRequest) displayPeer() string {
switch {
case r.initiator != "":
return r.initiator
case r.peerName != "":
return r.peerName
case r.sourceIP != "":
return r.sourceIP
default:
return "unknown peer"
}
}
// deadline returns the wall-clock auto-deny moment. Falls back to a short
// local window when the daemon's expires_at is missing/unparsable, so a
// stale value never leaves the dialog open indefinitely.
func (r approvalRequest) deadline() time.Time {
if t, err := time.Parse(time.RFC3339, r.expiresAt); err == nil {
return t
}
return time.Now().Add(13 * time.Second)
}
func approvalTitle(kind string) string {
switch kind {
case "vnc":
return "Allow VNC Connection?"
case "ssh":
return "Allow SSH Connection?"
default:
return "Allow Incoming Connection?"
}
}

View File

@@ -112,25 +112,13 @@ func main() {
showQuickActions: flags.showQuickActions,
showUpdate: flags.showUpdate,
showUpdateVersion: flags.showUpdateVersion,
showApproval: flags.showApproval,
approvalRequest: approvalRequest{
requestID: flags.approvalRequestID,
kind: flags.approvalKind,
initiator: os.Getenv(envApprovalInitiator),
peerName: os.Getenv(envApprovalPeerName),
sourceIP: os.Getenv(envApprovalSourceIP),
username: os.Getenv(envApprovalUsername),
subject: os.Getenv(envApprovalSubject),
expiresAt: flags.approvalExpiresAt,
keyFingerprint: os.Getenv(envApprovalKeyFingerprint),
},
})
// Watch for theme/settings changes to update the icon.
go watchSettingsChanges(a, client)
// Run in window mode if any UI flag was set.
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions || flags.showUpdate || flags.showApproval {
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions || flags.showUpdate {
a.Run()
return
}
@@ -167,11 +155,6 @@ type cliFlags struct {
saveLogsInFile bool
showUpdate bool
showUpdateVersion string
showApproval bool
approvalRequestID string
approvalKind string
approvalExpiresAt string
}
// parseFlags reads and returns all needed command-line flags.
@@ -193,10 +176,6 @@ func parseFlags() *cliFlags {
flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window")
flag.BoolVar(&flags.showUpdate, "update", false, "show update progress window")
flag.StringVar(&flags.showUpdateVersion, "update-version", "", "version to update to")
flag.BoolVar(&flags.showApproval, "approval", false, "show inbound-connection approval prompt window")
flag.StringVar(&flags.approvalRequestID, "approval-request-id", "", "approval prompt: daemon-issued request id")
flag.StringVar(&flags.approvalKind, "approval-kind", "", "approval prompt: subsystem kind (vnc, ssh, ...)")
flag.StringVar(&flags.approvalExpiresAt, "approval-expires-at", "", "approval prompt: RFC3339 deadline at which the daemon auto-denies")
flag.Parse()
return &flags
}
@@ -285,7 +264,6 @@ type serviceClient struct {
mQuit *systray.MenuItem
mNetworks *systray.MenuItem
mAllowSSH *systray.MenuItem
mAllowVNC *systray.MenuItem
mAutoConnect *systray.MenuItem
mEnableRosenpass *systray.MenuItem
mLazyConnEnabled *systray.MenuItem
@@ -324,8 +302,6 @@ type serviceClient struct {
sEnableSSHRemotePortForward *widget.Check
sDisableSSHAuth *widget.Check
iSSHJWTCacheTTL *widget.Entry
sServerVNCAllowed *widget.Check
sDisableVNCApproval *widget.Check
// observable settings over corresponding iMngURL and iPreSharedKey values.
managementURL string
@@ -347,8 +323,6 @@ type serviceClient struct {
enableSSHRemotePortForward bool
disableSSHAuth bool
sshJWTCacheTTL int
serverVNCAllowed bool
disableVNCApproval bool
connected bool
daemonVersion string
@@ -407,8 +381,6 @@ type newServiceClientArgs struct {
showQuickActions bool
showUpdate bool
showUpdateVersion string
showApproval bool
approvalRequest approvalRequest
}
// newServiceClient instance constructor
@@ -456,8 +428,6 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
}
case args.showUpdate:
s.showUpdateProgress(ctx, args.showUpdateVersion)
case args.showApproval:
s.showApprovalUI(args.approvalRequest)
}
return s
@@ -538,8 +508,6 @@ func (s *serviceClient) showSettingsUI() {
s.sEnableSSHRemotePortForward = widget.NewCheck("Enable SSH Remote Port Forwarding", nil)
s.sDisableSSHAuth = widget.NewCheck("Disable SSH Authentication", nil)
s.iSSHJWTCacheTTL = widget.NewEntry()
s.sServerVNCAllowed = widget.NewCheck("Allow embedded VNC server on this peer", nil)
s.sDisableVNCApproval = widget.NewCheck("Skip per-connection approval prompt for VNC", nil)
s.wSettings.SetContent(s.getSettingsForm())
s.wSettings.Resize(fyne.NewSize(600, 400))
@@ -652,8 +620,7 @@ func (s *serviceClient) hasSettingsChanged(iMngURL string, port, mtu int64) bool
s.disableServerRoutes != s.sDisableServerRoutes.Checked ||
s.disableIPv6 != s.sDisableIPv6.Checked ||
s.blockLANAccess != s.sBlockLANAccess.Checked ||
s.hasSSHChanges() ||
s.hasVNCChanges()
s.hasSSHChanges()
}
func (s *serviceClient) applySettingsChanges(iMngURL string, port, mtu int64) error {
@@ -712,8 +679,6 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
req.EnableSSHLocalPortForwarding = &s.sEnableSSHLocalPortForward.Checked
req.EnableSSHRemotePortForwarding = &s.sEnableSSHRemotePortForward.Checked
req.DisableSSHAuth = &s.sDisableSSHAuth.Checked
req.ServerVNCAllowed = &s.sServerVNCAllowed.Checked
req.DisableVNCApproval = &s.sDisableVNCApproval.Checked
sshJWTCacheTTLText := strings.TrimSpace(s.iSSHJWTCacheTTL.Text)
if sshJWTCacheTTLText != "" {
@@ -782,12 +747,10 @@ func (s *serviceClient) getSettingsForm() fyne.CanvasObject {
connectionForm := s.getConnectionForm()
networkForm := s.getNetworkForm()
sshForm := s.getSSHForm()
vncForm := s.getVNCForm()
tabs := container.NewAppTabs(
container.NewTabItem("Connection", connectionForm),
container.NewTabItem("Network", networkForm),
container.NewTabItem("SSH", sshForm),
container.NewTabItem("VNC", vncForm),
)
saveButton := widget.NewButtonWithIcon("Save", theme.ConfirmIcon(), s.saveSettings)
saveButton.Importance = widget.HighImportance
@@ -828,15 +791,6 @@ func (s *serviceClient) getSSHForm() *widget.Form {
}
}
func (s *serviceClient) getVNCForm() *widget.Form {
return &widget.Form{
Items: []*widget.FormItem{
{Text: "Allow VNC Server", Widget: s.sServerVNCAllowed},
{Text: "Disable Connection Approval Prompt", Widget: s.sDisableVNCApproval},
},
}
}
func (s *serviceClient) hasSSHChanges() bool {
currentSSHJWTCacheTTL := s.sshJWTCacheTTL
if text := strings.TrimSpace(s.iSSHJWTCacheTTL.Text); text != "" {
@@ -855,11 +809,6 @@ func (s *serviceClient) hasSSHChanges() bool {
s.sshJWTCacheTTL != currentSSHJWTCacheTTL
}
func (s *serviceClient) hasVNCChanges() bool {
return s.serverVNCAllowed != s.sServerVNCAllowed.Checked ||
s.disableVNCApproval != s.sDisableVNCApproval.Checked
}
func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
@@ -1143,7 +1092,6 @@ func (s *serviceClient) onTrayReady() {
s.mSettings = systray.AddMenuItem("Settings", disabledMenuDescr)
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false)
s.mAllowVNC = s.mSettings.AddSubMenuItemCheckbox("Allow VNC", allowVNCMenuDescr, false)
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false)
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false)
s.mLazyConnEnabled = s.mSettings.AddSubMenuItemCheckbox("Enable Lazy Connections", lazyConnMenuDescr, false)
@@ -1224,7 +1172,6 @@ func (s *serviceClient) onTrayReady() {
s.eventManager = event.NewManager(s.notifier, s.addr)
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
s.eventManager.AddHandler(s.handleApprovalEvent)
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
if event.Category == proto.SystemEvent_SYSTEM {
s.updateExitNodes()
@@ -1507,12 +1454,6 @@ func (s *serviceClient) getSrvConfig() {
if cfg.SSHJWTCacheTTL != nil {
s.sshJWTCacheTTL = *cfg.SSHJWTCacheTTL
}
if cfg.ServerVNCAllowed != nil {
s.serverVNCAllowed = *cfg.ServerVNCAllowed
}
if cfg.DisableVNCApproval != nil {
s.disableVNCApproval = *cfg.DisableVNCApproval
}
if s.showAdvancedSettings {
s.iMngURL.SetText(s.managementURL)
@@ -1568,12 +1509,6 @@ func (s *serviceClient) getSrvConfig() {
if cfg.SSHJWTCacheTTL != nil {
s.iSSHJWTCacheTTL.SetText(strconv.Itoa(*cfg.SSHJWTCacheTTL))
}
if cfg.ServerVNCAllowed != nil {
s.sServerVNCAllowed.SetChecked(*cfg.ServerVNCAllowed)
}
if cfg.DisableVNCApproval != nil {
s.sDisableVNCApproval.SetChecked(*cfg.DisableVNCApproval)
}
}
// MDM locks must run before the mNotifications-nil early return:
@@ -1640,8 +1575,6 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
config.DisableAutoConnect = cfg.DisableAutoConnect
config.ServerSSHAllowed = &cfg.ServerSSHAllowed
config.ServerVNCAllowed = &cfg.ServerVNCAllowed
config.DisableVNCApproval = &cfg.DisableVNCApproval
config.RosenpassEnabled = cfg.RosenpassEnabled
config.RosenpassPermissive = cfg.RosenpassPermissive
config.DisableNotifications = &cfg.DisableNotifications
@@ -1737,12 +1670,6 @@ func (s *serviceClient) loadSettings() {
s.mAllowSSH.Uncheck()
}
if cfg.ServerVNCAllowed {
s.mAllowVNC.Check()
} else {
s.mAllowVNC.Uncheck()
}
if cfg.DisableAutoConnect {
s.mAutoConnect.Uncheck()
} else {
@@ -1905,7 +1832,6 @@ func (s *serviceClient) applyMDMLocksToSettingsForm(set map[string]bool) {
func (s *serviceClient) updateConfig() error {
disableAutoStart := !s.mAutoConnect.Checked()
sshAllowed := s.mAllowSSH.Checked()
vncAllowed := s.mAllowVNC.Checked()
rosenpassEnabled := s.mEnableRosenpass.Checked()
lazyConnectionEnabled := s.mLazyConnEnabled.Checked()
blockInbound := s.mBlockInbound.Checked()
@@ -1934,7 +1860,6 @@ func (s *serviceClient) updateConfig() error {
Username: currUser.Username,
DisableAutoConnect: &disableAutoStart,
ServerSSHAllowed: &sshAllowed,
ServerVNCAllowed: &vncAllowed,
RosenpassEnabled: &rosenpassEnabled,
LazyConnectionEnabled: &lazyConnectionEnabled,
BlockInbound: &blockInbound,

View File

@@ -2,7 +2,6 @@ package main
const (
allowSSHMenuDescr = "Allow SSH connections"
allowVNCMenuDescr = "Allow embedded VNC server"
autoConnectMenuDescr = "Connect automatically when the service starts"
quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass"
lazyConnMenuDescr = "[Experimental] Enable lazy connections"

View File

@@ -112,7 +112,7 @@ func (e *Manager) handleEvent(event *proto.SystemEvent) {
handlers := slices.Clone(e.handlers)
e.mu.Unlock()
if event.UserMessage != "" && (enabled || event.Severity == proto.SystemEvent_CRITICAL) && !isV6DefaultRoutePartner(event) && event.Category != proto.SystemEvent_APPROVAL {
if event.UserMessage != "" && (enabled || event.Severity == proto.SystemEvent_CRITICAL) && !isV6DefaultRoutePartner(event) {
title := e.getEventTitle(event)
body := event.UserMessage
id := event.Metadata["id"]

View File

@@ -39,8 +39,6 @@ func (h *eventHandler) listen(ctx context.Context) {
h.handleDisconnectClick()
case <-h.client.mAllowSSH.ClickedCh:
h.handleAllowSSHClick()
case <-h.client.mAllowVNC.ClickedCh:
h.handleAllowVNCClick()
case <-h.client.mAutoConnect.ClickedCh:
h.handleAutoConnectClick()
case <-h.client.mEnableRosenpass.ClickedCh:
@@ -136,15 +134,6 @@ func (h *eventHandler) handleAllowSSHClick() {
}
func (h *eventHandler) handleAllowVNCClick() {
h.toggleCheckbox(h.client.mAllowVNC)
if err := h.updateConfigWithErr(); err != nil {
h.toggleCheckbox(h.client.mAllowVNC) // revert checkbox state on error
log.Errorf("failed to update config: %v", err)
h.client.notifier.Send("Error", "Failed to update VNC settings")
}
}
func (h *eventHandler) handleAutoConnectClick() {
h.toggleCheckbox(h.client.mAutoConnect)
if err := h.updateConfigWithErr(); err != nil {

View File

@@ -1,31 +0,0 @@
// Package vnc holds shared constants for the NetBird embedded VNC stack
// so non-server consumers (CLI capture, debug tooling) can refer to the
// well-known ports without depending on internal engine packages.
package vnc
// External and internal listen ports for the embedded VNC server.
// ExternalPort is what dashboard / browser clients see; the daemon
// DNATs it to InternalPort, where the in-process VNC server actually
// listens. Both flow over the WireGuard interface. AgentLegacyPort is
// the TCP port the per-session agent used before it switched to Unix
// sockets; kept here so packet captures from older builds still get
// tagged, and so any future on-wire agent variant has a reserved port.
const (
ExternalPort uint16 = 5900
InternalPort uint16 = 25900
AgentLegacyPort uint16 = 15900
)
// WellKnownPorts is the unordered set of ports a packet capture should
// treat as carrying NetBird VNC traffic.
var WellKnownPorts = [...]uint16{ExternalPort, InternalPort, AgentLegacyPort}
// IsWellKnownPort reports whether port matches any of WellKnownPorts.
func IsWellKnownPort(port uint16) bool {
for _, p := range WellKnownPorts {
if port == p {
return true
}
}
return false
}

View File

@@ -1,434 +0,0 @@
//go:build darwin && !ios
package server
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"os"
"os/exec"
"strconv"
"sync"
"syscall"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/configs"
)
// darwinAgentManager spawns a per-user VNC agent on demand and keeps it
// alive across multiple client connections within the same console-user
// session. A new agent is spawned the first time a client connects, or
// whenever the console user changes underneath us.
//
// Lifecycle is lazy by design: a daemon that never receives a VNC
// connection never spawns anything. The trade-off versus an eager spawn
// (the Windows model) is that the first VNC client pays the launchctl
// asuser + listen-readiness wait, ~hundreds of milliseconds in practice.
// That cost only repeats on user switch.
type darwinAgentManager struct {
mu sync.Mutex
authToken string
socketPath string
uid uint32
running bool
}
func newDarwinAgentManager(ctx context.Context) *darwinAgentManager {
m := &darwinAgentManager{}
go m.watchConsoleUser(ctx)
return m
}
// agentSocketName is the file name inside the per-uid socket directory
// the agent binds. The directory itself is created and chowned by the
// daemon (see prepareAgentSocketDir) so a non-root local user cannot
// pre-create or symlink the path before the agent listens.
const agentSocketName = "agent.sock"
// watchConsoleUser kills the cached agent whenever the console user
// changes (logout, fast user switch, login window). Without it the daemon
// keeps proxying to an agent whose TCC grant and WindowServer access
// belong to a user who is no longer at the screen, so the new user only
// ever sees the locked-screen wallpaper. Killing the agent breaks the
// loopback TCP that the daemon proxies into, the client disconnects, and
// the next reconnect runs ensure() against the new console uid.
func (m *darwinAgentManager) watchConsoleUser(ctx context.Context) {
t := time.NewTicker(2 * time.Second)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
uid, err := consoleUserID()
m.mu.Lock()
if !m.running {
m.mu.Unlock()
continue
}
if err != nil || uid != m.uid {
prev := m.uid
m.killLocked()
m.mu.Unlock()
if err != nil {
log.Infof("console user gone (was uid=%d): %v; agent stopped", prev, err)
} else {
log.Infof("console user changed %d -> %d; agent stopped, will respawn on next connect", prev, uid)
}
continue
}
m.mu.Unlock()
}
}
}
// Resolve spawns or respawns the per-user agent process as needed and
// returns its Unix-socket path, shared token, and the uid the agent was
// spawned under (so the daemon can validate peer credentials before
// dispatching the token). Each call is serialized so concurrent VNC
// clients share the same agent.
func (m *darwinAgentManager) Resolve(ctx context.Context) (string, string, uint32, error) {
consoleUID, err := consoleUserID()
if err != nil {
return "", "", 0, fmt.Errorf("no console user: %w", err)
}
m.mu.Lock()
defer m.mu.Unlock()
if m.running && m.uid == consoleUID && vncAgentRunning() {
return m.socketPath, m.authToken, m.uid, nil
}
m.killLocked()
// Reap stray agents so the new token is the only accepted one.
killAllVNCAgents()
socketDir, err := prepareAgentSocketDir(consoleUID)
if err != nil {
return "", "", 0, fmt.Errorf("prepare agent socket dir: %w", err)
}
socketPath := socketDir + "/" + agentSocketName
if err := os.Remove(socketPath); err != nil && !errors.Is(err, os.ErrNotExist) {
log.Debugf("clear stale agent socket %s: %v", socketPath, err)
}
token, err := generateAuthToken()
if err != nil {
return "", "", 0, fmt.Errorf("generate agent auth token: %w", err)
}
if err := spawnAgentForUser(consoleUID, socketPath, token); err != nil {
return "", "", 0, err
}
if err := waitForAgent(ctx, socketPath, 5*time.Second); err != nil {
killAllVNCAgents()
return "", "", 0, fmt.Errorf("agent did not start listening: %w", err)
}
m.authToken = token
m.socketPath = socketPath
m.uid = consoleUID
m.running = true
log.Infof("spawned VNC agent for console uid=%d on %s", consoleUID, socketPath)
return socketPath, token, consoleUID, nil
}
// prepareAgentSocketDir creates a per-uid subdirectory under the netbird
// runtime directory where the agent will bind its Unix socket. The leaf is
// owned by uid with mode 0700, so only the target user and root can write
// there. The parent is created root-owned with mode 0755 if missing.
// Symlinks at the per-uid level are refused (replaced with a fresh
// directory) so a low-priv user cannot redirect the chown that follows.
func prepareAgentSocketDir(uid uint32) (string, error) {
parent := configs.RuntimeDir
if err := ensureAgentSocketParent(parent); err != nil {
return "", err
}
subdir := fmt.Sprintf("%s/vnc-%d", parent, uid)
if err := purgeStaleAgentSubdir(subdir, uid); err != nil {
return "", err
}
if err := os.Mkdir(subdir, 0o700); err != nil && !errors.Is(err, os.ErrExist) {
return "", fmt.Errorf("mkdir %s: %w", subdir, err)
}
if err := os.Chmod(subdir, 0o700); err != nil {
return "", fmt.Errorf("chmod %s: %w", subdir, err)
}
if err := os.Chown(subdir, int(uid), -1); err != nil {
return "", fmt.Errorf("chown %s -> uid %d: %w", subdir, uid, err)
}
return subdir, nil
}
// ensureAgentSocketParent verifies the runtime parent dir exists, is not a
// symlink, and is owned by root.
func ensureAgentSocketParent(parent string) error {
if parent == "" {
return fmt.Errorf("no runtime directory configured for this platform")
}
if err := os.MkdirAll(parent, 0o755); err != nil {
return fmt.Errorf("mkdir %s: %w", parent, err)
}
info, err := os.Lstat(parent)
if err != nil {
return fmt.Errorf("lstat %s: %w", parent, err)
}
if info.Mode()&os.ModeSymlink != 0 {
return fmt.Errorf("%s is a symlink", parent)
}
if st, ok := info.Sys().(*syscall.Stat_t); ok && st.Uid != 0 {
return fmt.Errorf("%s not owned by root (uid=%d)", parent, st.Uid)
}
return nil
}
// purgeStaleAgentSubdir removes a leftover subdir unless it is a real dir
// owned by uid with mode 0700. Lstat (not Stat) so a symlink is detected.
func purgeStaleAgentSubdir(subdir string, uid uint32) error {
info, err := os.Lstat(subdir)
if errors.Is(err, os.ErrNotExist) {
return nil
}
if err != nil {
return fmt.Errorf("lstat %s: %w", subdir, err)
}
if agentSubdirOK(info, uid) {
return nil
}
if err := os.RemoveAll(subdir); err != nil {
return fmt.Errorf("remove stale %s: %w", subdir, err)
}
return nil
}
func agentSubdirOK(info os.FileInfo, uid uint32) bool {
if info.Mode()&os.ModeSymlink != 0 || !info.IsDir() {
return false
}
st, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return false
}
return st.Uid == uid && info.Mode().Perm() == 0o700
}
// stop terminates the spawned agent, if any. Intended for daemon shutdown.
func (m *darwinAgentManager) stop() {
m.mu.Lock()
defer m.mu.Unlock()
m.killLocked()
}
func (m *darwinAgentManager) killLocked() {
if !m.running {
return
}
killAllVNCAgents()
if m.socketPath != "" {
if err := os.Remove(m.socketPath); err != nil && !errors.Is(err, os.ErrNotExist) {
log.Debugf("remove agent socket %s: %v", m.socketPath, err)
}
}
m.running = false
m.authToken = ""
m.socketPath = ""
m.uid = 0
}
// consoleUserID returns the uid of the user currently sitting at the
// console (the one whose Aqua session is active). Returns
// errNoConsoleUser when nobody is logged in: at the login window
// /dev/console is owned by root.
func consoleUserID() (uint32, error) {
info, err := os.Stat("/dev/console")
if err != nil {
return 0, fmt.Errorf("stat /dev/console: %w", err)
}
st, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return 0, fmt.Errorf("/dev/console stat has unexpected type")
}
if st.Uid == 0 {
return 0, errNoConsoleUser
}
return st.Uid, nil
}
// spawnAgentForUser uses launchctl asuser to start a netbird vnc-agent
// process inside the target user's launchd bootstrap namespace. That is
// the only spawn mode on macOS that gives the child access to the user's
// WindowServer. The agent's stderr is relogged into the daemon log so
// startup failures are not silently lost when the readiness check times
// out.
func spawnAgentForUser(uid uint32, socketPath, token string) error {
exe, err := os.Executable()
if err != nil {
return fmt.Errorf("resolve own executable: %w", err)
}
cmd := exec.Command(
"/bin/launchctl", "asuser", strconv.FormatUint(uint64(uid), 10),
exe, vncAgentSubcommand,
"--socket", socketPath,
// Drop privs inside the agent: launchctl asuser preserves the
// daemon's uid (root), so without this the capture/input/
// encoder paths would run as root for the lifetime of the
// session. validateAgentPeer on the daemon side also relies on
// the agent's effective uid matching consoleUID.
"--target-uid", strconv.FormatUint(uint64(uid), 10),
)
cmd.Env = append(os.Environ(), agentTokenEnvVar+"="+token)
stderr, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("agent stderr pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("launchctl asuser: %w", err)
}
go func() {
defer stderr.Close()
relogAgentStream(stderr)
}()
go func() { _ = cmd.Wait() }()
return nil
}
// waitForAgent dials the agent's Unix socket until it answers. Used to
// gate proxy attempts until the spawned process has finished its Start.
func waitForAgent(ctx context.Context, socketPath string, wait time.Duration) error {
var d net.Dialer
deadline := time.Now().Add(wait)
for time.Now().Before(deadline) {
if ctx.Err() != nil {
return ctx.Err()
}
dialCtx, cancel := context.WithTimeout(ctx, 200*time.Millisecond)
c, err := d.DialContext(dialCtx, "unix", socketPath)
cancel()
if err == nil {
_ = c.Close()
return nil
}
time.Sleep(100 * time.Millisecond)
}
return fmt.Errorf("timeout dialing %s", socketPath)
}
// vncAgentRunning reports whether any vnc-agent process exists on the
// system. There is at most one agent per machine, so any match is "the"
// agent.
func vncAgentRunning() bool {
pids, err := vncAgentPIDs()
if err != nil {
log.Debugf("scan for vnc-agent: %v", err)
return false
}
return len(pids) > 0
}
// killAllVNCAgents sends SIGTERM to every process whose argv contains
// "vnc-agent", waits briefly for them to exit, and escalates to SIGKILL
// for any that remain. We enumerate kern.proc.all rather than
// kern.proc.uid because launchctl asuser preserves the caller's uid
// (root) on the spawned child, so a uid-scoped filter would never match.
func killAllVNCAgents() {
pids, err := vncAgentPIDs()
if err != nil {
log.Debugf("scan for vnc-agent: %v", err)
return
}
for _, pid := range pids {
_ = syscall.Kill(pid, syscall.SIGTERM)
}
if len(pids) == 0 {
return
}
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
remaining, _ := vncAgentPIDs()
if len(remaining) == 0 {
return
}
time.Sleep(100 * time.Millisecond)
}
leftover, _ := vncAgentPIDs()
for _, pid := range leftover {
_ = syscall.Kill(pid, syscall.SIGKILL)
}
}
// vncAgentPIDs returns the pids of vnc-agent subprocesses spawned from
// this binary. Matches exactly on argv[0] == our own executable path
// AND argv[1] == "vnc-agent" so unrelated processes that happen to have
// the same name elsewhere in argv are not targeted. Skips pid 0 and 1
// defensively.
func vncAgentPIDs() ([]int, error) {
procs, err := unix.SysctlKinfoProcSlice("kern.proc.all")
if err != nil {
return nil, fmt.Errorf("sysctl kern.proc.all: %w", err)
}
ownExe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("resolve own executable: %w", err)
}
var out []int
for i := range procs {
pid := int(procs[i].Proc.P_pid)
if pid <= 1 {
continue
}
argv, err := procArgv(pid)
if err != nil || !argvIsVNCAgent(argv, ownExe) {
continue
}
out = append(out, pid)
}
return out, nil
}
// procArgv reads the kernel's stored argv for pid via the kern.procargs2
// sysctl. Format: 4-byte argc, then argv[0..argc) each NUL-terminated,
// then envp, then padding. We only need argv so we stop after argc.
func procArgv(pid int) ([]string, error) {
raw, err := unix.SysctlRaw("kern.procargs2", pid)
if err != nil {
return nil, err
}
if len(raw) < 4 {
return nil, fmt.Errorf("procargs2 truncated")
}
argc := int(raw[0]) | int(raw[1])<<8 | int(raw[2])<<16 | int(raw[3])<<24
body := raw[4:]
// Skip the executable path (NUL-terminated) and any zero padding that
// follows before argv[0].
end := bytes.IndexByte(body, 0)
if end < 0 {
return nil, fmt.Errorf("procargs2 path unterminated")
}
body = body[end+1:]
for len(body) > 0 && body[0] == 0 {
body = body[1:]
}
args := make([]string, 0, argc)
for i := 0; i < argc; i++ {
end := bytes.IndexByte(body, 0)
if end < 0 {
break
}
args = append(args, string(body[:end]))
body = body[end+1:]
}
return args, nil
}
// argvIsVNCAgent reports whether argv belongs to a vnc-agent subprocess
// spawned from our binary. Requires argv[0] to match ownExe exactly and
// argv[1] to be the vnc-agent subcommand. Matches the spawn shape in
// spawnAgentForUser and rejects anything else.
func argvIsVNCAgent(argv []string, ownExe string) bool {
if len(argv) < 2 || ownExe == "" {
return false
}
return argv[0] == ownExe && argv[1] == vncAgentSubcommand
}

View File

@@ -1,305 +0,0 @@
//go:build darwin || windows
package server
import (
"bufio"
"bytes"
"context"
crand "crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"time"
log "github.com/sirupsen/logrus"
)
// errNoConsoleUser is the sentinel returned by sessionAgent.Resolve when
// the platform has no interactive user to attach a capture agent to (the
// macOS loginwindow state). Mapped to a distinct RFB reject code so the
// browser can show a meaningful message.
var errNoConsoleUser = errors.New("no user logged into console")
// sessionAgent abstracts the per-platform manager that spawns and tracks
// the user-session VNC agent. Resolve returns the agent's Unix-socket
// path, the shared per-spawn token, and the uid the agent was spawned
// under (used to validate peer credentials before the daemon hands the
// token to whoever is on the other end of the socket). Resolve may spawn
// the agent lazily.
type sessionAgent interface {
Resolve(ctx context.Context) (socketPath, token string, peerUID uint32, err error)
}
// prefixConn replays already-consumed header bytes ahead of the proxy
// stream by swapping in a different Reader on the same underlying Conn.
type prefixConn struct {
io.Reader
net.Conn
}
func (p *prefixConn) Read(b []byte) (int, error) { return p.Reader.Read(b) }
// handleServiceConnection runs the connection-header handshake (source
// check, Noise_IK auth) on conn, resolves the right per-session agent
// via sa, and proxies to it. Every accepted connection emits exactly one
// outcome line on the daemon log.
func (s *Server) handleServiceConnection(conn net.Conn, sa sessionAgent) {
start := time.Now()
connLog := s.log.WithField("remote", conn.RemoteAddr().String())
if !s.isAllowedSource(conn.RemoteAddr()) {
connLog.Info("VNC connection rejected: source not allowed")
_ = conn.Close()
return
}
var headerBuf bytes.Buffer
tee := io.TeeReader(conn, &headerBuf)
teeConn := &prefixConn{Reader: tee, Conn: conn}
header, err := s.readConnectionHeader(teeConn)
if err != nil {
connLog.Infof("VNC connection rejected: header read failed: %v", err)
_ = conn.Close()
return
}
authedLog, sessionUserID, ok := s.authorizeSession(conn, header, connLog)
if !ok {
authedLog.Info("VNC connection rejected: auth failed")
return
}
if err := s.registerConnAuth(conn, header); err != nil {
rejectConnection(conn, codeMessage(RejectCodeAuthForbidden, err.Error()))
authedLog.Warnf("VNC connection rejected: %v", err)
return
}
decision, err := s.gateApproval(conn, header)
if err != nil {
authedLog.Infof("VNC connection rejected: %v", err)
return
}
if decision.ViewOnly {
authedLog.Info("VNC connection approved by user (view-only)")
} else if s.requireApproval {
authedLog.Info("VNC connection approved by user")
}
socketPath, token, peerUID, err := sa.Resolve(s.ctx)
if err != nil {
code := RejectCodeCapturerError
if errors.Is(err, errNoConsoleUser) {
code = RejectCodeNoConsoleUser
}
rejectConnection(conn, codeMessage(code, err.Error()))
authedLog.Warnf("VNC connection rejected: agent unavailable: %v", err)
return
}
var initiator string
if s.authorizer != nil {
initiator = s.authorizer.LookupSessionDisplayName(header.clientStatic)
}
sessionID := s.addSession(ActiveSessionInfo{
RemoteAddress: conn.RemoteAddr().String(),
Mode: modeString(header.mode),
Username: header.username,
UserID: sessionUserID,
Initiator: initiator,
}, conn)
defer s.removeSession(sessionID)
replayConn := &prefixConn{
Reader: io.MultiReader(&headerBuf, conn),
Conn: conn,
}
if err := proxyToAgent(s.ctx, replayConn, socketPath, token, peerUID, decision.ViewOnly, authedLog); err != nil {
rejectConnection(conn, codeMessage(RejectCodeCapturerError, err.Error()))
authedLog.Warnf("VNC connection rejected: agent unreachable: %v", err)
return
}
authedLog.Infof("VNC connection closed (%dms)", time.Since(start).Milliseconds())
}
const (
// agentTokenLen is the size of the random per-spawn token in bytes.
agentTokenLen = 32
// agentTokenEnvVar names the environment variable the daemon uses to
// hand the per-spawn token to the agent child. Out-of-band channels
// like this keep the secret out of the command line, where listings
// such as `ps` or Windows tasklist would expose it.
agentTokenEnvVar = "NB_VNC_AGENT_TOKEN" // #nosec G101 -- env var name, not a credential
// vncAgentSubcommand is the CLI subcommand the daemon invokes to start
// the per-session agent process. Must match cmd.vncAgentCmd.Use in
// client/cmd/vnc_agent.go.
vncAgentSubcommand = "vnc-agent"
)
// generateAuthToken returns a fresh hex-encoded random token for one
// daemon→agent session. The daemon hands this to the spawned agent
// out-of-band (env var on Windows) and verifies it on every connection
// the agent accepts.
func generateAuthToken() (string, error) {
b := make([]byte, agentTokenLen)
if _, err := crand.Read(b); err != nil {
return "", fmt.Errorf("read random: %w", err)
}
return hex.EncodeToString(b), nil
}
// proxyToAgent dials the per-session agent's Unix socket, validates the
// peer's kernel-asserted uid (so the daemon never hands its per-spawn
// token to an impostor that won the listen race), writes the raw token
// bytes plus a single view-only flag byte, then copies bytes both ways
// until either side closes. The token + flag prefix must precede any RFB
// byte so the agent's verifyAgentToken can run first. Returns nil once a
// stream is established; the caller is responsible for sending an
// RFB-level rejection on error so the client sees a reason instead of a
// bare timeout. authedLog receives one audit line per dispatched
// preamble so an operator can correlate daemon→agent traffic with the
// remote session that triggered it.
func proxyToAgent(ctx context.Context, client net.Conn, socketPath, authToken string, peerUID uint32, viewOnly bool, authedLog *log.Entry) error {
tokenBytes, err := hex.DecodeString(authToken)
if err != nil || len(tokenBytes) != agentTokenLen {
return fmt.Errorf("invalid auth token (len=%d): %w", len(tokenBytes), err)
}
agentConn, err := dialAgentWithRetry(ctx, socketPath)
if err != nil {
return fmt.Errorf("dial agent at %s: %w", socketPath, err)
}
if err := validateAgentPeer(agentConn, peerUID); err != nil {
_ = agentConn.Close()
return fmt.Errorf("agent peer validation failed: %w", err)
}
preamble := make([]byte, len(tokenBytes)+1)
copy(preamble, tokenBytes)
if viewOnly {
preamble[len(tokenBytes)] = 1
}
if _, err := agentConn.Write(preamble); err != nil {
_ = agentConn.Close()
return fmt.Errorf("send auth preamble to agent: %w", err)
}
// Audit: one line per successfully-dispatched daemon→agent preamble.
// Token printed as its first 8 hex chars (enough to correlate, not
// enough to use). Kept at Info so the default deployment captures it.
tokenFp := authToken
if len(tokenFp) > 8 {
tokenFp = tokenFp[:8]
}
if authedLog != nil {
authedLog.Infof("VNC IPC: dispatched preamble to agent socket=%s peer_uid=%d view_only=%v token_fp=%s", socketPath, peerUID, viewOnly, tokenFp)
}
defer client.Close()
defer agentConn.Close()
log.Debugf("proxy connected to agent, starting bidirectional copy")
done := make(chan struct{}, 2)
cp := func(label string, dst, src net.Conn) {
n, err := io.Copy(dst, src)
log.Debugf("proxy %s: %d bytes, err=%v", label, n, err)
done <- struct{}{}
}
go cp("client->agent", agentConn, client)
go cp("agent->client", client, agentConn)
<-done
return nil
}
// relogAgentStream reads log lines from the agent's stderr and re-emits
// them through the daemon's logrus, so the merged log keeps a single
// format. JSON lines (the agent's normal output) are parsed and dispatched
// by level; plain-text lines (cobra errors, panic traces) are forwarded
// verbatim so early-startup failures stay visible.
func relogAgentStream(r io.Reader) {
entry := log.WithField("component", "vnc-agent")
scanner := bufio.NewScanner(r)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Bytes()
if len(line) == 0 {
continue
}
if line[0] != '{' {
entry.Warn(string(line))
continue
}
var m map[string]any
if err := json.Unmarshal(line, &m); err != nil {
entry.Warn(string(line))
continue
}
msg, _ := m["msg"].(string)
if msg == "" {
continue
}
fields := make(log.Fields)
for k, v := range m {
switch k {
case "msg", "level", "time", "func":
continue
case "caller":
fields["source"] = v
default:
fields[k] = v
}
}
e := entry.WithFields(fields)
switch m["level"] {
case "error":
e.Error(msg)
case "warning":
e.Warn(msg)
case "debug":
e.Debug(msg)
case "trace":
e.Trace(msg)
default:
e.Info(msg)
}
}
}
// dialAgentWithRetry retries the loopback connect for up to ~10 s so the
// daemon does not race the agent's first listen. Returns the live conn or
// the final error. Aborts early when ctx is cancelled so a Stop() during
// service-mode startup doesn't leave a goroutine sleeping for 10 s.
func dialAgentWithRetry(ctx context.Context, addr string) (net.Conn, error) {
var d net.Dialer
var lastErr error
for range 50 {
if err := ctx.Err(); err != nil {
if lastErr == nil {
lastErr = err
}
return nil, lastErr
}
dialCtx, cancel := context.WithTimeout(ctx, time.Second)
c, err := d.DialContext(dialCtx, "unix", addr)
cancel()
if err == nil {
return c, nil
}
lastErr = err
select {
case <-ctx.Done():
if errors.Is(lastErr, context.Canceled) || errors.Is(lastErr, context.DeadlineExceeded) {
lastErr = ctx.Err()
}
return nil, lastErr
case <-time.After(200 * time.Millisecond):
}
}
return nil, lastErr
}

View File

@@ -1,46 +0,0 @@
//go:build darwin && !ios
package server
import (
"fmt"
"net"
"golang.org/x/sys/unix"
)
// validateAgentPeer enforces that the peer behind the just-connected Unix
// socket is the agent we expect it to be: a process running under
// expectedUID, with the right effective uid stamped by the kernel on the
// socket. Refuses (with a non-nil error) if anything else is listening on
// the path (an unrelated local process that won the listen race or
// squatted the path before us). Defends against the daemon shipping its
// per-spawn auth token to a process that isn't the spawned agent.
func validateAgentPeer(conn net.Conn, expectedUID uint32) error {
uconn, ok := conn.(*net.UnixConn)
if !ok {
return fmt.Errorf("peer cred: expected *net.UnixConn, got %T", conn)
}
raw, err := uconn.SyscallConn()
if err != nil {
return fmt.Errorf("peer cred: syscall conn: %w", err)
}
var cred *unix.Xucred
var inner error
ctlErr := raw.Control(func(fd uintptr) {
cred, inner = unix.GetsockoptXucred(int(fd), unix.SOL_LOCAL, unix.LOCAL_PEERCRED)
})
if ctlErr != nil {
return fmt.Errorf("peer cred: control: %w", ctlErr)
}
if inner != nil {
return fmt.Errorf("peer cred: getsockopt LOCAL_PEERCRED: %w", inner)
}
if cred == nil {
return fmt.Errorf("peer cred: nil xucred")
}
if cred.Uid != expectedUID {
return fmt.Errorf("peer cred: agent uid %d does not match expected %d", cred.Uid, expectedUID)
}
return nil
}

View File

@@ -1,115 +0,0 @@
//go:build darwin && !ios
package server
import (
"net"
"os"
"path/filepath"
"strings"
"sync"
"testing"
)
// TestValidateAgentPeerAcceptsOwnUID confirms the happy path: a Unix
// socket whose peer is the current process must validate when the
// expected uid matches the process's own. Both sides of a unix-socket
// pair share the same kernel cred, so this exercises the real getsockopt
// LOCAL_PEERCRED path.
func TestValidateAgentPeerAcceptsOwnUID(t *testing.T) {
dir := t.TempDir()
sockPath := filepath.Join(dir, "test.sock")
ln, err := net.Listen("unix", sockPath)
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c, err := ln.Accept()
if err == nil {
_ = c.Close()
}
}()
c, err := net.Dial("unix", sockPath)
if err != nil {
t.Fatalf("dial: %v", err)
}
defer c.Close()
if err := validateAgentPeer(c, uint32(os.Getuid())); err != nil {
t.Fatalf("validateAgentPeer rejected own uid: %v", err)
}
wg.Wait()
}
// TestValidateAgentPeerRejectsWrongUID ensures the validator fails when
// the expected uid differs from the kernel-reported peer uid. This is
// the path that catches a hostile process that won the listen race.
func TestValidateAgentPeerRejectsWrongUID(t *testing.T) {
dir := t.TempDir()
sockPath := filepath.Join(dir, "test.sock")
ln, err := net.Listen("unix", sockPath)
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c, err := ln.Accept()
if err == nil {
_ = c.Close()
}
}()
c, err := net.Dial("unix", sockPath)
if err != nil {
t.Fatalf("dial: %v", err)
}
defer c.Close()
// Pick a uid the test process certainly isn't running as.
wrongUID := uint32(os.Getuid()) + 1
err = validateAgentPeer(c, wrongUID)
if err == nil {
t.Fatal("expected mismatch error, got nil")
}
if !strings.Contains(err.Error(), "does not match expected") {
t.Fatalf("error should mention uid mismatch, got: %v", err)
}
wg.Wait()
}
// TestValidateAgentPeerRejectsNonUnix protects against being handed a
// non-Unix-socket connection (the validator can't enforce anything on
// e.g. a *net.TCPConn so it must refuse rather than silently pass).
func TestValidateAgentPeerRejectsNonUnix(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen tcp: %v", err)
}
defer ln.Close()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c, err := ln.Accept()
if err == nil {
_ = c.Close()
}
}()
c, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial tcp: %v", err)
}
defer c.Close()
if err := validateAgentPeer(c, 0); err == nil {
t.Fatal("expected refusal on non-unix conn, got nil")
}
wg.Wait()
}

View File

@@ -1,31 +0,0 @@
//go:build windows
package server
import (
"net"
)
// validateAgentPeer is a documented no-op on Windows. AF_UNIX on Windows
// exposes no SO_PEERCRED equivalent and no supported API to recover the
// peer process from an accepted AF_UNIX connection, so the daemon cannot
// match the connected peer against the agent PID it spawned the way the
// darwin path does via LOCAL_PEERCRED. The Windows trust model therefore
// rests on three other measures, none of which assume the socket path is
// secret:
//
// - the socket lives in a dedicated directory (agentSocketDir) created
// with a DACL granting only SYSTEM and Administrators, so an
// unprivileged local user cannot create or squat a socket there;
// - each spawn uses a cryptographically random socket name, so the path
// is unguessable before the agent binds it;
// - the daemon publishes the path only after confirming the spawned
// agent is listening (see waitForAgentListening), and gates every
// connection on the per-spawn auth-token preamble that follows this
// call.
//
// If a future Windows release exposes peer-PID retrieval for AF_UNIX,
// this function should verify the peer against the spawned agent PID.
func validateAgentPeer(_ net.Conn, _ uint32) error {
return nil
}

View File

@@ -1,747 +0,0 @@
//go:build windows
package server
import (
"context"
crand "crypto/rand"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"net"
"os"
"path/filepath"
"runtime"
"sync"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
stillActive = 259
tokenPrimary = 1
securityImpersonation = 2
tokenSessionID = 12
createUnicodeEnvironment = 0x00000400
createNoWindow = 0x08000000
createSuspended = 0x00000004
createBreakawayFromJob = 0x01000000
)
var (
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
advapi32 = windows.NewLazySystemDLL("advapi32.dll")
userenv = windows.NewLazySystemDLL("userenv.dll")
procWTSGetActiveConsoleSessionId = kernel32.NewProc("WTSGetActiveConsoleSessionId")
procCreateJobObjectW = kernel32.NewProc("CreateJobObjectW")
procSetInformationJobObject = kernel32.NewProc("SetInformationJobObject")
procAssignProcessToJobObject = kernel32.NewProc("AssignProcessToJobObject")
procSetTokenInformation = advapi32.NewProc("SetTokenInformation")
procCreateEnvironmentBlock = userenv.NewProc("CreateEnvironmentBlock")
procDestroyEnvironmentBlock = userenv.NewProc("DestroyEnvironmentBlock")
wtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll")
procWTSEnumerateSessionsW = wtsapi32.NewProc("WTSEnumerateSessionsW")
procWTSFreeMemory = wtsapi32.NewProc("WTSFreeMemory")
procWTSQuerySessionInformation = wtsapi32.NewProc("WTSQuerySessionInformationW")
)
// GetCurrentSessionID returns the session ID of the current process.
func GetCurrentSessionID() uint32 {
var token windows.Token
if err := windows.OpenProcessToken(windows.CurrentProcess(),
windows.TOKEN_QUERY, &token); err != nil {
return 0
}
defer token.Close()
var id uint32
var ret uint32
_ = windows.GetTokenInformation(token, windows.TokenSessionId,
(*byte)(unsafe.Pointer(&id)), 4, &ret)
return id
}
func getConsoleSessionID() uint32 {
r, _, _ := procWTSGetActiveConsoleSessionId.Call()
return uint32(r)
}
const (
wtsActive = 0
wtsConnected = 1
wtsDisconnected = 4
)
// getActiveSessionID returns the session ID of the best session to attach to.
// On a Windows Server with no console display attached, session 1 still
// reports WTSActive (login screen "owns" the console), so a naive
// first-active-wins pick lands on a session with no actual rendering.
// Preference order:
// 1. Active session with a user logged in (RDP user in session ≥2)
// 2. Active session without a user (console at login screen)
// 3. Console session ID
func getActiveSessionID() uint32 {
var sessionInfo uintptr
var count uint32
r, _, _ := procWTSEnumerateSessionsW.Call(
0, // WTS_CURRENT_SERVER_HANDLE
0, // reserved
1, // version
uintptr(unsafe.Pointer(&sessionInfo)),
uintptr(unsafe.Pointer(&count)),
)
if r == 0 || count == 0 {
return getConsoleSessionID()
}
defer func() { _, _, _ = procWTSFreeMemory.Call(sessionInfo) }()
type wtsSession struct {
SessionID uint32
Station *uint16
State uint32
}
sessions := unsafe.Slice((*wtsSession)(unsafe.Pointer(sessionInfo)), count)
var withUser uint32
var withUserFound bool
var anyActive uint32
var anyActiveFound bool
for _, s := range sessions {
if s.SessionID == 0 {
continue
}
if s.State != wtsActive {
continue
}
if !anyActiveFound {
anyActive = s.SessionID
anyActiveFound = true
}
if !withUserFound && wtsSessionHasUser(s.SessionID) {
withUser = s.SessionID
withUserFound = true
}
}
if withUserFound {
return withUser
}
if anyActiveFound {
return anyActive
}
return getConsoleSessionID()
}
// wtsSessionHasUser returns true if the session has a non-empty user name,
// i.e. someone is logged in (vs. the login/Welcome screen). The console
// session at the lock screen has WTSUserName == "".
const wtsUserName = 5
func wtsSessionHasUser(sessionID uint32) bool {
var buf uintptr
var bytesReturned uint32
r, _, _ := procWTSQuerySessionInformation.Call(
0, // WTS_CURRENT_SERVER_HANDLE
uintptr(sessionID),
uintptr(wtsUserName),
uintptr(unsafe.Pointer(&buf)),
uintptr(unsafe.Pointer(&bytesReturned)),
)
if r == 0 || buf == 0 {
return false
}
defer func() { _, _, _ = procWTSFreeMemory.Call(buf) }()
// First UTF-16 code unit non-zero ⇒ non-empty username.
return *(*uint16)(unsafe.Pointer(buf)) != 0
}
// getSystemTokenForSession duplicates the current SYSTEM token and sets its
// session ID so the spawned process runs in the target session. Using a SYSTEM
// token gives access to both Default and Winlogon desktops plus UIPI bypass.
func getSystemTokenForSession(sessionID uint32) (windows.Token, error) {
var cur windows.Token
if err := windows.OpenProcessToken(windows.CurrentProcess(),
windows.MAXIMUM_ALLOWED, &cur); err != nil {
return 0, fmt.Errorf("OpenProcessToken: %w", err)
}
defer cur.Close()
var dup windows.Token
if err := windows.DuplicateTokenEx(cur, windows.MAXIMUM_ALLOWED, nil,
securityImpersonation, tokenPrimary, &dup); err != nil {
return 0, fmt.Errorf("DuplicateTokenEx: %w", err)
}
sid := sessionID
r, _, err := procSetTokenInformation.Call(
uintptr(dup),
uintptr(tokenSessionID),
uintptr(unsafe.Pointer(&sid)),
unsafe.Sizeof(sid),
)
if r == 0 {
dup.Close()
return 0, fmt.Errorf("SetTokenInformation(SessionId=%d): %w", sessionID, err)
}
return dup, nil
}
// injectEnvVar appends a KEY=VALUE entry to a Unicode environment block.
// The block is a sequence of null-terminated UTF-16 strings, terminated by
// an extra null. Returns the new []uint16 backing slice; the caller must
// hold the returned slice alive until CreateProcessAsUser completes.
func injectEnvVar(envBlock uintptr, key, value string) []uint16 {
entry := key + "=" + value
// Walk the existing block to find its total length.
ptr := (*uint16)(unsafe.Pointer(envBlock))
var totalChars int
for {
ch := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars)*2))
if ch == 0 {
// Check for double-null terminator.
next := *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(totalChars+1)*2))
totalChars++
if next == 0 {
// End of block (don't count the final null yet, we'll rebuild).
break
}
} else {
totalChars++
}
}
entryUTF16, _ := windows.UTF16FromString(entry)
// New block: existing entries + new entry (null-terminated) + final null.
newLen := totalChars + len(entryUTF16) + 1
newBlock := make([]uint16, newLen)
// Copy existing entries (up to but not including the final null).
for i := range totalChars {
newBlock[i] = *(*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(ptr)) + uintptr(i)*2))
}
copy(newBlock[totalChars:], entryUTF16)
newBlock[newLen-1] = 0 // final null terminator
return newBlock
}
func spawnAgentInSession(sessionID uint32, socketPath, authToken string, jobHandle windows.Handle) (windows.Handle, error) {
token, err := getSystemTokenForSession(sessionID)
if err != nil {
return 0, fmt.Errorf("get SYSTEM token for session %d: %w", sessionID, err)
}
defer token.Close()
var envBlock uintptr
r, _, e := procCreateEnvironmentBlock.Call(
uintptr(unsafe.Pointer(&envBlock)),
uintptr(token),
0,
)
if r == 0 {
// Without an environment block we cannot inject NB_VNC_AGENT_TOKEN;
// the agent would start unauthenticated. Abort instead of launching.
return 0, fmt.Errorf("CreateEnvironmentBlock: %w", e)
}
defer func() { _, _, _ = procDestroyEnvironmentBlock.Call(envBlock) }()
// Inject the auth token into the environment block so it doesn't appear
// in the process command line (visible via tasklist/wmic). injectedBlock
// must stay alive until CreateProcessAsUser returns.
injectedBlock := injectEnvVar(envBlock, agentTokenEnvVar, authToken)
exePath, err := os.Executable()
if err != nil {
return 0, fmt.Errorf("get executable path: %w", err)
}
cmdLine := fmt.Sprintf(`"%s" %s --socket %q`, exePath, vncAgentSubcommand, socketPath)
cmdLineW, err := windows.UTF16PtrFromString(cmdLine)
if err != nil {
return 0, fmt.Errorf("UTF16 cmdline: %w", err)
}
// Create an inheritable pipe for the agent's stderr so we can relog
// its output in the service process.
var sa windows.SecurityAttributes
sa.Length = uint32(unsafe.Sizeof(sa))
sa.InheritHandle = 1
var stderrRead, stderrWrite windows.Handle
if err := windows.CreatePipe(&stderrRead, &stderrWrite, &sa, 0); err != nil {
return 0, fmt.Errorf("create stderr pipe: %w", err)
}
// The read end must NOT be inherited by the child.
_ = windows.SetHandleInformation(stderrRead, windows.HANDLE_FLAG_INHERIT, 0)
desktop, _ := windows.UTF16PtrFromString(`WinSta0\Default`)
si := windows.StartupInfo{
Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})),
Desktop: desktop,
Flags: windows.STARTF_USESHOWWINDOW | windows.STARTF_USESTDHANDLES,
ShowWindow: 0,
StdErr: stderrWrite,
StdOutput: stderrWrite,
}
var pi windows.ProcessInformation
var envPtr *uint16
if len(injectedBlock) > 0 {
envPtr = &injectedBlock[0]
} else if envBlock != 0 {
envPtr = (*uint16)(unsafe.Pointer(envBlock))
}
// CREATE_SUSPENDED so we can assign the process to our Job Object
// before it executes. Without this the agent could spawn its own child
// processes and have them inherit the SCM service-job (not ours), or
// briefly listen on the agent port before we tear it down on rollback.
// CREATE_BREAKAWAY_FROM_JOB lets the child leave the SCM-managed
// service job; harmless if that job allows breakaway, and is required
// before AssignProcessToJobObject can succeed in the no-nested-jobs case.
err = windows.CreateProcessAsUser(
token, nil, cmdLineW,
nil, nil, true, // inheritHandles=true for the pipe
createUnicodeEnvironment|createNoWindow|createSuspended|createBreakawayFromJob,
envPtr, nil, &si, &pi,
)
runtime.KeepAlive(injectedBlock)
// Close the write end in the parent so reads will get EOF when the child exits.
_ = windows.CloseHandle(stderrWrite)
if err != nil {
_ = windows.CloseHandle(stderrRead)
return 0, fmt.Errorf("CreateProcessAsUser: %w", err)
}
if jobHandle != 0 {
r, _, e := procAssignProcessToJobObject.Call(uintptr(jobHandle), uintptr(pi.Process))
if r == 0 {
log.Warnf("assign agent to job object: %v (orphan possible on service crash)", e)
}
}
if _, err := windows.ResumeThread(pi.Thread); err != nil {
_ = windows.CloseHandle(pi.Thread)
_ = windows.TerminateProcess(pi.Process, 1)
_ = windows.CloseHandle(pi.Process)
_ = windows.CloseHandle(stderrRead)
return 0, fmt.Errorf("ResumeThread: %w", err)
}
_ = windows.CloseHandle(pi.Thread)
// Relog agent output in the service with a [vnc-agent] prefix.
go relogAgentOutput(stderrRead)
log.Infof("spawned agent PID=%d in session %d on %s", pi.ProcessId, sessionID, socketPath)
return pi.Process, nil
}
// sessionManager monitors the active console session and ensures a VNC agent
// process is running in it. When the session changes (e.g., user switch, RDP
// connect/disconnect), it kills the old agent and spawns a new one. Each
// spawn picks a per-session Unix-socket path the agent binds and the
// daemon dials over local IPC.
type sessionManager struct {
mu sync.Mutex
agentProc windows.Handle
everSpawned bool
agentStartedAt time.Time
spawnFailures int
nextSpawnAt time.Time
sessionID uint32
authToken string
socketPath string
done chan struct{}
// jobHandle owns the agent processes via a Windows Job Object with
// JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE. When the service exits or crashes,
// the OS closes the handle and terminates every assigned agent: no
// orphaned agent processes holding a socket across restarts.
jobHandle windows.Handle
}
const (
// agentSocketDir is a dedicated subdirectory under C:\Windows\Temp that
// the daemon creates with a restrictive DACL (SYSTEM + Administrators
// only). The default ACL on C:\Windows\Temp grants BUILTIN\Users
// create-file rights, so the agent socket must not live directly there:
// an unprivileged local user could pre-create a predictable path and
// intercept the daemon→agent stream. Both the daemon and the agent run
// as SYSTEM, so a SYSTEM-write-only directory is sufficient.
agentSocketDir = `C:\Windows\Temp\netbird-vnc`
// agentSocketDirSDDL grants full access to Local System (SY) and the
// Builtin Administrators group (BA) only, with the DACL protected
// (P) from inheritance so the parent's BUILTIN\Users grant does not
// flow in. AI is omitted; PAI marks the DACL protected and auto-
// inherited entries cleared.
agentSocketDirSDDL = "D:PAI(A;;FA;;;SY)(A;;FA;;;BA)"
// agentSocketRandomLen is the number of random bytes mixed into each
// per-spawn socket name so the path is unguessable before the agent
// owns it.
agentSocketRandomLen = 16
// agentReadyTimeout bounds how long the daemon waits for the freshly
// spawned agent to bind and accept on its socket before treating the
// spawn as failed.
agentReadyTimeout = 5 * time.Second
)
func newSessionManager() *sessionManager {
m := &sessionManager{sessionID: ^uint32(0), done: make(chan struct{})}
if h, err := createKillOnCloseJob(); err != nil {
log.Warnf("create job object for vnc-agent (orphan agents possible after crash): %v", err)
} else {
m.jobHandle = h
}
return m
}
// createKillOnCloseJob returns a Job Object configured so that closing its
// handle (process exit or explicit Close) terminates every process assigned
// to it. Used to keep orphaned vnc-agent processes from outliving the service.
func createKillOnCloseJob() (windows.Handle, error) {
r, _, e := procCreateJobObjectW.Call(0, 0)
if r == 0 {
return 0, fmt.Errorf("CreateJobObject: %w", e)
}
job := windows.Handle(r)
// JOBOBJECT_EXTENDED_LIMIT_INFORMATION on amd64 = 144 bytes.
//
// JOBOBJECT_BASIC_LIMIT_INFORMATION (64 bytes with alignment padding)
// PerProcessUserTimeLimit LARGE_INTEGER off 0
// PerJobUserTimeLimit LARGE_INTEGER off 8
// LimitFlags DWORD off 16
// [4 byte pad to align SIZE_T]
// MinimumWorkingSetSize SIZE_T off 24
// MaximumWorkingSetSize SIZE_T off 32
// ActiveProcessLimit DWORD off 40
// [4 byte pad to align ULONG_PTR]
// Affinity ULONG_PTR off 48
// PriorityClass DWORD off 56
// SchedulingClass DWORD off 60
// IO_COUNTERS (48) + 4 * SIZE_T (32) = 144 total.
//
// We only set LimitFlags; the rest stays zero.
const sizeofExtended = 144
const offsetLimitFlags = 16
const jobObjectExtendedLimitInformation = 9
const jobObjectLimitKillOnJobClose = 0x00002000
var info [sizeofExtended]byte
binary.LittleEndian.PutUint32(info[offsetLimitFlags:offsetLimitFlags+4], jobObjectLimitKillOnJobClose)
r, _, e = procSetInformationJobObject.Call(
uintptr(job),
uintptr(jobObjectExtendedLimitInformation),
uintptr(unsafe.Pointer(&info[0])),
uintptr(sizeofExtended),
)
if r == 0 {
_ = windows.CloseHandle(job)
return 0, fmt.Errorf("SetInformationJobObject(KILL_ON_JOB_CLOSE): %w", e)
}
return job, nil
}
// Resolve returns the current agent socket path, shared token, and the
// uid the agent runs under (0 on Windows since the agent runs as
// SYSTEM in the interactive session; see validateAgentPeer for the
// Windows trust model). The path is only published after the spawned
// agent is confirmed listening, so a caller never receives a socket a
// squatter could be holding. When no agent is spawned yet (initial
// boot, between session switches, or permanently disabled when
// SE_TCB_NAME is missing) it surfaces a distinct error so the daemon
// can reject the connection with a meaningful message instead of timing
// out the proxy dial.
func (m *sessionManager) Resolve(_ context.Context) (string, string, uint32, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.socketPath == "" {
return "", "", 0, errAgentNotReady
}
return m.socketPath, m.authToken, 0, nil
}
var errAgentNotReady = errors.New("VNC agent not running yet")
// Stop signals the session manager to exit its polling loop and closes the
// Job Object handle, which Windows uses as the trigger to terminate every
// agent process this manager spawned.
func (m *sessionManager) Stop() {
select {
case <-m.done:
default:
close(m.done)
}
m.mu.Lock()
if m.jobHandle != 0 {
_ = windows.CloseHandle(m.jobHandle)
m.jobHandle = 0
}
m.mu.Unlock()
}
func (m *sessionManager) run() {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
if !m.tick() {
return
}
select {
case <-m.done:
m.mu.Lock()
m.killAgent()
m.mu.Unlock()
return
case <-ticker.C:
}
}
}
// tick performs one session/agent-state update. Returns false if the manager
// should permanently stop (e.g. missing SYSTEM privileges).
func (m *sessionManager) tick() bool {
sid := getActiveSessionID()
m.mu.Lock()
defer m.mu.Unlock()
m.handleSessionChange(sid)
m.reapExitedAgent()
return m.maybeSpawnAgent(sid)
}
func (m *sessionManager) handleSessionChange(sid uint32) {
if sid == m.sessionID {
return
}
log.Infof("active session changed: %d -> %d", m.sessionID, sid)
m.killAgent()
m.sessionID = sid
}
func (m *sessionManager) reapExitedAgent() {
if m.agentProc == 0 {
return
}
var code uint32
if err := windows.GetExitCodeProcess(m.agentProc, &code); err != nil {
log.Debugf("GetExitCodeProcess: %v", err)
return
}
if code == stillActive {
return
}
m.scheduleNextSpawn(code, time.Since(m.agentStartedAt))
if err := windows.CloseHandle(m.agentProc); err != nil {
log.Debugf("close agent handle: %v", err)
}
m.agentProc = 0
m.authToken = ""
m.socketPath = ""
}
// scheduleNextSpawn applies an exponential backoff on fast crashes (<5s) and
// resets immediately otherwise.
func (m *sessionManager) scheduleNextSpawn(exitCode uint32, lifetime time.Duration) {
if lifetime < 5*time.Second {
m.spawnFailures++
backoff := time.Duration(1<<min(m.spawnFailures, 5)) * time.Second
if backoff > 30*time.Second {
backoff = 30 * time.Second
}
m.nextSpawnAt = time.Now().Add(backoff)
log.Warnf("agent exited (code=%d) after %v, retrying in %v (failures=%d)", exitCode, lifetime.Round(time.Millisecond), backoff, m.spawnFailures)
return
}
m.spawnFailures = 0
m.nextSpawnAt = time.Time{}
log.Infof("agent exited (code=%d) after %v, respawning", exitCode, lifetime.Round(time.Second))
}
// maybeSpawnAgent spawns a new agent if there's no current one and the backoff
// window has elapsed. Returns false to permanently stop the manager when the
// service lacks the privileges needed to spawn cross-session.
func (m *sessionManager) maybeSpawnAgent(sid uint32) bool {
if m.agentProc != 0 || sid == 0xFFFFFFFF || !time.Now().After(m.nextSpawnAt) {
return true
}
if err := ensureAgentSocketDir(); err != nil {
log.Warnf("prepare agent socket dir: %v", err)
m.nextSpawnAt = time.Now().Add(5 * time.Second)
return true
}
// The leaf name carries a cryptographically random component so a local
// user cannot pre-create the path at a guessable location. The session
// id is kept for diagnostics only; security does not rely on it.
socketPath, err := newAgentSocketPath(sid)
if err != nil {
log.Warnf("generate agent socket path: %v", err)
return true
}
// Covers a previous-run crash that escaped Job Object kill-on-close.
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
log.Debugf("clear stale agent socket %s: %v", socketPath, err)
}
token, err := generateAuthToken()
if err != nil {
log.Warnf("generate agent auth token: %v", err)
return true
}
h, err := spawnAgentInSession(sid, socketPath, token, m.jobHandle)
if err != nil {
if errors.Is(err, windows.ERROR_PRIVILEGE_NOT_HELD) {
// SE_TCB_NAME (token-impersonation across sessions) is only
// granted to SYSTEM. Without it spawnAgent will fail every 2
// seconds forever: log once and give up.
log.Warnf("VNC service mode disabled: agent spawn requires SYSTEM privileges (got: %v)", err)
return false
}
log.Warnf("spawn agent in session %d: %v", sid, err)
return true
}
// Gate on listen-readiness before publishing the path: do not hand a
// caller a socket the agent has not bound yet. On timeout, fail closed
// by killing the agent and leaving socketPath/authToken unset so
// Resolve keeps returning errAgentNotReady.
if err := waitForAgentListening(socketPath, agentReadyTimeout); err != nil {
log.Warnf("agent in session %d did not start listening: %v", sid, err)
_ = windows.TerminateProcess(h, 1)
_ = windows.CloseHandle(h)
if rmErr := os.Remove(socketPath); rmErr != nil && !os.IsNotExist(rmErr) {
log.Debugf("clear unready agent socket %s: %v", socketPath, rmErr)
}
m.scheduleNextSpawn(0, 0)
return true
}
m.authToken = token
m.socketPath = socketPath
m.agentProc = h
m.agentStartedAt = time.Now()
m.everSpawned = true
return true
}
// ensureAgentSocketDir creates the dedicated socket directory with a
// restrictive DACL (SYSTEM + Administrators only). A pre-existing directory
// is torn down and recreated rather than reused: it may have been created by
// an unprivileged user with a permissive ACL, and it only ever holds our
// transient sockets, so removing it loses nothing. Fails closed: returns an
// error if the directory cannot be created with the intended security.
func ensureAgentSocketDir() error {
sd, err := windows.SecurityDescriptorFromString(agentSocketDirSDDL)
if err != nil {
return fmt.Errorf("parse socket dir SDDL: %w", err)
}
var sa windows.SecurityAttributes
sa.Length = uint32(unsafe.Sizeof(sa))
sa.SecurityDescriptor = sd
dirW, err := windows.UTF16PtrFromString(agentSocketDir)
if err != nil {
return fmt.Errorf("encode socket dir path: %w", err)
}
err = windows.CreateDirectory(dirW, &sa)
if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
if rmErr := os.RemoveAll(agentSocketDir); rmErr != nil {
return fmt.Errorf("remove pre-existing socket dir %s: %w", agentSocketDir, rmErr)
}
err = windows.CreateDirectory(dirW, &sa)
}
if err != nil {
return fmt.Errorf("create socket dir %s: %w", agentSocketDir, err)
}
return nil
}
// newAgentSocketPath returns a per-spawn socket path inside the secured
// socket directory. The leaf name mixes a cryptographically random component
// with the session id (for diagnostics) so the path is unguessable before the
// agent binds it.
func newAgentSocketPath(sessionID uint32) (string, error) {
b := make([]byte, agentSocketRandomLen)
if _, err := crand.Read(b); err != nil {
return "", fmt.Errorf("read random: %w", err)
}
name := fmt.Sprintf("netbird-vnc-%d-%s.sock", sessionID, hex.EncodeToString(b))
return filepath.Join(agentSocketDir, name), nil
}
// waitForAgentListening dials the agent's Unix socket until it answers or the
// timeout elapses. Mirrors the darwin readiness gate so the daemon never
// exposes a socket path before the legitimate agent owns it.
func waitForAgentListening(socketPath string, wait time.Duration) error {
var d net.Dialer
deadline := time.Now().Add(wait)
var lastErr error
for time.Now().Before(deadline) {
c, err := d.Dial("unix", socketPath)
if err == nil {
_ = c.Close()
return nil
}
lastErr = err
time.Sleep(100 * time.Millisecond)
}
if lastErr == nil {
lastErr = fmt.Errorf("timeout")
}
return fmt.Errorf("dial %s: %w", socketPath, lastErr)
}
func (m *sessionManager) killAgent() {
if m.agentProc == 0 {
return
}
_ = windows.TerminateProcess(m.agentProc, 0)
_ = windows.CloseHandle(m.agentProc)
m.agentProc = 0
m.authToken = ""
m.socketPath = ""
log.Info("killed old agent")
}
// relogAgentOutput reads log lines from the agent's stderr pipe and
// relogs them with the service's formatter. The *os.File owns the
// underlying handle, so closing it suffices.
func relogAgentOutput(pipe windows.Handle) {
f := os.NewFile(uintptr(pipe), "vnc-agent-stderr")
defer func() { _ = f.Close() }()
relogAgentStream(f)
}
// logCleanupCall invokes a Windows syscall used solely as a cleanup primitive
// (CloseClipboard, ReleaseDC, etc.) and logs failures at trace level. The
// indirection lets us satisfy errcheck without scattering ignored returns at
// each call site, while still capturing diagnostic info when the OS reports
// a failure.
func logCleanupCall(name string, proc *windows.LazyProc) {
r, _, err := proc.Call()
if r == 0 && err != nil && err != windows.NTE_OP_OK {
log.Tracef("%s: %v", name, err)
}
}
// logCleanupCallArgs is logCleanupCall with one argument; common pattern for
// release-by-handle syscalls.
func logCleanupCallArgs(name string, proc *windows.LazyProc, args ...uintptr) {
r, _, err := proc.Call(args...)
if r == 0 && err != nil && err != windows.NTE_OP_OK {
log.Tracef("%s: %v", name, err)
}
}

Some files were not shown because too many files have changed in this diff Show More