mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-22 15:59:59 +00:00
Compare commits
3 Commits
embedded-v
...
socket-grp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d65927275d | ||
|
|
064f7bf0fd | ||
|
|
644615fed6 |
@@ -3,14 +3,12 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/user"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
@@ -87,73 +85,6 @@ var persistenceCmd = &cobra.Command{
|
||||
RunE: setSyncResponsePersistence,
|
||||
}
|
||||
|
||||
var debugConfigCmd = &cobra.Command{
|
||||
Use: "config",
|
||||
Example: " netbird debug config",
|
||||
Short: "Dump the effective configuration",
|
||||
Long: "Prints the daemon's resolved configuration (after applying defaults, file, env, CLI input, and MDM policy overrides) as JSON. Includes the list of MDM-managed fields.",
|
||||
RunE: debugConfigDump,
|
||||
}
|
||||
|
||||
// debugConfigDump implements `netbird debug config`. It resolves the
|
||||
// active profile, queries the daemon for the effective configuration
|
||||
// via GetConfig, and prints the resulting GetConfigResponse as JSON
|
||||
// (via protojson with EmitUnpopulated=true so the output is stable
|
||||
// across runs and includes zero-valued fields).
|
||||
//
|
||||
// Useful for verifying MDM enforcement end-to-end: the response's
|
||||
// mDMManagedFields array is the single source of truth for "which
|
||||
// fields is the daemon currently enforcing from the MDM source", and
|
||||
// every config field side-by-side with that list confirms the merge
|
||||
// result. Secrets in the response (e.g. PreSharedKey) are already
|
||||
// redacted by the daemon-side handler.
|
||||
func debugConfigDump(cmd *cobra.Command, _ []string) error {
|
||||
pm := profilemanager.NewProfileManager()
|
||||
activeProf, err := pm.GetActiveProfile()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile: %v", err)
|
||||
}
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %v", err)
|
||||
}
|
||||
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.Name,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get config: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
// Use protojson so well-known fields render correctly; emit defaults so
|
||||
// the operator sees every field even when zero/empty.
|
||||
m := protojson.MarshalOptions{Multiline: true, Indent: " ", EmitUnpopulated: true}
|
||||
out, err := m.Marshal(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
cmd.Println(string(out))
|
||||
return nil
|
||||
}
|
||||
|
||||
// debugBundle requests the daemon to create a debug bundle and prints
|
||||
// the resulting local file path and, if uploaded, the uploaded file
|
||||
// key. It uses the package flags (anonymize, system info, log file
|
||||
// count, CLI version, optional upload URL) to configure the bundle
|
||||
// request. Returns an error if the RPC fails or if the daemon reports
|
||||
// an upload failure reason.
|
||||
func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,301 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
KubernetesDNSSuffix = "netbird-kubeapi-proxy"
|
||||
)
|
||||
|
||||
var kubernetesCmd = &cobra.Command{
|
||||
Use: "kubernetes",
|
||||
Short: "Kubernetes cluster commands.",
|
||||
Long: "Kubernetes cluster commands.",
|
||||
}
|
||||
|
||||
var kubernetesListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
RunE: kubernetesList,
|
||||
Short: "List Kubernetes clusters.",
|
||||
Long: "List Kubernetes clusters by discovering NetBird peers running netbird-kubeapi-proxy.",
|
||||
}
|
||||
|
||||
var kubernetesWriteKubeconfigCmd = &cobra.Command{
|
||||
Use: "write-kubeconfig",
|
||||
RunE: kubernetesWriteKubeconfig,
|
||||
Args: cobra.ExactArgs(1),
|
||||
Short: "Write kubeconfig for a Kubernetes cluster.",
|
||||
Long: "Updates kubeconfig in place to allow token-less access to the Kubernetes cluster through NetBird.",
|
||||
}
|
||||
|
||||
func init() {
|
||||
kubernetesWriteKubeconfigCmd.Flags().String("kubeconfig", "", "path to kubeconfig file")
|
||||
}
|
||||
|
||||
func kubernetesList(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
statusResp, err := client.Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
kcs, err := getKubernetesClusters(cmd.Context(), statusResp.FullStatus.Peers, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(kcs) == 0 {
|
||||
cmd.Println("No Kubernetes clusters available.")
|
||||
return nil
|
||||
}
|
||||
cmd.Println("Available Kubernetes clusters:")
|
||||
for _, k := range kcs {
|
||||
cmd.Printf("\n - Name: %s\n FQDN: %s\n Version: %s\n", k.name, k.url.Host, k.version)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func kubernetesWriteKubeconfig(cmd *cobra.Command, args []string) error {
|
||||
kubeconfigPath, err := resolveKubeconfigPath(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
statusResp, err := client.Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clusterName := args[0]
|
||||
kcs, err := getKubernetesClusters(cmd.Context(), statusResp.FullStatus.Peers, clusterName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(kcs) == 0 {
|
||||
return fmt.Errorf("kubernetes cluster named %s not found", clusterName)
|
||||
}
|
||||
if len(kcs) > 1 {
|
||||
return fmt.Errorf("too many Kubernetes clusters returned")
|
||||
}
|
||||
err = writeKubeconfig(kubeconfigPath, kcs[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type kubernetesCluster struct {
|
||||
name string
|
||||
url *url.URL
|
||||
version string
|
||||
}
|
||||
|
||||
func getKubernetesClusters(ctx context.Context, peers []*proto.PeerState, nameFilter string) ([]kubernetesCluster, error) {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
httpClient := &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
resolver := net.Resolver{
|
||||
// Required so both DNS records are returned.
|
||||
// https://github.com/golang/go/issues/17093
|
||||
PreferGo: true,
|
||||
}
|
||||
|
||||
kcs := []kubernetesCluster{}
|
||||
attempted := map[string]struct{}{}
|
||||
for _, peer := range peers {
|
||||
fqdns, err := resolver.LookupAddr(ctx, peer.IP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, fqdn := range fqdns {
|
||||
if _, ok := attempted[fqdn]; ok {
|
||||
continue
|
||||
}
|
||||
attempted[fqdn] = struct{}{}
|
||||
comps := strings.Split(fqdn, ".")
|
||||
if len(comps) < 2 {
|
||||
continue
|
||||
}
|
||||
if comps[1] != KubernetesDNSSuffix {
|
||||
continue
|
||||
}
|
||||
if nameFilter != "" && nameFilter != comps[0] {
|
||||
continue
|
||||
}
|
||||
clusterURL, clusterVersion, err := fingerprintClusters(ctx, httpClient, fqdn)
|
||||
if err != nil {
|
||||
log.Debugf("could not fingerprint Kubernetes cluster %s %q", fqdn, err)
|
||||
continue
|
||||
}
|
||||
kc := kubernetesCluster{
|
||||
name: comps[0],
|
||||
url: clusterURL,
|
||||
version: clusterVersion,
|
||||
}
|
||||
if nameFilter != "" {
|
||||
return []kubernetesCluster{kc}, nil
|
||||
}
|
||||
kcs = append(kcs, kc)
|
||||
}
|
||||
}
|
||||
return kcs, nil
|
||||
}
|
||||
|
||||
func fingerprintClusters(ctx context.Context, httpClient *http.Client, fqdn string) (*url.URL, string, error) {
|
||||
clusterURL, err := url.Parse("https://" + fqdn)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
versionURL, err := clusterURL.Parse("/version")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, versionURL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, "", fmt.Errorf("expected %d response but got %s", http.StatusOK, resp.Status)
|
||||
}
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
versionData := map[string]string{}
|
||||
err = json.Unmarshal(b, &versionData)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
version, ok := versionData["gitVersion"]
|
||||
if !ok {
|
||||
return nil, "", errors.New("no version found in response")
|
||||
}
|
||||
return clusterURL, version, nil
|
||||
}
|
||||
|
||||
func resolveKubeconfigPath(cmd *cobra.Command) (string, error) {
|
||||
if cmd.Flags().Changed("kubeconfig") {
|
||||
path, err := cmd.Flags().GetString("kubeconfig")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
if env := os.Getenv("KUBECONFIG"); env != "" {
|
||||
return env, nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not determine home directory: %w", err)
|
||||
}
|
||||
return filepath.Join(home, ".kube", "config"), nil
|
||||
}
|
||||
|
||||
func writeKubeconfig(kubeconfigPath string, kc kubernetesCluster) error {
|
||||
b, err := os.ReadFile(kubeconfigPath)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
var cfg map[string]any
|
||||
if err := yaml.Unmarshal(b, &cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if cfg == nil {
|
||||
cfg = map[string]any{
|
||||
"apiVersion": "v1",
|
||||
"kind": "Config",
|
||||
}
|
||||
}
|
||||
|
||||
cfg["clusters"] = appendWithName(cfg["clusters"], map[string]any{
|
||||
"name": kc.name,
|
||||
"cluster": map[string]any{
|
||||
"server": kc.url.String(),
|
||||
"insecure-skip-tls-verify": true,
|
||||
},
|
||||
})
|
||||
cfg["users"] = appendWithName(cfg["users"], map[string]any{
|
||||
"name": "netbird",
|
||||
"user": map[string]any{
|
||||
"token": "none",
|
||||
},
|
||||
})
|
||||
cfg["contexts"] = appendWithName(cfg["contexts"], map[string]any{
|
||||
"name": kc.name,
|
||||
"context": map[string]any{
|
||||
"cluster": kc.name,
|
||||
"user": "netbird",
|
||||
"namespace": "default",
|
||||
},
|
||||
})
|
||||
cfg["current-context"] = kc.name
|
||||
|
||||
out, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(kubeconfigPath, out, 0o600); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func appendWithName(data any, add map[string]any) any {
|
||||
if data == nil {
|
||||
return []any{add}
|
||||
}
|
||||
v, ok := data.([]any)
|
||||
if !ok {
|
||||
return []any{add}
|
||||
}
|
||||
i := slices.IndexFunc(v, func(item any) bool {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return m["name"] == add["name"]
|
||||
})
|
||||
if i == -1 {
|
||||
return append(v, add)
|
||||
}
|
||||
v[i] = add
|
||||
return v
|
||||
}
|
||||
@@ -1,120 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFingerprintClusters(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
//nolint: errcheck
|
||||
w.Write([]byte(`{"gitVersion": "foobar"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
clusterURL, clusterVersion, err := fingerprintClusters(t.Context(), srv.Client(), srv.Listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, srv.URL, clusterURL.String())
|
||||
require.Equal(t, "foobar", clusterVersion)
|
||||
}
|
||||
|
||||
func TestResolveKubeconfigPath(t *testing.T) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
t.Fatalf("could not determine home directory: %v", err)
|
||||
}
|
||||
defaultPath := filepath.Join(home, ".kube", "config")
|
||||
path, err := resolveKubeconfigPath(&cobra.Command{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, defaultPath, path)
|
||||
|
||||
flagPath := "flag-path"
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().String("kubeconfig", "", "")
|
||||
err = cmd.Flags().Set("kubeconfig", flagPath)
|
||||
require.NoError(t, err)
|
||||
path, err = resolveKubeconfigPath(cmd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, flagPath, path)
|
||||
|
||||
envPath := "env-path"
|
||||
t.Setenv("KUBECONFIG", envPath)
|
||||
path, err = resolveKubeconfigPath(&cobra.Command{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, envPath, path)
|
||||
}
|
||||
|
||||
func TestWriteKubeconfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
existing string
|
||||
}{
|
||||
{
|
||||
name: "empty file",
|
||||
},
|
||||
{
|
||||
name: "existing content",
|
||||
existing: `apiVersion: v1
|
||||
clusters:
|
||||
- cluster:
|
||||
insecure-skip-tls-verify: true
|
||||
server: https://foobar.com
|
||||
name: foo
|
||||
current-context: test
|
||||
kind: Config
|
||||
users: []
|
||||
`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
kubeconfigPath := filepath.Join(t.TempDir(), "config")
|
||||
err := os.WriteFile(kubeconfigPath, []byte(tt.existing), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
kc := kubernetesCluster{
|
||||
name: "foo",
|
||||
url: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
}
|
||||
err = writeKubeconfig(kubeconfigPath, kc)
|
||||
require.NoError(t, err)
|
||||
|
||||
b, err := os.ReadFile(kubeconfigPath)
|
||||
require.NoError(t, err)
|
||||
expected := `apiVersion: v1
|
||||
clusters:
|
||||
- cluster:
|
||||
insecure-skip-tls-verify: true
|
||||
server: https://example.com
|
||||
name: foo
|
||||
contexts:
|
||||
- context:
|
||||
cluster: foo
|
||||
namespace: default
|
||||
user: netbird
|
||||
name: foo
|
||||
current-context: foo
|
||||
kind: Config
|
||||
users:
|
||||
- name: netbird
|
||||
user:
|
||||
token: none
|
||||
`
|
||||
require.Equal(t, expected, string(b))
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
36
client/cmd/peercred_bsd.go
Normal file
36
client/cmd/peercred_bsd.go
Normal file
@@ -0,0 +1,36 @@
|
||||
//go:build darwin || freebsd
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// peerUID returns the uid of the process on the other end of a unix socket
|
||||
// connection, read via LOCAL_PEERCRED (xucred). Note: xucred carries the uid
|
||||
// and group list but no pid, so audit on these platforms is uid-based.
|
||||
func peerUID(c net.Conn) (int, error) {
|
||||
uc, ok := c.(*net.UnixConn)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("connection is not a unix socket: %T", c)
|
||||
}
|
||||
raw, err := uc.SyscallConn()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("raw conn: %w", err)
|
||||
}
|
||||
|
||||
var cred *unix.Xucred
|
||||
var credErr error
|
||||
if err := raw.Control(func(fd uintptr) {
|
||||
cred, credErr = unix.GetsockoptXucred(int(fd), unix.SOL_LOCAL, unix.LOCAL_PEERCRED)
|
||||
}); err != nil {
|
||||
return 0, fmt.Errorf("getsockopt control: %w", err)
|
||||
}
|
||||
if credErr != nil {
|
||||
return 0, fmt.Errorf("LOCAL_PEERCRED: %w", credErr)
|
||||
}
|
||||
return int(cred.Uid), nil
|
||||
}
|
||||
35
client/cmd/peercred_linux.go
Normal file
35
client/cmd/peercred_linux.go
Normal file
@@ -0,0 +1,35 @@
|
||||
//go:build linux
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// peerUID returns the uid of the process on the other end of a unix socket
|
||||
// connection, read from the kernel via SO_PEERCRED.
|
||||
func peerUID(c net.Conn) (int, error) {
|
||||
uc, ok := c.(*net.UnixConn)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("connection is not a unix socket: %T", c)
|
||||
}
|
||||
raw, err := uc.SyscallConn()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("raw conn: %w", err)
|
||||
}
|
||||
|
||||
var cred *unix.Ucred
|
||||
var credErr error
|
||||
if err := raw.Control(func(fd uintptr) {
|
||||
cred, credErr = unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED)
|
||||
}); err != nil {
|
||||
return 0, fmt.Errorf("getsockopt control: %w", err)
|
||||
}
|
||||
if credErr != nil {
|
||||
return 0, fmt.Errorf("SO_PEERCRED: %w", credErr)
|
||||
}
|
||||
return int(cred.Uid), nil
|
||||
}
|
||||
16
client/cmd/peercred_unsupported.go
Normal file
16
client/cmd/peercred_unsupported.go
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build !linux && !darwin && !freebsd
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// peerUID is unimplemented on this platform, so the trust-on-first-use socket
|
||||
// migration cannot run here. Configure --socket-owner explicitly, or use
|
||||
// --disable-strict-socket. (Windows uses a TCP socket and never reaches this.)
|
||||
func peerUID(net.Conn) (int, error) {
|
||||
return 0, fmt.Errorf("peer credential check not supported on %s", runtime.GOOS)
|
||||
}
|
||||
@@ -77,6 +77,8 @@ var (
|
||||
updateSettingsDisabled bool
|
||||
captureEnabled bool
|
||||
networksDisabled bool
|
||||
socketOwner string
|
||||
strictSocketDisabled bool
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
@@ -95,9 +97,7 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// Execute runs the appropriate Cobra command for the CLI.
|
||||
// If the process is the update binary it delegates to updateCmd; otherwise it runs the root command.
|
||||
// It returns any error produced during command execution.
|
||||
// Execute executes the root command.
|
||||
func Execute() error {
|
||||
if isUpdateBinary() {
|
||||
return updateCmd.Execute()
|
||||
@@ -105,16 +105,6 @@ func Execute() error {
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
// init initialises package-level defaults and configures the root
|
||||
// Cobra command tree. Sets platform-specific config / log directory
|
||||
// paths (including legacy Wiretrustee fallbacks) and a default daemon
|
||||
// address; registers persistent CLI flags (daemon address,
|
||||
// management / admin URLs, logging, setup key (file and inline,
|
||||
// mutually exclusive), preshared key, hostname, anonymise, config
|
||||
// path); attaches top-level and nested subcommands to the root
|
||||
// command; and registers `up`-specific persistent flags (external IP
|
||||
// maps, custom DNS resolver address, Rosenpass options, auto-connect
|
||||
// disabling, lazy connection).
|
||||
func init() {
|
||||
defaultConfigPathDir = "/etc/netbird/"
|
||||
defaultLogFileDir = "/var/log/netbird/"
|
||||
@@ -180,12 +170,6 @@ func init() {
|
||||
logCmd.AddCommand(logLevelCmd)
|
||||
debugCmd.AddCommand(forCmd)
|
||||
debugCmd.AddCommand(persistenceCmd)
|
||||
debugCmd.AddCommand(debugConfigCmd)
|
||||
|
||||
// kubernetes commands
|
||||
rootCmd.AddCommand(kubernetesCmd)
|
||||
kubernetesCmd.AddCommand(kubernetesListCmd)
|
||||
kubernetesCmd.AddCommand(kubernetesWriteKubeconfigCmd)
|
||||
|
||||
// profile commands
|
||||
profileCmd.AddCommand(profileListCmd)
|
||||
|
||||
@@ -57,6 +57,9 @@ func init() {
|
||||
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||
reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||
|
||||
serviceCmd.PersistentFlags().StringVar(&socketOwner, "socket-owner", "", "user to own the daemon control socket; restricts it to that user plus the netbird group (0660). If unset, the first client to connect claims ownership (trust-on-first-use)")
|
||||
serviceCmd.PersistentFlags().BoolVar(&strictSocketDisabled, "disable-strict-socket", false, "leave the daemon control socket world-writable (0666) instead of restricting it; set via the (root-only) service command")
|
||||
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,10 +4,15 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
@@ -16,6 +21,7 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/shell"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -54,10 +60,36 @@ func (p *program) Start(svc service.Service) error {
|
||||
go func() {
|
||||
defer listen.Close()
|
||||
|
||||
srvListener := listen
|
||||
if split[0] == "unix" {
|
||||
if err := os.Chmod(split[1], 0666); err != nil {
|
||||
log.Errorf("failed setting daemon permissions: %v", split[1])
|
||||
return
|
||||
owner := effectiveSocketOwner()
|
||||
switch {
|
||||
case strictSocketDisabled:
|
||||
// Opt-out (root-only, via service.json): leave it world-writable.
|
||||
if err := os.Chmod(split[1], 0666); err != nil {
|
||||
log.Errorf("failed setting daemon permissions: %v", split[1])
|
||||
return
|
||||
}
|
||||
case owner != "":
|
||||
// Seeded owner (flag, MDM, or persisted TOFU result): restrict
|
||||
// before serving so there is no open window.
|
||||
uid, err := lookupUser(owner)
|
||||
if err != nil {
|
||||
log.Errorf("lookup socket owner %q: %v", owner, err)
|
||||
return
|
||||
}
|
||||
if err := restrictSocket(split[1], uid); err != nil {
|
||||
log.Errorf("restrict socket to %q: %v", owner, err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
// Trust-on-first-use: open the socket now; tofuListener locks it
|
||||
// to the first caller's uid on the first connection.
|
||||
if err := os.Chmod(split[1], 0666); err != nil {
|
||||
log.Errorf("failed setting daemon permissions: %v", split[1])
|
||||
return
|
||||
}
|
||||
srvListener = &tofuListener{Listener: listen, path: split[1], owner: -1}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,13 +104,180 @@ func (p *program) Start(svc service.Service) error {
|
||||
p.serverInstanceMu.Unlock()
|
||||
|
||||
log.Printf("started daemon server: %v", split[1])
|
||||
if err := p.serv.Serve(listen); err != nil {
|
||||
if err := p.serv.Serve(srvListener); err != nil {
|
||||
log.Errorf("failed to serve daemon requests: %v", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func lookupUser(username string) (int, error) {
|
||||
u, err := shell.LookupWithGetent(username)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("lookup user %s: %w", username, err)
|
||||
}
|
||||
uid, err := strconv.Atoi(u.Uid)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("parse uid %s: %w", u.Uid, err)
|
||||
}
|
||||
return uid, nil
|
||||
}
|
||||
|
||||
// addGroup creates a system group if it doesn't already exist and returns the gid.
|
||||
// Must run as root.
|
||||
func addGroup(name string) (int, error) {
|
||||
group, err := shell.LookupGroupWithGetent(name)
|
||||
if err == nil {
|
||||
gid, err := strconv.ParseInt(group.Gid, 10, 64)
|
||||
return int(gid), err
|
||||
}
|
||||
|
||||
// looup failed, create the group
|
||||
groupadd, err := exec.LookPath("groupadd")
|
||||
if err != nil {
|
||||
// Fallback for Alpine/BusyBox systems.
|
||||
if groupadd, err = exec.LookPath("addgroup"); err != nil {
|
||||
return -1, errors.New("neither groupadd nor addgroup found")
|
||||
}
|
||||
}
|
||||
|
||||
// Use --system for a service/daemon group (no login, low GID).
|
||||
out, err := exec.Command(groupadd, "--system", name).CombinedOutput()
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("create group %q: %w: %s", name, err, out)
|
||||
}
|
||||
if group, err := shell.LookupWithGetent(name); err == nil {
|
||||
gid, err := strconv.ParseInt(group.Gid, 10, 64)
|
||||
return int(gid), err
|
||||
}
|
||||
return -1, fmt.Errorf("lookup group %q: %w", name, err)
|
||||
}
|
||||
|
||||
// restrictSocket locks the unix socket down to the owner uid plus the netbird
|
||||
// group (0660). If the group cannot be created or applied, it fails closed to
|
||||
// owner-only 0600 — it never leaves the socket world-writable.
|
||||
func restrictSocket(path string, uid int) error {
|
||||
// TODO: introduce flag to configure this (LDAP/AD usecase)
|
||||
gid, err := addGroup("netbird")
|
||||
if err != nil {
|
||||
log.Errorf("create netbird group, failing closed to owner-only 0600: %v", err)
|
||||
return chownChmod(path, uid, -1, 0600)
|
||||
}
|
||||
if err := chownChmod(path, uid, gid, 0660); err != nil {
|
||||
log.Errorf("apply netbird group to socket, failing closed to owner-only 0600: %v", err)
|
||||
return chownChmod(path, uid, -1, 0600)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// chownChmod sets ownership and mode on the socket. A gid of -1 leaves the
|
||||
// group unchanged.
|
||||
func chownChmod(path string, uid, gid int, mode os.FileMode) error {
|
||||
if err := os.Chown(path, uid, gid); err != nil {
|
||||
return fmt.Errorf("chown socket %s: %w", path, err)
|
||||
}
|
||||
if err := os.Chmod(path, mode); err != nil {
|
||||
return fmt.Errorf("chmod socket %s: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// tofuListener implements trust-on-first-use for the daemon control socket.
|
||||
// The socket starts world-writable; the first caller's uid (read via the
|
||||
// platform peer-credential mechanism) becomes the owner. On that first
|
||||
// connection the socket is restricted (see restrictSocket) and the owner is
|
||||
// persisted so the open window never reopens on later starts. Connections that
|
||||
// raced in during the open window and are neither the owner nor root are
|
||||
// dropped. Changing the socket mode does not disturb the already-open
|
||||
// connection, so the first caller's request is served normally.
|
||||
type tofuListener struct {
|
||||
net.Listener
|
||||
path string
|
||||
mu sync.Mutex
|
||||
owner int // -1 until claimed
|
||||
}
|
||||
|
||||
func (l *tofuListener) Accept() (net.Conn, error) {
|
||||
for {
|
||||
c, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uid, err := peerUID(c)
|
||||
if err != nil {
|
||||
log.Errorf("read peer credentials, dropping connection: %v", err)
|
||||
_ = c.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
if l.owner == -1 {
|
||||
if err := restrictSocket(l.path, uid); err != nil {
|
||||
l.mu.Unlock()
|
||||
_ = c.Close()
|
||||
// Refuse to serve on a socket we could not lock down.
|
||||
return nil, fmt.Errorf("restrict socket on first connection: %w", err)
|
||||
}
|
||||
l.owner = uid
|
||||
persistSocketOwner(uid)
|
||||
log.Infof("control socket restricted to first caller (uid %d)", uid)
|
||||
l.mu.Unlock()
|
||||
return c, nil
|
||||
}
|
||||
owner := l.owner
|
||||
l.mu.Unlock()
|
||||
|
||||
// New connects are already gated by the 0660 perms set above; this only
|
||||
// drops anything that slipped in during the brief open window.
|
||||
if uid != owner && uid != 0 {
|
||||
log.Warnf("dropping non-owner connection (uid %d) during socket bootstrap", uid)
|
||||
_ = c.Close()
|
||||
continue
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
// effectiveSocketOwner returns the configured socket owner: the --socket-owner
|
||||
// flag when set, otherwise the owner persisted by a previous TOFU migration.
|
||||
func effectiveSocketOwner() string {
|
||||
if socketOwner != "" {
|
||||
return socketOwner
|
||||
}
|
||||
params, err := loadServiceParams()
|
||||
if err != nil {
|
||||
log.Errorf("load service params for socket owner: %v", err)
|
||||
return ""
|
||||
}
|
||||
if params != nil {
|
||||
return params.SocketOwner
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// persistSocketOwner records the TOFU-selected owner (by username) so the next
|
||||
// daemon start restricts the socket immediately, with no open window.
|
||||
func persistSocketOwner(uid int) {
|
||||
u, err := user.LookupId(strconv.Itoa(uid))
|
||||
if err != nil {
|
||||
log.Errorf("resolve uid %d to username for persistence: %v", uid, err)
|
||||
return
|
||||
}
|
||||
params, err := loadServiceParams()
|
||||
if err != nil {
|
||||
log.Errorf("load service params to persist socket owner: %v", err)
|
||||
return
|
||||
}
|
||||
if params == nil {
|
||||
params = currentServiceParams()
|
||||
}
|
||||
params.SocketOwner = u.Username
|
||||
if err := saveServiceParams(params); err != nil {
|
||||
log.Errorf("persist socket owner: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *program) Stop(srv service.Service) error {
|
||||
p.serverInstanceMu.Lock()
|
||||
if p.serverInstance != nil {
|
||||
|
||||
@@ -67,6 +67,14 @@ func buildServiceArguments() []string {
|
||||
args = append(args, "--disable-networks")
|
||||
}
|
||||
|
||||
if socketOwner != "" {
|
||||
args = append(args, "--socket-owner", socketOwner)
|
||||
}
|
||||
|
||||
if strictSocketDisabled {
|
||||
args = append(args, "--disable-strict-socket")
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
@@ -127,6 +135,8 @@ var installCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("SUDO_UID: %s\n", os.Getenv("SUDO_UID"))
|
||||
|
||||
if err := loadAndApplyServiceParams(cmd); err != nil {
|
||||
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
|
||||
}
|
||||
|
||||
@@ -30,6 +30,8 @@ type serviceParams struct {
|
||||
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
||||
EnableCapture bool `json:"enable_capture,omitempty"`
|
||||
DisableNetworks bool `json:"disable_networks,omitempty"`
|
||||
SocketOwner string `json:"socket_owner,omitempty"`
|
||||
DisableStrictSocket bool `json:"disable_strict_socket,omitempty"`
|
||||
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
||||
}
|
||||
|
||||
@@ -82,6 +84,8 @@ func currentServiceParams() *serviceParams {
|
||||
DisableUpdateSettings: updateSettingsDisabled,
|
||||
EnableCapture: captureEnabled,
|
||||
DisableNetworks: networksDisabled,
|
||||
SocketOwner: socketOwner,
|
||||
DisableStrictSocket: strictSocketDisabled,
|
||||
}
|
||||
|
||||
if len(serviceEnvVars) > 0 {
|
||||
@@ -154,6 +158,14 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
|
||||
networksDisabled = params.DisableNetworks
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("socket-owner") {
|
||||
socketOwner = params.SocketOwner
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("disable-strict-socket") {
|
||||
strictSocketDisabled = params.DisableStrictSocket
|
||||
}
|
||||
|
||||
applyServiceEnvParams(cmd, params)
|
||||
}
|
||||
|
||||
|
||||
106
client/cmd/up.go
106
client/cmd/up.go
@@ -361,12 +361,6 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
req.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
req.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
if cmd.Flag(disableVNCApprovalFlag).Changed {
|
||||
req.DisableVNCApproval = &disableVNCApproval
|
||||
}
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
req.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
@@ -473,14 +467,30 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
ic.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
if cmd.Flag(disableVNCApprovalFlag).Changed {
|
||||
ic.DisableVNCApproval = &disableVNCApproval
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
ic.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
|
||||
applySSHFlagsToConfig(cmd, &ic)
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
ic.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
ic.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
@@ -556,49 +566,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
return &ic, nil
|
||||
}
|
||||
|
||||
func applySSHFlagsToConfig(cmd *cobra.Command, ic *profilemanager.ConfigInput) {
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
ic.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
ic.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
ic.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
|
||||
}
|
||||
}
|
||||
|
||||
func applySSHFlagsToLogin(cmd *cobra.Command, req *proto.LoginRequest) {
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
req.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
req.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
req.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
ttl := int32(sshJWTCacheTTL)
|
||||
req.SshJWTCacheTTL = &ttl
|
||||
}
|
||||
}
|
||||
|
||||
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
@@ -628,14 +595,31 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
if cmd.Flag(serverSSHAllowedFlag).Changed {
|
||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
if cmd.Flag(serverVNCAllowedFlag).Changed {
|
||||
loginRequest.ServerVNCAllowed = &serverVNCAllowed
|
||||
}
|
||||
if cmd.Flag(disableVNCApprovalFlag).Changed {
|
||||
loginRequest.DisableVNCApproval = &disableVNCApproval
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
loginRequest.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
|
||||
applySSHFlagsToLogin(cmd, &loginRequest)
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
loginRequest.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
loginRequest.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
|
||||
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
|
||||
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
//go:build windows || (darwin && !ios)
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
var (
|
||||
vncAgentSocket string
|
||||
vncAgentTargetUID uint32
|
||||
)
|
||||
|
||||
func init() {
|
||||
vncAgentCmd.Flags().StringVar(&vncAgentSocket, "socket", "", "Unix-domain socket path the agent listens on (required)")
|
||||
vncAgentCmd.Flags().Uint32Var(&vncAgentTargetUID, "target-uid", 0, "uid the agent should drop privileges to before listening (darwin only; 0 = stay as current uid)")
|
||||
rootCmd.AddCommand(vncAgentCmd)
|
||||
}
|
||||
|
||||
// vncAgentCmd runs a VNC server inside the user's interactive session,
|
||||
// listening on a Unix-domain socket. The NetBird service spawns it: on
|
||||
// Windows via CreateProcessAsUser into the console session, on macOS via
|
||||
// launchctl asuser into the Aqua session.
|
||||
var vncAgentCmd = &cobra.Command{
|
||||
Use: "vnc-agent",
|
||||
Short: "Run VNC capture agent (internal, spawned by service)",
|
||||
Hidden: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
log.SetReportCaller(true)
|
||||
log.SetFormatter(&log.JSONFormatter{})
|
||||
log.SetOutput(os.Stderr)
|
||||
|
||||
if vncAgentSocket == "" {
|
||||
return fmt.Errorf("--socket is required")
|
||||
}
|
||||
|
||||
token := os.Getenv("NB_VNC_AGENT_TOKEN")
|
||||
if token == "" {
|
||||
return fmt.Errorf("NB_VNC_AGENT_TOKEN not set; agent requires a token from the service")
|
||||
}
|
||||
// Purge the token from env so it doesn't leak via /proc/<pid>/environ.
|
||||
if err := os.Unsetenv("NB_VNC_AGENT_TOKEN"); err != nil {
|
||||
log.Debugf("unset NB_VNC_AGENT_TOKEN: %v", err)
|
||||
}
|
||||
|
||||
// Drop root privileges to the target console user BEFORE creating
|
||||
// the listening socket: keeps a post-auth bug in the encoder /
|
||||
// input / capture paths confined to the user's own privileges
|
||||
// rather than escalating to host root, and makes the daemon's
|
||||
// LOCAL_PEERCRED check see the right uid. No-op on Windows
|
||||
// (both processes run as SYSTEM) and when --target-uid is 0.
|
||||
if vncAgentTargetUID != 0 {
|
||||
if err := dropAgentPrivileges(vncAgentTargetUID); err != nil {
|
||||
return fmt.Errorf("drop privileges to uid %d: %w", vncAgentTargetUID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := os.Remove(vncAgentSocket); err != nil && !os.IsNotExist(err) {
|
||||
log.Debugf("remove stale socket %s: %v", vncAgentSocket, err)
|
||||
}
|
||||
ln, err := net.Listen("unix", vncAgentSocket)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", vncAgentSocket, err)
|
||||
}
|
||||
if err := os.Chmod(vncAgentSocket, 0o600); err != nil {
|
||||
log.Debugf("chmod %s: %v", vncAgentSocket, err)
|
||||
}
|
||||
|
||||
capturer, injector, err := newAgentResources()
|
||||
if err != nil {
|
||||
_ = ln.Close()
|
||||
return err
|
||||
}
|
||||
srv := vncserver.New(vncserver.Config{
|
||||
Capturer: capturer,
|
||||
Injector: injector,
|
||||
DisableAuth: true,
|
||||
AgentTokenHex: token,
|
||||
Listener: ln,
|
||||
})
|
||||
|
||||
if err := srv.Start(cmd.Context(), netip.AddrPort{}, netip.Prefix{}); err != nil {
|
||||
return fmt.Errorf("start vnc server: %w", err)
|
||||
}
|
||||
log.Infof("vnc-agent listening on %s, ready", vncAgentSocket)
|
||||
|
||||
<-cmd.Context().Done()
|
||||
log.Info("vnc-agent context cancelled, shutting down")
|
||||
return srv.Stop()
|
||||
},
|
||||
SilenceUsage: true,
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newAgentResources() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
|
||||
capturer := vncserver.NewMacPoller()
|
||||
injector, err := vncserver.NewMacInputInjector()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("macOS input injector: %w", err)
|
||||
}
|
||||
return capturer, injector, nil
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// dropAgentPrivileges drops the vnc-agent process from root (its
|
||||
// launchctl-asuser-inherited starting uid) to the target console user
|
||||
// before any other initialisation runs. Without this the agent runs as
|
||||
// root for the lifetime of the session; any post-auth memory-safety
|
||||
// issue in the capture/input/encode paths would then be a root-level
|
||||
// RCE on the host instead of a user-level one. Also makes the daemon's
|
||||
// LOCAL_PEERCRED check correctly identify the agent as the console user,
|
||||
// not as root.
|
||||
//
|
||||
// Returns an error when the agent is running as a non-root uid that
|
||||
// differs from targetUID: non-root can only setuid to itself, so a
|
||||
// mismatch here means the spawn went to the wrong session.
|
||||
func dropAgentPrivileges(targetUID uint32) error {
|
||||
if targetUID == 0 {
|
||||
return fmt.Errorf("refusing to keep agent running as root (target uid 0)")
|
||||
}
|
||||
cur := uint32(os.Getuid())
|
||||
if cur == targetUID {
|
||||
return nil
|
||||
}
|
||||
if cur != 0 {
|
||||
return fmt.Errorf("agent uid %d does not match expected %d and we lack root to fix it", cur, targetUID)
|
||||
}
|
||||
// Resolve the target user's real primary group rather than reusing
|
||||
// targetUID as the gid: a user's primary group on macOS is typically
|
||||
// staff(20), not gid==uid. Fail closed if the lookup fails.
|
||||
targetGID, err := primaryGroupID(targetUID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Drop supplementary groups first: setgid alone doesn't touch the
|
||||
// auxiliary group list, leaving root's groups attached would let the
|
||||
// dropped process write to root-only group-writable files.
|
||||
if err := syscall.Setgroups([]int{}); err != nil {
|
||||
return fmt.Errorf("setgroups([]): %w", err)
|
||||
}
|
||||
if err := syscall.Setgid(targetGID); err != nil {
|
||||
return fmt.Errorf("setgid(%d): %w", targetGID, err)
|
||||
}
|
||||
if err := syscall.Setuid(int(targetUID)); err != nil {
|
||||
return fmt.Errorf("setuid(%d): %w", targetUID, err)
|
||||
}
|
||||
if uint32(os.Getuid()) != targetUID || uint32(os.Geteuid()) != targetUID {
|
||||
return fmt.Errorf("setuid verification: uid=%d euid=%d, expected %d", os.Getuid(), os.Geteuid(), targetUID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// primaryGroupID resolves the real primary group id of the user with the
|
||||
// given uid. Fails closed: a lookup or parse error returns an error so the
|
||||
// caller never falls back to using uid as the gid.
|
||||
func primaryGroupID(targetUID uint32) (int, error) {
|
||||
u, err := user.LookupId(strconv.Itoa(int(targetUID)))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("look up uid %d: %w", targetUID, err)
|
||||
}
|
||||
gid, err := strconv.Atoi(u.Gid)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parse gid %q for uid %d: %w", u.Gid, targetUID, err)
|
||||
}
|
||||
return gid, nil
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestDropAgentPrivileges_RefusesRootTarget locks in the contract that
|
||||
// dropAgentPrivileges must never be a no-op when asked to keep the
|
||||
// agent as root (target uid 0). A future caller that passes 0 by
|
||||
// mistake would otherwise leave the post-auth attack surface running
|
||||
// with full root privileges.
|
||||
func TestDropAgentPrivileges_RefusesRootTarget(t *testing.T) {
|
||||
err := dropAgentPrivileges(0)
|
||||
if err == nil {
|
||||
t.Fatal("expected refusal for target uid 0, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "root") {
|
||||
t.Fatalf("error should mention root, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDropAgentPrivileges_NoOpWhenAlreadyTarget covers the dev path
|
||||
// where the agent is launched by hand as the target user (no root
|
||||
// available, no setuid needed). The helper must succeed silently
|
||||
// instead of trying (and failing) a setuid to its current uid.
|
||||
func TestDropAgentPrivileges_NoOpWhenAlreadyTarget(t *testing.T) {
|
||||
// Skip when running as root: the early-return path we want to
|
||||
// cover only fires when current uid == target uid.
|
||||
uid := currentUIDForTest()
|
||||
if uid == 0 {
|
||||
t.Skip("test must not run as root; cannot exercise the no-op early-return")
|
||||
}
|
||||
if err := dropAgentPrivileges(uid); err != nil {
|
||||
t.Fatalf("expected no-op when current uid == target, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDropAgentPrivileges_RefusesMismatchedNonRoot guards the "non-root
|
||||
// caller tries to setuid to a different uid" path: setuid would fail
|
||||
// with EPERM anyway, but the helper should surface a clear error
|
||||
// before issuing the syscall so a misconfigured spawn (wrong --target-uid
|
||||
// flag) is debuggable.
|
||||
func TestDropAgentPrivileges_RefusesMismatchedNonRoot(t *testing.T) {
|
||||
uid := currentUIDForTest()
|
||||
if uid == 0 {
|
||||
t.Skip("test must not run as root; covered case requires non-root caller")
|
||||
}
|
||||
err := dropAgentPrivileges(uid + 1)
|
||||
if err == nil {
|
||||
t.Fatal("expected refusal when non-root caller asks to setuid elsewhere")
|
||||
}
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package cmd
|
||||
|
||||
import "os"
|
||||
|
||||
// currentUIDForTest exposes os.Getuid for the darwin dropprivs tests
|
||||
// without leaking an os import into the test file itself.
|
||||
func currentUIDForTest() uint32 {
|
||||
return uint32(os.Getuid())
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package cmd
|
||||
|
||||
// dropAgentPrivileges is a no-op on Windows: the agent and the daemon
|
||||
// both run as SYSTEM (the daemon spawns the agent into the interactive
|
||||
// session via CreateProcessAsUser with an impersonation token, but the
|
||||
// resulting process still runs under SYSTEM, not under the user's
|
||||
// account). The Windows path relies on the DACL-restricted socket
|
||||
// directory, the unpredictable per-spawn socket name, the listen-readiness
|
||||
// gate, and the per-spawn token for integrity instead.
|
||||
func dropAgentPrivileges(_ uint32) error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newAgentResources() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
|
||||
sessionID := vncserver.GetCurrentSessionID()
|
||||
log.Infof("VNC agent running in Windows session %d", sessionID)
|
||||
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector(), nil
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
package cmd
|
||||
|
||||
const (
|
||||
serverVNCAllowedFlag = "allow-server-vnc"
|
||||
disableVNCApprovalFlag = "disable-vnc-approval"
|
||||
)
|
||||
|
||||
var (
|
||||
serverVNCAllowed bool
|
||||
disableVNCApproval bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&serverVNCAllowed, serverVNCAllowedFlag, false, "Allow embedded VNC server on peer")
|
||||
upCmd.PersistentFlags().BoolVar(&disableVNCApproval, disableVNCApprovalFlag, false, "Disable per-connection user approval prompts for the embedded VNC server")
|
||||
}
|
||||
@@ -6,30 +6,19 @@ import (
|
||||
"runtime"
|
||||
)
|
||||
|
||||
var (
|
||||
// StateDir holds persistent state (config, profiles, install metadata).
|
||||
StateDir string
|
||||
// RuntimeDir holds ephemeral artifacts that should not survive reboot,
|
||||
// such as Unix sockets for daemon and per-session IPC. Empty on
|
||||
// platforms without a conventional /var/run-style location.
|
||||
RuntimeDir string
|
||||
)
|
||||
var StateDir string
|
||||
|
||||
func init() {
|
||||
StateDir = os.Getenv("NB_STATE_DIR")
|
||||
if StateDir != "" {
|
||||
return
|
||||
}
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
StateDir = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird")
|
||||
case "darwin", "linux":
|
||||
StateDir = "/var/lib/netbird"
|
||||
RuntimeDir = "/var/run/netbird"
|
||||
case "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||
StateDir = "/var/db/netbird"
|
||||
RuntimeDir = "/var/run/netbird"
|
||||
}
|
||||
if v := os.Getenv("NB_STATE_DIR"); v != "" {
|
||||
StateDir = v
|
||||
}
|
||||
if v := os.Getenv("NB_RUNTIME_DIR"); v != "" {
|
||||
RuntimeDir = v
|
||||
}
|
||||
}
|
||||
|
||||
@@ -279,10 +279,6 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
// Cancel the client context before stopping: Engine.Start blocks on the
|
||||
// signal stream while holding the engine mutex and only unblocks on
|
||||
// cancellation. Stopping first would deadlock on that mutex.
|
||||
cancel()
|
||||
if stopErr := client.Stop(); stopErr != nil {
|
||||
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||
}
|
||||
@@ -446,8 +442,8 @@ func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession,
|
||||
|
||||
// IdentityForIP looks up a remote peer by its tunnel IP using the
|
||||
// embedded client's status recorder. Returns the peer's WireGuard public
|
||||
// key and FQDN. ok=false means the IP doesn't belong to an active peer
|
||||
// — offline roster peers are treated as unknown, same as foreign IPs.
|
||||
// key and FQDN. ok=false means the IP isn't in this client's peer
|
||||
// roster — callers should treat that as "unknown peer".
|
||||
func (c *Client) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
|
||||
if !ip.IsValid() || c.recorder == nil {
|
||||
return "", "", false
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
mgmt "github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const testSetupKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
// TestClientStartTimeoutRollback reproduces a deadlock between Engine.Start and
|
||||
// Engine.Stop. The signal endpoint accepts gRPC connections but never serves the
|
||||
// SignalExchange service, so Engine.Start parks in WaitStreamConnected while
|
||||
// holding the engine mutex. When the Start context expires, the rollback path
|
||||
// calls ConnectClient.Stop, which must not block forever acquiring that mutex.
|
||||
func TestClientStartTimeoutRollback(t *testing.T) {
|
||||
signalAddr := startBlackholeSignal(t)
|
||||
mgmAddr := startManagement(t, signalAddr)
|
||||
|
||||
wgPort := 0
|
||||
client, err := New(Options{
|
||||
DeviceName: "embed-rollback-test",
|
||||
SetupKey: testSetupKey,
|
||||
ManagementURL: "http://" + mgmAddr,
|
||||
WireguardPort: &wgPort,
|
||||
})
|
||||
require.NoError(t, err, "embed client creation must succeed")
|
||||
|
||||
startCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
startErr := make(chan error, 1)
|
||||
go func() {
|
||||
startErr <- client.Start(startCtx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-startErr:
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
case <-time.After(60 * time.Second):
|
||||
t.Fatal("client.Start did not return after its context expired: Engine.Stop deadlocked against Engine.Start waiting for the signal stream")
|
||||
}
|
||||
}
|
||||
|
||||
// startBlackholeSignal starts a gRPC server without the SignalExchange service
|
||||
// registered. Connections succeed, but the signal stream can never be
|
||||
// established, which keeps Engine.Start parked in WaitStreamConnected.
|
||||
func startBlackholeSignal(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
s := grpc.NewServer()
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
t.Cleanup(s.Stop)
|
||||
|
||||
return lis.Addr().String()
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, signalAddr string) string {
|
||||
t.Helper()
|
||||
|
||||
cfg := &config.Config{
|
||||
Stuns: []*config.Host{},
|
||||
TURNConfig: &config.TURNConfig{},
|
||||
Relay: &config.Relay{
|
||||
Addresses: []string{"127.0.0.1:1234"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "222222222222222222",
|
||||
},
|
||||
Signal: &config.Host{
|
||||
Proto: "http",
|
||||
URI: signalAddr,
|
||||
},
|
||||
Datadir: t.TempDir(),
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
s := grpc.NewServer()
|
||||
|
||||
testStore, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", cfg.Datadir)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
permissionsManager := permissions.NewManager(testStore)
|
||||
peersManager := peers.NewManager(testStore, permissionsManager)
|
||||
jobManager := job.NewJobManager(nil, testStore, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
iv, err := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
require.NoError(t, err)
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.EXPECT().
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
settingsMockManager.EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := mgmt.NewAccountRequestBuffer(context.Background(), testStore)
|
||||
networkMapController := controller.NewController(context.Background(), testStore, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(testStore, peersManager), cfg)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), cfg, testStore, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
require.NoError(t, err)
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, cfg.TURNConfig, cfg.Relay, settingsMockManager, groupsManager)
|
||||
require.NoError(t, err)
|
||||
|
||||
mgmtServer, err := nbgrpc.NewServer(cfg, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
require.NoError(t, err)
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
t.Cleanup(s.Stop)
|
||||
|
||||
return lis.Addr().String()
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package iptables
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net"
|
||||
"slices"
|
||||
|
||||
@@ -422,17 +421,12 @@ func (m *aclManager) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
// Clone the maps so the persisted state holds a private snapshot. The
|
||||
// live maps keep being mutated by subsequent rule operations while the
|
||||
// state manager marshals the state from its periodic-save goroutine.
|
||||
// Sharing them by reference races the two and aborts the process with a
|
||||
// concurrent map iteration and write.
|
||||
if m.v6 {
|
||||
currentState.ACLEntries6 = maps.Clone(m.entries)
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore.clone()
|
||||
currentState.ACLEntries6 = m.entries
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore
|
||||
} else {
|
||||
currentState.ACLEntries = maps.Clone(m.entries)
|
||||
currentState.ACLIPsetStore = m.ipsetStore.clone()
|
||||
currentState.ACLEntries = m.entries
|
||||
currentState.ACLIPsetStore = m.ipsetStore
|
||||
}
|
||||
|
||||
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||
|
||||
@@ -4,7 +4,6 @@ package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -750,17 +749,11 @@ func (r *router) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
// Clone the rule map so the persisted state holds a private snapshot. The
|
||||
// live map keeps being mutated by subsequent rule operations while the
|
||||
// state manager marshals the state from its periodic-save goroutine.
|
||||
// Sharing it by reference races the two and aborts the process with a
|
||||
// concurrent map iteration and write. The ipset counter guards itself
|
||||
// during marshaling, so it can be shared directly.
|
||||
if r.v6 {
|
||||
currentState.RouteRules6 = maps.Clone(r.rules)
|
||||
currentState.RouteRules6 = r.rules
|
||||
currentState.RouteIPsetCounter6 = r.ipsetCounter
|
||||
} else {
|
||||
currentState.RouteRules = maps.Clone(r.rules)
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"maps"
|
||||
)
|
||||
import "encoding/json"
|
||||
|
||||
type ipList struct {
|
||||
ips map[string]struct{}
|
||||
@@ -22,14 +19,6 @@ func (s *ipList) addIP(ip string) {
|
||||
s.ips[ip] = struct{}{}
|
||||
}
|
||||
|
||||
// clone returns a deep copy of the ipList with its own ips map.
|
||||
func (s *ipList) clone() *ipList {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return &ipList{ips: maps.Clone(s.ips)}
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (s *ipList) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
@@ -66,19 +55,6 @@ func newIpsetStore() *ipsetStore {
|
||||
}
|
||||
}
|
||||
|
||||
// clone returns a deep copy of the ipsetStore with its own ipsets map and
|
||||
// independent ipList entries.
|
||||
func (s *ipsetStore) clone() *ipsetStore {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := &ipsetStore{ipsets: make(map[string]*ipList, len(s.ipsets))}
|
||||
for name, list := range s.ipsets {
|
||||
cloned.ipsets[name] = list.clone()
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
|
||||
@@ -1,219 +0,0 @@
|
||||
// Package approval brokers per-attempt user-accept prompts for inbound
|
||||
// remote access (VNC today, SSH and others in the future). A caller pushes
|
||||
// a Prompt; the broker emits a SystemEvent on the daemon→UI stream and
|
||||
// blocks until the UI calls the daemon's RespondApproval RPC, the per-
|
||||
// request timeout fires, or no subscriber is connected. The latter case
|
||||
// fails closed so a backgrounded UI cannot silently bypass the gate.
|
||||
package approval
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// Metadata keys the broker reserves on the emitted SystemEvent. Callers
|
||||
// should not set these themselves; values in Prompt.Metadata that collide
|
||||
// are overwritten by the broker.
|
||||
const (
|
||||
MetaRequestID = "request_id"
|
||||
MetaKind = "kind"
|
||||
MetaExpiresAt = "expires_at"
|
||||
)
|
||||
|
||||
// ShortKeyFingerprint formats a hex-encoded Noise_IK static pubkey as a
|
||||
// short, eyeball-able fingerprint to display in the approval dialog.
|
||||
// The dashboard-supplied display name attached to a SessionPubKey isn't
|
||||
// cryptographically asserted by the connecting client, so the prompt
|
||||
// must also show something that IS: the key fingerprint, a hash of
|
||||
// the static public key the client just proved possession of during the
|
||||
// Noise handshake. Returns the empty string when the input is too short
|
||||
// to plausibly be a hex pubkey, so the row is omitted rather than
|
||||
// rendered as a misleading partial.
|
||||
//
|
||||
// Output format: 16 hex chars grouped as XXXX-XXXX-XXXX-XXXX (64 bits of
|
||||
// fingerprint, resistant to random-prefix collisions and easy for a human
|
||||
// to compare with an out-of-band reference).
|
||||
func ShortKeyFingerprint(hexKey string) string {
|
||||
if len(hexKey) < 8 {
|
||||
return ""
|
||||
}
|
||||
src := hexKey
|
||||
if len(src) > 16 {
|
||||
src = src[:16]
|
||||
}
|
||||
var out []byte
|
||||
for i, c := range src {
|
||||
if i > 0 && i%4 == 0 {
|
||||
out = append(out, '-')
|
||||
}
|
||||
out = append(out, byte(c))
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
|
||||
// Kind values for the well-known prompt subjects. New subsystems should
|
||||
// add a constant here so the UI can dispatch on a known string.
|
||||
const (
|
||||
KindVNC = "vnc"
|
||||
KindSSH = "ssh"
|
||||
)
|
||||
|
||||
// DefaultTimeout is the wall-clock window the user has to accept or deny a
|
||||
// pending approval before the broker fails closed and returns ErrTimeout.
|
||||
// Kept well under typical VNC client and dashboard connection timeouts so
|
||||
// the RFB rejection actually reaches the browser instead of racing the
|
||||
// browser's own "connection timed out" message.
|
||||
const DefaultTimeout = 15 * time.Second
|
||||
|
||||
// timeoutValue returns the active timeout. It's a var so tests in this
|
||||
// package can shorten the wait without exposing a setter on the public
|
||||
// API. Production code always sees DefaultTimeout.
|
||||
var timeoutValue = func() time.Duration { return DefaultTimeout }
|
||||
|
||||
// ErrNoSubscriber indicates no UI is connected to consume the prompt.
|
||||
// The caller must reject the underlying connection (fail-closed).
|
||||
var ErrNoSubscriber = errors.New("no UI subscriber connected for approval")
|
||||
|
||||
// ErrTimeout indicates the user did not respond within DefaultTimeout.
|
||||
var ErrTimeout = errors.New("approval timed out")
|
||||
|
||||
// ErrDenied indicates the user explicitly denied the connection.
|
||||
var ErrDenied = errors.New("approval denied")
|
||||
|
||||
// EventPublisher is the subset of peer.Status used to emit prompts.
|
||||
type EventPublisher interface {
|
||||
PublishEvent(
|
||||
severity proto.SystemEvent_Severity,
|
||||
category proto.SystemEvent_Category,
|
||||
msg string,
|
||||
userMsg string,
|
||||
metadata map[string]string,
|
||||
)
|
||||
HasEventSubscribers() bool
|
||||
}
|
||||
|
||||
// Prompt describes the pending request shown to the user. Kind selects
|
||||
// the UI dispatch path (e.g. "vnc", "ssh"). Subject is the human-readable
|
||||
// one-liner the UI may show as a title or notification body. Metadata is
|
||||
// passed through verbatim and is the subsystem-specific payload (peer
|
||||
// name, source IP, mode, etc.).
|
||||
type Prompt struct {
|
||||
Kind string
|
||||
Subject string
|
||||
Metadata map[string]string
|
||||
}
|
||||
|
||||
// Decision carries the user's response to an approval prompt. ViewOnly is
|
||||
// only meaningful when Accept is true; it lets the host grant the
|
||||
// connection but signal the requester that input control is withheld.
|
||||
type Decision struct {
|
||||
Accept bool
|
||||
ViewOnly bool
|
||||
}
|
||||
|
||||
// Broker holds in-flight approval requests keyed by request ID.
|
||||
type Broker struct {
|
||||
pub EventPublisher
|
||||
|
||||
mu sync.Mutex
|
||||
pending map[string]chan Decision
|
||||
}
|
||||
|
||||
// New returns a broker that publishes prompts via pub.
|
||||
func New(pub EventPublisher) *Broker {
|
||||
return &Broker{
|
||||
pub: pub,
|
||||
pending: make(map[string]chan Decision),
|
||||
}
|
||||
}
|
||||
|
||||
// Request emits a SystemEvent for p and blocks until the UI calls Respond,
|
||||
// ctx is cancelled, or DefaultTimeout elapses. Returns a Decision when
|
||||
// the user replied; ErrDenied / ErrTimeout / ErrNoSubscriber / ctx.Err
|
||||
// otherwise. Callers must treat any non-nil error as a deny.
|
||||
func (b *Broker) Request(ctx context.Context, p Prompt) (Decision, error) {
|
||||
var zero Decision
|
||||
if b == nil || b.pub == nil {
|
||||
return zero, fmt.Errorf("approval broker not configured")
|
||||
}
|
||||
if !b.pub.HasEventSubscribers() {
|
||||
return zero, ErrNoSubscriber
|
||||
}
|
||||
|
||||
id := uuid.NewString()
|
||||
resp := make(chan Decision, 1)
|
||||
|
||||
b.mu.Lock()
|
||||
b.pending[id] = resp
|
||||
b.mu.Unlock()
|
||||
|
||||
defer b.dropPending(id)
|
||||
|
||||
timeout := timeoutValue()
|
||||
expiresAt := time.Now().Add(timeout)
|
||||
meta := make(map[string]string, len(p.Metadata)+3)
|
||||
for k, v := range p.Metadata {
|
||||
meta[k] = v
|
||||
}
|
||||
meta[MetaRequestID] = id
|
||||
meta[MetaKind] = p.Kind
|
||||
meta[MetaExpiresAt] = expiresAt.UTC().Format(time.RFC3339)
|
||||
|
||||
subject := p.Subject
|
||||
if subject == "" {
|
||||
subject = fmt.Sprintf("%s connection requires approval", p.Kind)
|
||||
}
|
||||
b.pub.PublishEvent(proto.SystemEvent_INFO, proto.SystemEvent_APPROVAL, subject, subject, meta)
|
||||
log.Debugf("approval request %s (%s) emitted: %s", id, p.Kind, subject)
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case d := <-resp:
|
||||
if !d.Accept {
|
||||
return zero, ErrDenied
|
||||
}
|
||||
return d, nil
|
||||
case <-timer.C:
|
||||
return zero, ErrTimeout
|
||||
case <-ctx.Done():
|
||||
return zero, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Respond delivers the user's decision for id. Returns true when a pending
|
||||
// request matched and was woken, false when id was unknown or already done.
|
||||
func (b *Broker) Respond(id string, d Decision) bool {
|
||||
if b == nil {
|
||||
return false
|
||||
}
|
||||
b.mu.Lock()
|
||||
ch, ok := b.pending[id]
|
||||
if ok {
|
||||
delete(b.pending, id)
|
||||
}
|
||||
b.mu.Unlock()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case ch <- d:
|
||||
default:
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (b *Broker) dropPending(id string) {
|
||||
b.mu.Lock()
|
||||
delete(b.pending, id)
|
||||
b.mu.Unlock()
|
||||
}
|
||||
@@ -1,434 +0,0 @@
|
||||
package approval
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// fakePublisher records published events and reports whether subscribers
|
||||
// are connected. The subscribers flag is the security-critical signal:
|
||||
// when false the broker must refuse to emit and the gate must fail closed.
|
||||
type fakePublisher struct {
|
||||
mu sync.Mutex
|
||||
subscribers bool
|
||||
events []*proto.SystemEvent
|
||||
}
|
||||
|
||||
func (p *fakePublisher) PublishEvent(
|
||||
severity proto.SystemEvent_Severity,
|
||||
category proto.SystemEvent_Category,
|
||||
msg string,
|
||||
userMsg string,
|
||||
metadata map[string]string,
|
||||
) {
|
||||
p.mu.Lock()
|
||||
p.events = append(p.events, &proto.SystemEvent{
|
||||
Severity: severity,
|
||||
Category: category,
|
||||
Message: msg,
|
||||
UserMessage: userMsg,
|
||||
Metadata: metadata,
|
||||
})
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
func (p *fakePublisher) HasEventSubscribers() bool {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.subscribers
|
||||
}
|
||||
|
||||
func (p *fakePublisher) lastEvent(t *testing.T) *proto.SystemEvent {
|
||||
t.Helper()
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
require.NotEmpty(t, p.events, "publisher saw no events")
|
||||
return p.events[len(p.events)-1]
|
||||
}
|
||||
|
||||
func (p *fakePublisher) eventCount() int {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return len(p.events)
|
||||
}
|
||||
|
||||
// TestRequestNoSubscriberFailsClosed is the core fail-closed invariant:
|
||||
// when the UI is not subscribed, the broker must refuse without emitting
|
||||
// an event or arming a waiter. A regression here is a silent bypass.
|
||||
func TestRequestNoSubscriberFailsClosed(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: false}
|
||||
b := New(pub)
|
||||
|
||||
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
|
||||
assert.ErrorIs(t, err, ErrNoSubscriber)
|
||||
assert.Equal(t, 0, pub.eventCount(), "no event must be emitted when fail-closed")
|
||||
|
||||
b.mu.Lock()
|
||||
pending := len(b.pending)
|
||||
b.mu.Unlock()
|
||||
assert.Equal(t, 0, pending, "no waiter must be registered on fail-closed")
|
||||
}
|
||||
|
||||
// TestRequestTimeoutDenies verifies that a request without a UI response
|
||||
// returns ErrTimeout (deny) rather than nil (silent accept). Uses a short
|
||||
// per-test broker timeout via Respond after the fact to keep the test fast.
|
||||
func TestRequestTimeoutDenies(t *testing.T) {
|
||||
// Replace DefaultTimeout for the lifetime of this test.
|
||||
orig := DefaultTimeout
|
||||
defaultTimeout(t, 60*time.Millisecond)
|
||||
defer defaultTimeout(t, orig)
|
||||
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
start := time.Now()
|
||||
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
|
||||
assert.ErrorIs(t, err, ErrTimeout, "missing user response must yield ErrTimeout, not nil")
|
||||
assert.GreaterOrEqual(t, time.Since(start), 50*time.Millisecond, "timeout fired prematurely")
|
||||
}
|
||||
|
||||
// TestRequestDenied returns ErrDenied when the UI responds with false.
|
||||
func TestRequestDenied(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
var requestID string
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
|
||||
}()
|
||||
|
||||
requestID = waitForRequestID(t, pub)
|
||||
require.True(t, b.Respond(requestID, Decision{Accept: false}))
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assert.ErrorIs(t, err, ErrDenied)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Request did not return after Respond(false)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRequestAccepted is the happy path. Failure here doesn't bypass the
|
||||
// gate but breaks the feature.
|
||||
func TestRequestAccepted(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC, Subject: "test"})
|
||||
}()
|
||||
|
||||
id := waitForRequestID(t, pub)
|
||||
require.True(t, b.Respond(id, Decision{Accept: true}))
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assert.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Request did not return after Respond(true)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRequestCtxCancelDenies verifies that an upstream cancel (e.g. the
|
||||
// engine shutting down mid-prompt) returns the cancel error rather than
|
||||
// nil. A nil here would be a silent bypass on shutdown races.
|
||||
func TestRequestCtxCancelDenies(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, ctx, Prompt{Kind: KindVNC, Subject: "test"})
|
||||
}()
|
||||
|
||||
// Wait until the prompt is in flight so cancel races a live waiter.
|
||||
_ = waitForRequestID(t, pub)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assert.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Request did not return after ctx cancel")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRespondUnknownIsNoop ensures a stray RespondApproval RPC cannot
|
||||
// affect or accidentally accept any in-flight request whose id it doesn't
|
||||
// match. Also confirms it doesn't panic.
|
||||
func TestRespondUnknownIsNoop(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
// No in-flight prompts: Respond returns false.
|
||||
assert.False(t, b.Respond("does-not-exist", Decision{Accept: true}))
|
||||
|
||||
// With an in-flight prompt, a wrong id still returns false and the
|
||||
// prompt remains armed (eventually timing out as a deny).
|
||||
defaultTimeout(t, 60*time.Millisecond)
|
||||
defer defaultTimeout(t, DefaultTimeout)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
|
||||
}()
|
||||
realID := waitForRequestID(t, pub)
|
||||
assert.False(t, b.Respond("totally-bogus", Decision{Accept: true}), "unknown id must not match")
|
||||
assert.NotEqual(t, "totally-bogus", realID)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assert.ErrorIs(t, err, ErrTimeout, "armed prompt must still time out, not accept")
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("prompt did not resolve")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRespondAfterTimeoutNoop confirms a late accept response can't
|
||||
// retroactively flip a denied (timed-out) request. The dropPending defer
|
||||
// in Request must have removed the entry by the time Respond races in.
|
||||
func TestRespondAfterTimeoutNoop(t *testing.T) {
|
||||
defaultTimeout(t, 30*time.Millisecond)
|
||||
defer defaultTimeout(t, DefaultTimeout)
|
||||
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
|
||||
}()
|
||||
id := waitForRequestID(t, pub)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
require.ErrorIs(t, err, ErrTimeout)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("prompt did not time out")
|
||||
}
|
||||
|
||||
assert.False(t, b.Respond(id, Decision{Accept: true}), "late respond must be no-op")
|
||||
}
|
||||
|
||||
// TestRespondDoubleNoop ensures a duplicate ack from the UI doesn't leak
|
||||
// past the matched waiter or panic on a closed/full channel.
|
||||
func TestRespondDoubleNoop(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
|
||||
}()
|
||||
id := waitForRequestID(t, pub)
|
||||
require.True(t, b.Respond(id, Decision{Accept: true}))
|
||||
assert.False(t, b.Respond(id, Decision{Accept: false}), "second response must be no-op")
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assert.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("prompt did not resolve")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNilBrokerRequestErrors guards the engine pre-init path where the
|
||||
// broker may not yet exist (or its publisher is nil): Request must
|
||||
// error, never silently accept.
|
||||
func TestNilBrokerRequestErrors(t *testing.T) {
|
||||
var b *Broker
|
||||
_, err := b.Request(context.Background(), Prompt{Kind: KindVNC})
|
||||
assert.Error(t, err, "nil broker must error, never silently accept")
|
||||
|
||||
b2 := New(nil)
|
||||
_, err = b2.Request(context.Background(), Prompt{Kind: KindVNC})
|
||||
assert.Error(t, err, "broker with nil publisher must error, never silently accept")
|
||||
}
|
||||
|
||||
// TestPromptMetadataInjected confirms the broker stamps request_id, kind,
|
||||
// and expires_at on the emitted event. The UI relies on these keys; if
|
||||
// they are dropped, the user cannot route the prompt and the response
|
||||
// path breaks (which fails closed via timeout).
|
||||
func TestPromptMetadataInjected(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- requestErr(b, context.Background(), Prompt{
|
||||
Kind: KindVNC,
|
||||
Subject: "VNC connection from peerA",
|
||||
Metadata: map[string]string{"peer_name": "peerA"},
|
||||
})
|
||||
}()
|
||||
|
||||
id := waitForRequestID(t, pub)
|
||||
ev := pub.lastEvent(t)
|
||||
|
||||
assert.Equal(t, proto.SystemEvent_APPROVAL, ev.Category)
|
||||
assert.Equal(t, KindVNC, ev.Metadata[MetaKind])
|
||||
assert.Equal(t, id, ev.Metadata[MetaRequestID])
|
||||
assert.NotEmpty(t, ev.Metadata[MetaExpiresAt])
|
||||
assert.Equal(t, "peerA", ev.Metadata["peer_name"], "caller metadata must pass through")
|
||||
|
||||
require.True(t, b.Respond(id, Decision{Accept: true}))
|
||||
<-done
|
||||
}
|
||||
|
||||
// TestConcurrentRequests verifies that two concurrent prompts are tracked
|
||||
// independently. A bug that aliases ids would let one Respond unblock
|
||||
// the wrong waiter (a silent accept across prompts).
|
||||
func TestConcurrentRequests(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
const n = 20
|
||||
results := make(chan error, n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
results <- requestErr(b, context.Background(), Prompt{Kind: KindVNC})
|
||||
}()
|
||||
}
|
||||
|
||||
ids := waitForNRequestIDs(t, pub, n)
|
||||
require.Len(t, ids, n)
|
||||
|
||||
// Deny exactly half, accept the rest. Track outcome per id so we can
|
||||
// match each Request's return value against the response we sent.
|
||||
denySet := make(map[string]bool, n)
|
||||
for i, id := range ids {
|
||||
deny := i%2 == 0
|
||||
denySet[id] = deny
|
||||
require.True(t, b.Respond(id, Decision{Accept: !deny}))
|
||||
}
|
||||
|
||||
// Collect all returns and check no nil errors slipped past a deny.
|
||||
var accepted, denied atomic.Int32
|
||||
for i := 0; i < n; i++ {
|
||||
select {
|
||||
case err := <-results:
|
||||
if err == nil {
|
||||
accepted.Add(1)
|
||||
} else {
|
||||
assert.ErrorIs(t, err, ErrDenied)
|
||||
denied.Add(1)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("only got %d/%d responses", i, n)
|
||||
}
|
||||
}
|
||||
assert.Equal(t, int32(n/2), denied.Load())
|
||||
assert.Equal(t, int32(n/2), accepted.Load())
|
||||
}
|
||||
|
||||
// waitForRequestID blocks until the publisher sees its next event and
|
||||
// returns the request_id stamped on it.
|
||||
func waitForRequestID(t *testing.T, pub *fakePublisher) string {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
pub.mu.Lock()
|
||||
count := len(pub.events)
|
||||
var id string
|
||||
if count > 0 {
|
||||
id = pub.events[count-1].Metadata[MetaRequestID]
|
||||
}
|
||||
pub.mu.Unlock()
|
||||
if id != "" {
|
||||
return id
|
||||
}
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
}
|
||||
t.Fatal("timeout waiting for emitted event")
|
||||
return ""
|
||||
}
|
||||
|
||||
func waitForNRequestIDs(t *testing.T, pub *fakePublisher, n int) []string {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
pub.mu.Lock()
|
||||
count := len(pub.events)
|
||||
pub.mu.Unlock()
|
||||
if count >= n {
|
||||
break
|
||||
}
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
}
|
||||
pub.mu.Lock()
|
||||
defer pub.mu.Unlock()
|
||||
out := make([]string, 0, len(pub.events))
|
||||
seen := make(map[string]struct{}, len(pub.events))
|
||||
for _, ev := range pub.events {
|
||||
id := ev.Metadata[MetaRequestID]
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, dup := seen[id]; dup {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
if len(out) < n {
|
||||
t.Fatalf("only got %d/%d request ids", len(out), n)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// defaultTimeout swaps the broker's per-request wall-clock window so the
|
||||
// timeout tests run quickly. Restores the prior value on the next call.
|
||||
func defaultTimeout(t *testing.T, d time.Duration) {
|
||||
t.Helper()
|
||||
if d <= 0 {
|
||||
t.Fatal("defaultTimeout must be > 0")
|
||||
}
|
||||
timeoutValue = func() time.Duration { return d }
|
||||
}
|
||||
|
||||
// requestErr wraps Broker.Request to drop the Decision when tests only
|
||||
// care about the error path. Keeps the goroutine bodies tight.
|
||||
func requestErr(b *Broker, ctx context.Context, p Prompt) error {
|
||||
_, err := b.Request(ctx, p)
|
||||
return err
|
||||
}
|
||||
|
||||
// TestRequestViewOnly checks the view-only outcome flows through Request's
|
||||
// Decision return without being silently swallowed.
|
||||
func TestRequestViewOnly(t *testing.T) {
|
||||
pub := &fakePublisher{subscribers: true}
|
||||
b := New(pub)
|
||||
|
||||
type result struct {
|
||||
d Decision
|
||||
err error
|
||||
}
|
||||
done := make(chan result, 1)
|
||||
go func() {
|
||||
d, err := b.Request(context.Background(), Prompt{Kind: KindVNC})
|
||||
done <- result{d, err}
|
||||
}()
|
||||
|
||||
id := waitForRequestID(t, pub)
|
||||
require.True(t, b.Respond(id, Decision{Accept: true, ViewOnly: true}))
|
||||
|
||||
select {
|
||||
case r := <-done:
|
||||
assert.NoError(t, r.err)
|
||||
assert.True(t, r.d.Accept)
|
||||
assert.True(t, r.d.ViewOnly, "ViewOnly must survive the round-trip")
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("view-only request did not resolve")
|
||||
}
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
package approval
|
||||
|
||||
import "testing"
|
||||
|
||||
// TestShortKeyFingerprint locks in the format the VNC approval prompt
|
||||
// shows to the user. The fingerprint is the user's only cryptographic
|
||||
// anchor against a malicious management server that pushes a spoofed
|
||||
// display name, so accidental changes to its format would silently
|
||||
// undermine that defence.
|
||||
func TestShortKeyFingerprint(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "full_32_byte_pubkey",
|
||||
in: "0123456789abcdeffedcba9876543210ffeeddccbbaa99887766554433221100",
|
||||
want: "0123-4567-89ab-cdef",
|
||||
},
|
||||
{
|
||||
name: "exactly_16_chars",
|
||||
in: "0123456789abcdef",
|
||||
want: "0123-4567-89ab-cdef",
|
||||
},
|
||||
{
|
||||
name: "borderline_8_chars",
|
||||
in: "01234567",
|
||||
want: "0123-4567",
|
||||
},
|
||||
{
|
||||
name: "too_short_returns_empty",
|
||||
in: "0123",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty_returns_empty",
|
||||
in: "",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ShortKeyFingerprint(tc.in)
|
||||
if got != tc.want {
|
||||
t.Fatalf("ShortKeyFingerprint(%q) = %q, want %q", tc.in, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestShortKeyFingerprint_DistinctKeysDistinctOutputs guards against a
|
||||
// formatting bug that would collapse different prefixes onto the same
|
||||
// displayed fingerprint and let an attacker substitute their pubkey for
|
||||
// a victim's while keeping the prompt visually identical.
|
||||
func TestShortKeyFingerprint_DistinctKeysDistinctOutputs(t *testing.T) {
|
||||
a := ShortKeyFingerprint("0123456789abcdef" + "rest_of_pubkey_ignored")
|
||||
b := ShortKeyFingerprint("0123456789abcde0" + "rest_of_pubkey_ignored")
|
||||
if a == b {
|
||||
t.Fatalf("expected distinct outputs for distinct prefixes, both = %q", a)
|
||||
}
|
||||
}
|
||||
@@ -315,7 +315,6 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
a.config.RosenpassEnabled,
|
||||
a.config.RosenpassPermissive,
|
||||
a.config.ServerSSHAllowed,
|
||||
a.config.ServerVNCAllowed,
|
||||
a.config.DisableClientRoutes,
|
||||
a.config.DisableServerRoutes,
|
||||
a.config.DisableDNS,
|
||||
|
||||
@@ -568,8 +568,6 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
RosenpassEnabled: config.RosenpassEnabled,
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
ServerVNCAllowed: config.ServerVNCAllowed != nil && *config.ServerVNCAllowed,
|
||||
DisableVNCApproval: config.DisableVNCApproval,
|
||||
EnableSSHRoot: config.EnableSSHRoot,
|
||||
EnableSSHSFTP: config.EnableSSHSFTP,
|
||||
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
||||
@@ -652,7 +650,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.RosenpassEnabled,
|
||||
config.RosenpassPermissive,
|
||||
config.ServerSSHAllowed,
|
||||
config.ServerVNCAllowed,
|
||||
config.DisableClientRoutes,
|
||||
config.DisableServerRoutes,
|
||||
config.DisableDNS,
|
||||
|
||||
@@ -516,14 +516,6 @@ func (g *BundleGenerator) addConfig() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Surface the set of MDM-enforced keys so a support engineer reading
|
||||
// the bundle can tell which field values are user-set vs MDM-overridden.
|
||||
// Same semantics as the mDMManagedFields list returned by the
|
||||
// GetConfig RPC consumed by `netbird debug config`.
|
||||
if managed := g.internalConfig.Policy().ManagedKeys(); len(managed) > 0 {
|
||||
configContent.WriteString(fmt.Sprintf("MDMManagedFields: %v\n", managed))
|
||||
}
|
||||
|
||||
configReader := strings.NewReader(configContent.String())
|
||||
if err := g.addFileToZip(configReader, "config.txt"); err != nil {
|
||||
return fmt.Errorf("add config file to zip: %w", err)
|
||||
@@ -652,12 +644,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
if g.internalConfig.SSHJWTCacheTTL != nil {
|
||||
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
|
||||
}
|
||||
if g.internalConfig.ServerVNCAllowed != nil {
|
||||
configContent.WriteString(fmt.Sprintf("ServerVNCAllowed: %v\n", *g.internalConfig.ServerVNCAllowed))
|
||||
}
|
||||
if g.internalConfig.DisableVNCApproval != nil {
|
||||
configContent.WriteString(fmt.Sprintf("DisableVNCApproval: %v\n", *g.internalConfig.DisableVNCApproval))
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||
|
||||
@@ -843,7 +843,6 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
||||
"PreSharedKey": "sensitive: WireGuard pre-shared key",
|
||||
"SSHKey": "sensitive: SSH private key",
|
||||
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
|
||||
"policy": "non-config: in-memory MDM policy snapshot, surfaced via Config.Policy() / GetConfigResponse.MDMManagedFields",
|
||||
}
|
||||
|
||||
mURL, _ := url.Parse("https://api.example.com:443")
|
||||
@@ -863,8 +862,6 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
||||
RosenpassEnabled: true,
|
||||
RosenpassPermissive: true,
|
||||
ServerSSHAllowed: &bTrue,
|
||||
ServerVNCAllowed: &bTrue,
|
||||
DisableVNCApproval: &bTrue,
|
||||
EnableSSHRoot: &bTrue,
|
||||
EnableSSHSFTP: &bTrue,
|
||||
EnableSSHLocalPortForwarding: &bTrue,
|
||||
|
||||
@@ -482,7 +482,7 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16,
|
||||
// completely when every proxy peer is offline (the upstream may still
|
||||
// be reachable some other way, or the peerstore may be stale).
|
||||
func (d *Resolver) filterDisconnectedPeerAnswers(logger *log.Entry, question dns.Question, records []dns.RR) []dns.RR {
|
||||
if len(records) < 2 {
|
||||
if len(records) == 0 {
|
||||
return records
|
||||
}
|
||||
d.mu.RLock()
|
||||
|
||||
@@ -2738,17 +2738,6 @@ func TestLocalResolver_FilterDisconnectedPeerAnswers(t *testing.T) {
|
||||
connByIP: nil,
|
||||
wantInOrder: []string{"100.64.0.10", "100.64.0.11"},
|
||||
},
|
||||
{
|
||||
// A single answer is never filtered: dropping it would only
|
||||
// trigger the empty-answer escape hatch, so the fast path
|
||||
// returns it untouched.
|
||||
name: "single disconnected answer passes through",
|
||||
records: []nbdns.SimpleRecord{disconnectedRec},
|
||||
connByIP: map[string]ipState{
|
||||
"100.64.0.11": {known: true, connected: false},
|
||||
},
|
||||
wantInOrder: []string{"100.64.0.11"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
|
||||
@@ -14,10 +14,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// errNoSuitableAddress mirrors the unexported error string the net package
|
||||
// uses when a resolved host has no addresses of the requested family.
|
||||
const errNoSuitableAddress = "no suitable address found"
|
||||
|
||||
// GenerateRequestID creates a random 8-character hex string for request tracing.
|
||||
func GenerateRequestID() string {
|
||||
bytes := make([]byte, 4)
|
||||
@@ -130,14 +126,6 @@ func LookupIP(ctx context.Context, r resolver, network, host string, qtype uint1
|
||||
}
|
||||
|
||||
func getRcodeForError(ctx context.Context, r resolver, host string, qtype uint16, err error) int {
|
||||
// The net package returns this AddrError when the host resolves but has
|
||||
// no addresses of the requested family. The domain exists, so answer
|
||||
// NODATA instead of SERVFAIL.
|
||||
var addrErr *net.AddrError
|
||||
if errors.As(err, &addrErr) && addrErr.Err == errNoSuitableAddress {
|
||||
return dns.RcodeSuccess
|
||||
}
|
||||
|
||||
var dnsErr *net.DNSError
|
||||
if !errors.As(err, &dnsErr) {
|
||||
return dns.RcodeServerFailure
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
package resutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockResolver struct {
|
||||
// results maps network ("ip4"/"ip6") to the lookup outcome.
|
||||
results map[string]mockLookup
|
||||
}
|
||||
|
||||
type mockLookup struct {
|
||||
ips []netip.Addr
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockResolver) LookupNetIP(_ context.Context, network, _ string) ([]netip.Addr, error) {
|
||||
res, ok := m.results[network]
|
||||
if !ok {
|
||||
return nil, errors.New("unexpected network: " + network)
|
||||
}
|
||||
return res.ips, res.err
|
||||
}
|
||||
|
||||
func TestLookupIP_Success(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip4": {ips: []netip.Addr{netip.MustParseAddr("::ffff:192.0.2.1")}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip4", "example.com.", dns.TypeA)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, result.Rcode, "successful lookup should return NOERROR")
|
||||
require.Len(t, result.IPs, 1, "should return the resolved address")
|
||||
assert.Equal(t, netip.MustParseAddr("192.0.2.1"), result.IPs[0], "v4-mapped address should be unmapped")
|
||||
}
|
||||
|
||||
func TestLookupIP_NoSuitableAddress(t *testing.T) {
|
||||
// The net package returns this AddrError when the host resolves but has
|
||||
// no addresses of the requested family (e.g. AAAA query for a v4-only
|
||||
// hosts file entry). The domain exists, so this is NODATA, not SERVFAIL.
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip6": {err: &net.AddrError{Err: "no suitable address found", Addr: "example.com."}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip6", "example.com.", dns.TypeAAAA)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, result.Rcode, "no suitable address should map to NODATA")
|
||||
assert.Empty(t, result.IPs, "NODATA response should carry no addresses")
|
||||
}
|
||||
|
||||
// TestErrNoSuitableAddressMatchesNetPackage pins our copy of the error string
|
||||
// to what the net package actually emits. A literal IP of the wrong family
|
||||
// takes the same filterAddrList path as a resolved hostname, without network
|
||||
// access.
|
||||
func TestErrNoSuitableAddressMatchesNetPackage(t *testing.T) {
|
||||
_, err := (&net.Resolver{}).LookupNetIP(context.Background(), "ip6", "192.0.2.1")
|
||||
require.Error(t, err)
|
||||
|
||||
var addrErr *net.AddrError
|
||||
require.ErrorAs(t, err, &addrErr, "wrong-family lookup should return AddrError")
|
||||
assert.Equal(t, errNoSuitableAddress, addrErr.Err, "net package error string should match our constant")
|
||||
}
|
||||
|
||||
func TestLookupIP_OtherAddrError(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip4": {err: &net.AddrError{Err: "some other address problem", Addr: "example.com."}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip4", "example.com.", dns.TypeA)
|
||||
|
||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "unrecognized AddrError should map to SERVFAIL")
|
||||
}
|
||||
|
||||
func TestLookupIP_NotFoundNXDomain(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip4": {err: &net.DNSError{Err: "no such host", Name: "example.com.", IsNotFound: true}},
|
||||
"ip6": {err: &net.DNSError{Err: "no such host", Name: "example.com.", IsNotFound: true}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip4", "example.com.", dns.TypeA)
|
||||
|
||||
assert.Equal(t, dns.RcodeNameError, result.Rcode, "not found for both families should map to NXDOMAIN")
|
||||
}
|
||||
|
||||
func TestLookupIP_NotFoundNoData(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip6": {err: &net.DNSError{Err: "no such host", Name: "example.com.", IsNotFound: true}},
|
||||
"ip4": {ips: []netip.Addr{netip.MustParseAddr("192.0.2.1")}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip6", "example.com.", dns.TypeAAAA)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, result.Rcode, "not found with the other family present should map to NODATA")
|
||||
}
|
||||
|
||||
func TestLookupIP_GenericError(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip4": {err: errors.New("connection refused")},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip4", "example.com.", dns.TypeA)
|
||||
|
||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "generic error should map to SERVFAIL")
|
||||
}
|
||||
|
||||
func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip4": {err: &net.DNSError{Err: "server misbehaving", Name: "example.com.", IsTemporary: true}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip4", "example.com.", dns.TypeA)
|
||||
|
||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
|
||||
}
|
||||
@@ -777,24 +777,13 @@ func (s *DefaultServer) applyHostConfig() {
|
||||
// context is released rather than leaked until GC.
|
||||
func (s *DefaultServer) registerFallback() {
|
||||
originalNameservers := s.hostManager.getOriginalNameservers()
|
||||
|
||||
serverIP := s.service.RuntimeIP()
|
||||
var servers []netip.AddrPort
|
||||
for _, ns := range originalNameservers {
|
||||
if ns == serverIP {
|
||||
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, serverIP)
|
||||
continue
|
||||
}
|
||||
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
|
||||
}
|
||||
|
||||
if len(servers) == 0 {
|
||||
if len(originalNameservers) == 0 {
|
||||
log.Debugf("no fallback upstreams to register; clearing PriorityFallback handler")
|
||||
s.clearFallback()
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("registering original nameservers %v as upstream handlers with priority %d", servers, PriorityFallback)
|
||||
log.Infof("registering original nameservers %v as upstream handlers with priority %d", originalNameservers, PriorityFallback)
|
||||
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
@@ -808,6 +797,11 @@ func (s *DefaultServer) registerFallback() {
|
||||
return
|
||||
}
|
||||
handler.selectedRoutes = s.selectedRoutes
|
||||
|
||||
var servers []netip.AddrPort
|
||||
for _, ns := range originalNameservers {
|
||||
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
|
||||
}
|
||||
handler.addRace(servers)
|
||||
|
||||
prev := s.fallbackHandler
|
||||
|
||||
@@ -34,7 +34,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/approval"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
@@ -125,8 +124,6 @@ type EngineConfig struct {
|
||||
RosenpassPermissive bool
|
||||
|
||||
ServerSSHAllowed bool
|
||||
ServerVNCAllowed bool
|
||||
DisableVNCApproval *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
@@ -212,9 +209,7 @@ type Engine struct {
|
||||
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServer sshServer
|
||||
vncSrv vncServer
|
||||
approvalBroker *approval.Broker
|
||||
sshServer sshServer
|
||||
|
||||
statusRecorder *peer.Status
|
||||
|
||||
@@ -300,7 +295,6 @@ func NewEngine(
|
||||
TURNs: []*stun.URI{},
|
||||
networkSerial: 0,
|
||||
statusRecorder: services.StatusRecorder,
|
||||
approvalBroker: approval.New(services.StatusRecorder),
|
||||
stateManager: services.StateManager,
|
||||
portForwardManager: portforward.NewManager(),
|
||||
checks: services.Checks,
|
||||
@@ -337,10 +331,6 @@ func (e *Engine) Stop() error {
|
||||
log.Warnf("failed to stop SSH server: %v", err)
|
||||
}
|
||||
|
||||
if err := e.stopVNCServer(); err != nil {
|
||||
log.Warnf("failed to stop VNC server: %v", err)
|
||||
}
|
||||
|
||||
e.cleanupSSHConfig()
|
||||
|
||||
if e.ingressGatewayMgr != nil {
|
||||
@@ -890,25 +880,62 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
||||
}
|
||||
|
||||
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
|
||||
return err
|
||||
}
|
||||
if update.GetNetbirdConfig() != nil {
|
||||
wCfg := update.GetNetbirdConfig()
|
||||
err := e.updateTURNs(wCfg.GetTurns())
|
||||
if err != nil {
|
||||
return fmt.Errorf("update TURNs: %w", err)
|
||||
}
|
||||
|
||||
// Posture checks are bound to the network map presence:
|
||||
// NetworkMap != nil, checks present -> apply the received checks
|
||||
// NetworkMap != nil, checks nil -> posture checks were removed, clear them
|
||||
// NetworkMap == nil -> config-only update (e.g. relay token rotation),
|
||||
// leave the previously applied checks untouched
|
||||
nm := update.GetNetworkMap()
|
||||
if nm == nil {
|
||||
return nil
|
||||
err = e.updateSTUNs(wCfg.GetStuns())
|
||||
if err != nil {
|
||||
return fmt.Errorf("update STUNs: %w", err)
|
||||
}
|
||||
|
||||
var stunTurn []*stun.URI
|
||||
stunTurn = append(stunTurn, e.STUNs...)
|
||||
stunTurn = append(stunTurn, e.TURNs...)
|
||||
e.stunTurn.Store(stunTurn)
|
||||
|
||||
err = e.handleRelayUpdate(wCfg.GetRelay())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = e.handleFlowUpdate(wCfg.GetFlow())
|
||||
if err != nil {
|
||||
return fmt.Errorf("handle the flow configuration: %w", err)
|
||||
}
|
||||
|
||||
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
|
||||
log.Warnf("Failed to update DNS server config: %v", err)
|
||||
}
|
||||
|
||||
// todo update signal
|
||||
}
|
||||
|
||||
if err := e.updateChecksIfNew(update.Checks); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.persistSyncResponse(update)
|
||||
nm := update.GetNetworkMap()
|
||||
if nm == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
|
||||
// A non-nil syncStore is what marks persistence as enabled. Hold the lock for
|
||||
// the whole Set so the store cannot be cleared (disabled / engine close)
|
||||
// mid-call and have this write resurrect a file that was just removed.
|
||||
e.syncRespMux.RLock()
|
||||
if e.syncStore != nil {
|
||||
if err := e.syncStore.Set(update); err != nil {
|
||||
log.Errorf("failed to persist sync response: %v", err)
|
||||
} else {
|
||||
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
|
||||
}
|
||||
}
|
||||
e.syncRespMux.RUnlock()
|
||||
|
||||
// only apply new changes and ignore old ones
|
||||
if err := e.updateNetworkMap(nm); err != nil {
|
||||
@@ -920,64 +947,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateNetbirdConfig applies the management-provided NetBird configuration:
|
||||
// STUN/TURN and relay servers, flow logging and DNS settings. A nil config is a no-op,
|
||||
// which is the case for sync updates carrying only a network map.
|
||||
func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error {
|
||||
if wCfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.updateTURNs(wCfg.GetTurns()); err != nil {
|
||||
return fmt.Errorf("update TURNs: %w", err)
|
||||
}
|
||||
|
||||
if err := e.updateSTUNs(wCfg.GetStuns()); err != nil {
|
||||
return fmt.Errorf("update STUNs: %w", err)
|
||||
}
|
||||
|
||||
var stunTurn []*stun.URI
|
||||
stunTurn = append(stunTurn, e.STUNs...)
|
||||
stunTurn = append(stunTurn, e.TURNs...)
|
||||
e.stunTurn.Store(stunTurn)
|
||||
|
||||
if err := e.handleRelayUpdate(wCfg.GetRelay()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := e.handleFlowUpdate(wCfg.GetFlow()); err != nil {
|
||||
return fmt.Errorf("handle the flow configuration: %w", err)
|
||||
}
|
||||
|
||||
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
|
||||
log.Warnf("Failed to update DNS server config: %v", err)
|
||||
}
|
||||
|
||||
// todo update signal
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// persistSyncResponse stores the full sync response so it can be restored on the next
|
||||
// startup. Persistence is enabled only when syncStore is set. The dedicated syncRespMux
|
||||
// (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled /
|
||||
// engine close) mid-call and have this write resurrect a file that was just removed.
|
||||
func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) {
|
||||
e.syncRespMux.RLock()
|
||||
defer e.syncRespMux.RUnlock()
|
||||
|
||||
if e.syncStore == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := e.syncStore.Set(update); err != nil {
|
||||
log.Errorf("failed to persist sync response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("sync response persisted with serial %d", update.GetNetworkMap().GetSerial())
|
||||
}
|
||||
|
||||
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
||||
if update != nil {
|
||||
// when we receive token we expect valid address list too
|
||||
@@ -1051,7 +1020,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -1099,10 +1067,6 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := e.updateVNC(); err != nil {
|
||||
log.Warnf("failed handling VNC server setup: %v", err)
|
||||
}
|
||||
|
||||
state := e.statusRecorder.GetLocalPeerState()
|
||||
state.IP = e.wgInterface.Address().String()
|
||||
state.IPv6 = e.wgInterface.Address().IPv6String()
|
||||
@@ -1230,7 +1194,6 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -1420,11 +1383,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||
}
|
||||
|
||||
// VNC auth: always sync, including nil so cleared auth on the management
|
||||
// side is applied locally, and so it isn't skipped on the RemotePeersIsEmpty
|
||||
// cleanup path.
|
||||
e.updateVNCServerAuth(networkMap.GetVncAuth())
|
||||
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
|
||||
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
||||
@@ -1892,7 +1850,6 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
&e.config.ServerVNCAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
@@ -2677,16 +2634,3 @@ func decodeRelayIP(b []byte) netip.Addr {
|
||||
}
|
||||
return ip.Unmap()
|
||||
}
|
||||
|
||||
// RespondApproval relays the user's decision for a pending approval to
|
||||
// the broker. viewOnly is honoured only when accept is true. Returns
|
||||
// true when the request_id matched a live prompt.
|
||||
func (e *Engine) RespondApproval(requestID string, accept, viewOnly bool) bool {
|
||||
if e == nil || e.approvalBroker == nil {
|
||||
return false
|
||||
}
|
||||
return e.approvalBroker.Respond(requestID, approval.Decision{
|
||||
Accept: accept,
|
||||
ViewOnly: accept && viewOnly,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -12,10 +12,10 @@ import (
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
@@ -237,18 +237,22 @@ func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
|
||||
return errors.New("wg interface not initialized")
|
||||
}
|
||||
|
||||
wgAddr := e.wgInterface.Address()
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: e.config.SSHKey,
|
||||
JWT: jwtConfig,
|
||||
NetstackNet: e.wgInterface.GetNet(),
|
||||
NetworkValidation: wgAddr,
|
||||
HostKeyPEM: e.config.SSHKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
|
||||
wgAddr := e.wgInterface.Address()
|
||||
server.SetNetworkValidation(wgAddr)
|
||||
|
||||
netbirdIP := wgAddr.IP
|
||||
listenAddr := netip.AddrPortFrom(netbirdIP, sshserver.InternalSSHPort)
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
server.SetNetstackNet(netstackNet)
|
||||
}
|
||||
|
||||
e.configureSSHServer(server)
|
||||
|
||||
if err := server.Start(e.ctx, listenAddr); err != nil {
|
||||
|
||||
@@ -1,302 +0,0 @@
|
||||
//go:build !js && !ios && !android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/approval"
|
||||
"github.com/netbirdio/netbird/client/internal/metrics"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/vnc"
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
type vncServer interface {
|
||||
Start(ctx context.Context, addr netip.AddrPort, network netip.Prefix) error
|
||||
Stop() error
|
||||
ActiveSessions() []vncserver.ActiveSessionInfo
|
||||
}
|
||||
|
||||
func (e *Engine) setupVNCPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, vnc.ExternalPort, vnc.InternalPort); err != nil {
|
||||
return fmt.Errorf("add VNC port redirection: %w", err)
|
||||
}
|
||||
log.Infof("VNC port redirection: %s:%d -> %s:%d", localAddr, vnc.ExternalPort, localAddr, vnc.InternalPort)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) cleanupVNCPortRedirection() error {
|
||||
if e.firewall == nil || e.wgInterface == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
if !localAddr.IsValid() {
|
||||
return errors.New("invalid local NetBird address")
|
||||
}
|
||||
|
||||
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, vnc.ExternalPort, vnc.InternalPort); err != nil {
|
||||
return fmt.Errorf("remove VNC port redirection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateVNC handles starting/stopping the VNC server based on the config flag.
|
||||
func (e *Engine) updateVNC() error {
|
||||
if !e.config.ServerVNCAllowed {
|
||||
if e.vncSrv != nil {
|
||||
log.Info("VNC server disabled, stopping")
|
||||
}
|
||||
return e.stopVNCServer()
|
||||
}
|
||||
|
||||
if e.config.BlockInbound {
|
||||
log.Info("VNC server disabled because inbound connections are blocked")
|
||||
return e.stopVNCServer()
|
||||
}
|
||||
|
||||
if e.vncSrv != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.startVNCServer()
|
||||
}
|
||||
|
||||
func (e *Engine) startVNCServer() error {
|
||||
if e.wgInterface == nil {
|
||||
return errors.New("wg interface not initialized")
|
||||
}
|
||||
|
||||
capturer, injector, ok := newPlatformVNC()
|
||||
if !ok {
|
||||
log.Debug("VNC server not supported on this platform")
|
||||
return nil
|
||||
}
|
||||
|
||||
netbirdIP := e.wgInterface.Address().IP
|
||||
|
||||
var sessionRecorder func(vncserver.SessionTick)
|
||||
if e.clientMetrics != nil {
|
||||
sessionRecorder = func(t vncserver.SessionTick) {
|
||||
e.clientMetrics.RecordVNCSessionTick(e.ctx, metrics.VNCSessionTick{
|
||||
Period: t.Period,
|
||||
BytesOut: t.BytesOut,
|
||||
Writes: t.Writes,
|
||||
FBUs: t.FBUs,
|
||||
MaxFBUBytes: t.MaxFBUBytes,
|
||||
MaxFBURects: t.MaxFBURects,
|
||||
MaxWriteBytes: t.MaxWriteBytes,
|
||||
WriteNanos: t.WriteNanos,
|
||||
})
|
||||
}
|
||||
}
|
||||
serviceMode := vncNeedsServiceMode()
|
||||
if serviceMode {
|
||||
log.Info("VNC: running as system service, enabling service mode (per-session agent proxy)")
|
||||
}
|
||||
requireApproval := e.config.DisableVNCApproval == nil || !*e.config.DisableVNCApproval
|
||||
srv := vncserver.New(vncserver.Config{
|
||||
Capturer: capturer,
|
||||
Injector: injector,
|
||||
IdentityKey: e.config.WgPrivateKey[:],
|
||||
ServiceMode: serviceMode,
|
||||
SessionRecorder: sessionRecorder,
|
||||
NetstackNet: e.wgInterface.GetNet(),
|
||||
RequireApproval: requireApproval,
|
||||
Approver: &vncApprover{broker: e.approvalBroker, statusRecorder: e.statusRecorder},
|
||||
})
|
||||
|
||||
listenAddr := netip.AddrPortFrom(netbirdIP, vnc.InternalPort)
|
||||
network := e.wgInterface.Address().Network
|
||||
if err := srv.Start(e.ctx, listenAddr, network); err != nil {
|
||||
return fmt.Errorf("start VNC server: %w", err)
|
||||
}
|
||||
|
||||
e.vncSrv = srv
|
||||
|
||||
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.RegisterNetstackService(nftypes.TCP, vnc.InternalPort)
|
||||
log.Debugf("registered VNC service with netstack for TCP:%d", vnc.InternalPort)
|
||||
}
|
||||
}
|
||||
|
||||
if err := e.setupVNCPortRedirection(); err != nil {
|
||||
log.Warnf("setup VNC port redirection: %v", err)
|
||||
}
|
||||
|
||||
log.Info("VNC server enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateVNCServerAuth updates VNC fine-grained access control from management.
|
||||
// A nil vncAuth clears all authorized users and session pubkeys so management
|
||||
// can revoke access by omitting the field on the next sync.
|
||||
func (e *Engine) updateVNCServerAuth(vncAuth *mgmProto.VNCAuth) {
|
||||
if e.vncSrv == nil {
|
||||
return
|
||||
}
|
||||
|
||||
vncSrv, ok := e.vncSrv.(*vncserver.Server)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if vncAuth == nil {
|
||||
vncSrv.UpdateVNCAuth(&sshauth.Config{})
|
||||
return
|
||||
}
|
||||
|
||||
protoUsers := vncAuth.GetAuthorizedUsers()
|
||||
authorizedUsers := make([]sshuserhash.UserIDHash, len(protoUsers))
|
||||
for i, hash := range protoUsers {
|
||||
if len(hash) != 16 {
|
||||
log.Warnf("invalid VNC auth hash length %d, expected 16", len(hash))
|
||||
return
|
||||
}
|
||||
authorizedUsers[i] = sshuserhash.UserIDHash(hash)
|
||||
}
|
||||
|
||||
machineUsers := make(map[string][]uint32)
|
||||
for osUser, indexes := range vncAuth.GetMachineUsers() {
|
||||
machineUsers[osUser] = indexes.GetIndexes()
|
||||
}
|
||||
|
||||
sessionPubKeys := make([]sshauth.SessionPubKey, 0, len(vncAuth.GetSessionPubKeys()))
|
||||
for _, pk := range vncAuth.GetSessionPubKeys() {
|
||||
pub := pk.GetPubKey()
|
||||
if len(pub) != 32 {
|
||||
log.Warnf("VNC session pubkey wrong length %d", len(pub))
|
||||
continue
|
||||
}
|
||||
hash := pk.GetUserIdHash()
|
||||
if len(hash) != 16 {
|
||||
log.Warnf("VNC session user id hash wrong length %d", len(hash))
|
||||
continue
|
||||
}
|
||||
sessionPubKeys = append(sessionPubKeys, sshauth.SessionPubKey{
|
||||
PubKey: pub,
|
||||
UserIDHash: sshuserhash.UserIDHash(hash),
|
||||
DisplayName: pk.GetDisplayName(),
|
||||
})
|
||||
}
|
||||
|
||||
vncSrv.UpdateVNCAuth(&sshauth.Config{
|
||||
AuthorizedUsers: authorizedUsers,
|
||||
MachineUsers: machineUsers,
|
||||
SessionPubKeys: sessionPubKeys,
|
||||
})
|
||||
}
|
||||
|
||||
// GetVNCServerStatus returns whether the VNC server is running and the list
|
||||
// of active VNC sessions. The pointer is captured under syncMsgMux so a
|
||||
// concurrent updateVNC/stopVNCServer cannot swap it out between the nil
|
||||
// check and the ActiveSessions call.
|
||||
func (e *Engine) GetVNCServerStatus() (enabled bool, sessions []vncserver.ActiveSessionInfo) {
|
||||
e.syncMsgMux.Lock()
|
||||
vncSrv := e.vncSrv
|
||||
e.syncMsgMux.Unlock()
|
||||
if vncSrv == nil {
|
||||
return false, nil
|
||||
}
|
||||
return true, vncSrv.ActiveSessions()
|
||||
}
|
||||
|
||||
func (e *Engine) stopVNCServer() error {
|
||||
if e.vncSrv == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.cleanupVNCPortRedirection(); err != nil {
|
||||
log.Warnf("cleanup VNC port redirection: %v", err)
|
||||
}
|
||||
|
||||
if e.wgInterface != nil && e.wgInterface.GetNet() != nil {
|
||||
if registrar, ok := e.firewall.(interface {
|
||||
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.UnregisterNetstackService(nftypes.TCP, vnc.InternalPort)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("stopping VNC server")
|
||||
err := e.vncSrv.Stop()
|
||||
e.vncSrv = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop VNC server: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// vncApprover adapts the generic approval.Broker for the VNC server.
|
||||
type vncApprover struct {
|
||||
broker *approval.Broker
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
func (a *vncApprover) Request(ctx context.Context, info vncserver.ApprovalInfo) (vncserver.ApprovalDecision, error) {
|
||||
// Resolve the source overlay IP to a peer FQDN for the prompt label.
|
||||
if info.PeerName == "" && info.SourceIP != "" && a.statusRecorder != nil {
|
||||
if fqdn, ok := a.statusRecorder.PeerByIP(info.SourceIP); ok {
|
||||
info.PeerName = fqdn
|
||||
}
|
||||
}
|
||||
subject := fmt.Sprintf("VNC connection from %s", displayPeer(info))
|
||||
meta := map[string]string{
|
||||
"peer_name": info.PeerName,
|
||||
"peer_pubkey": info.PeerPubKey,
|
||||
"source_ip": info.SourceIP,
|
||||
"mode": info.Mode,
|
||||
"username": info.Username,
|
||||
"initiator": info.Initiator,
|
||||
}
|
||||
d, err := a.broker.Request(ctx, approval.Prompt{
|
||||
Kind: approval.KindVNC,
|
||||
Subject: subject,
|
||||
Metadata: meta,
|
||||
})
|
||||
if err != nil {
|
||||
return vncserver.ApprovalDecision{}, err
|
||||
}
|
||||
return vncserver.ApprovalDecision{ViewOnly: d.ViewOnly}, nil
|
||||
}
|
||||
|
||||
func displayPeer(info vncserver.ApprovalInfo) string {
|
||||
if info.Initiator != "" {
|
||||
return info.Initiator
|
||||
}
|
||||
if info.PeerName != "" {
|
||||
return info.PeerName
|
||||
}
|
||||
if info.SourceIP != "" {
|
||||
return info.SourceIP
|
||||
}
|
||||
if info.PeerPubKey != "" {
|
||||
return info.PeerPubKey
|
||||
}
|
||||
return "unknown peer"
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
//go:build freebsd
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
// newConsoleVNC builds the FreeBSD console fallback: vt(4) framebuffer
|
||||
// for capture, /dev/uinput for input. The uinput device requires the
|
||||
// `uinput` kernel module (`kldload uinput`); without it, input init
|
||||
// fails and we drop to a stub injector so the user still gets a
|
||||
// view-only screen mirror.
|
||||
func newConsoleVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
|
||||
poller := vncserver.NewFBPoller("")
|
||||
w, h := poller.Width(), poller.Height()
|
||||
if w == 0 || h == 0 {
|
||||
poller.Close()
|
||||
return nil, nil, fmt.Errorf("vt framebuffer init failed (vt may not allow mmap on this driver)")
|
||||
}
|
||||
if inj, err := vncserver.NewUInputInjector(w, h); err == nil {
|
||||
return poller, inj, nil
|
||||
} else {
|
||||
log.Infof("VNC console: uinput unavailable (%v); view-only mode. Run `kldload uinput` to enable input.", err)
|
||||
return poller, &vncserver.StubInputInjector{}, nil
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
// newConsoleVNC builds a framebuffer + uinput VNC backend for boxes
|
||||
// without a running X server. Used as the auto-fallback when
|
||||
// newPlatformVNC can't reach X. Returns an error when /dev/fb0 or
|
||||
// /dev/uinput aren't usable so the caller can drop back to a stub.
|
||||
func newConsoleVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, error) {
|
||||
poller := vncserver.NewFBPoller("")
|
||||
w, h := poller.Width(), poller.Height()
|
||||
if w == 0 || h == 0 {
|
||||
poller.Close()
|
||||
return nil, nil, fmt.Errorf("framebuffer capturer init failed (is /dev/fb0 readable?)")
|
||||
}
|
||||
inj, err := vncserver.NewUInputInjector(w, h)
|
||||
if err != nil {
|
||||
log.Debugf("uinput unavailable, falling back to view-only VNC: %v", err)
|
||||
return poller, &vncserver.StubInputInjector{}, nil
|
||||
}
|
||||
return poller, inj, nil
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
|
||||
capturer := vncserver.NewMacPoller()
|
||||
// Prompt for Screen Recording at server-enable time rather than first
|
||||
// client-connect. The native prompt is far easier for users to act on
|
||||
// in the moment they toggled VNC on than later when "the screen looks
|
||||
// like wallpaper" would otherwise be the only clue.
|
||||
vncserver.PrimeScreenCapturePermission()
|
||||
injector, err := vncserver.NewMacInputInjector()
|
||||
if err != nil {
|
||||
log.Debugf("VNC: macOS input injector: %v", err)
|
||||
return capturer, &vncserver.StubInputInjector{}, true
|
||||
}
|
||||
return capturer, injector, true
|
||||
}
|
||||
|
||||
// vncNeedsServiceMode reports whether the running process is a system
|
||||
// LaunchDaemon (root, parented by launchd). Daemons sit in the global
|
||||
// bootstrap namespace and cannot talk to WindowServer; we route capture
|
||||
// through a per-user agent in that case.
|
||||
func vncNeedsServiceMode() bool {
|
||||
return os.Geteuid() == 0 && os.Getppid() == 1
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
//go:build js || ios || android
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type vncServer interface{}
|
||||
|
||||
func (e *Engine) updateVNC() error { return nil }
|
||||
|
||||
func (e *Engine) updateVNCServerAuth(auth *mgmProto.VNCAuth) {
|
||||
if auth == nil {
|
||||
return
|
||||
}
|
||||
log.Debugf("ignoring VNC auth push on platform without a VNC server: %d session pubkeys, %d authorized users",
|
||||
len(auth.GetSessionPubKeys()), len(auth.GetAuthorizedUsers()))
|
||||
}
|
||||
|
||||
func (e *Engine) stopVNCServer() error { return nil }
|
||||
@@ -1,13 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package internal
|
||||
|
||||
import vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
|
||||
return vncserver.NewDesktopCapturer(), vncserver.NewWindowsInputInjector(), true
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return vncserver.GetCurrentSessionID() == 0
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
//go:build (linux && !android) || freebsd
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
vncserver "github.com/netbirdio/netbird/client/vnc/server"
|
||||
)
|
||||
|
||||
func newPlatformVNC() (vncserver.ScreenCapturer, vncserver.InputInjector, bool) {
|
||||
// Prefer X11 when an X server is reachable. NewX11InputInjector probes
|
||||
// DISPLAY (and /proc) eagerly, so a non-nil error here means no X.
|
||||
injector, err := vncserver.NewX11InputInjector("", "", "")
|
||||
if err == nil {
|
||||
return vncserver.NewX11Poller("", ""), injector, true
|
||||
}
|
||||
log.Debugf("VNC: X11 not available: %v", err)
|
||||
|
||||
// Fallback for headless / pre-X states (kernel console, login manager
|
||||
// without X, physical server in recovery): stream the framebuffer and
|
||||
// inject input via /dev/uinput.
|
||||
consoleCap, consoleInj, err := newConsoleVNC()
|
||||
if err == nil {
|
||||
log.Infof("VNC: using framebuffer console capture (%dx%d)", consoleCap.Width(), consoleCap.Height())
|
||||
return consoleCap, consoleInj, true
|
||||
}
|
||||
log.Debugf("VNC: framebuffer console fallback unavailable: %v", err)
|
||||
|
||||
return &vncserver.StubCapturer{}, &vncserver.StubInputInjector{}, false
|
||||
}
|
||||
|
||||
func vncNeedsServiceMode() bool {
|
||||
return false
|
||||
}
|
||||
@@ -120,36 +120,6 @@ func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentI
|
||||
m.trimLocked()
|
||||
}
|
||||
|
||||
func (m *influxDBMetrics) RecordVNCSessionTick(_ context.Context, agentInfo AgentInfo, tick VNCSessionTick) {
|
||||
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s",
|
||||
agentInfo.DeploymentType.String(),
|
||||
agentInfo.Version,
|
||||
agentInfo.OS,
|
||||
agentInfo.Arch,
|
||||
agentInfo.peerID,
|
||||
)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.samples = append(m.samples, influxSample{
|
||||
measurement: "netbird_vnc_traffic",
|
||||
tags: tags,
|
||||
fields: map[string]float64{
|
||||
"period_seconds": tick.Period.Seconds(),
|
||||
"bytes_out": float64(tick.BytesOut),
|
||||
"writes": float64(tick.Writes),
|
||||
"fbus": float64(tick.FBUs),
|
||||
"max_fbu_bytes": float64(tick.MaxFBUBytes),
|
||||
"max_fbu_rects": float64(tick.MaxFBURects),
|
||||
"max_write_bytes": float64(tick.MaxWriteBytes),
|
||||
"write_time_seconds": float64(tick.WriteNanos) / 1e9,
|
||||
},
|
||||
timestamp: time.Now(),
|
||||
})
|
||||
m.trimLocked()
|
||||
}
|
||||
|
||||
func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) {
|
||||
result := "success"
|
||||
if !success {
|
||||
|
||||
@@ -59,11 +59,6 @@ type metricsImplementation interface {
|
||||
// RecordLoginDuration records how long the login to management took
|
||||
RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool)
|
||||
|
||||
// RecordVNCSessionTick records a periodic snapshot of one VNC
|
||||
// session's wire activity. Called once per metricsConn tick interval
|
||||
// (and once at session close), only when the tick saw activity.
|
||||
RecordVNCSessionTick(ctx context.Context, agentInfo AgentInfo, tick VNCSessionTick)
|
||||
|
||||
// Export exports metrics in InfluxDB line protocol format
|
||||
Export(w io.Writer) error
|
||||
|
||||
@@ -83,21 +78,6 @@ type ClientMetrics struct {
|
||||
pushCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// VNCSessionTick is one sampling slice of a VNC session's wire activity.
|
||||
// BytesOut / Writes / FBUs / WriteNanos are deltas observed during this
|
||||
// tick; Max* fields are the high-water marks observed during the tick.
|
||||
// Period is the wall-clock duration the deltas cover.
|
||||
type VNCSessionTick struct {
|
||||
Period time.Duration
|
||||
BytesOut uint64
|
||||
Writes uint64
|
||||
FBUs uint64
|
||||
MaxFBUBytes uint64
|
||||
MaxFBURects uint64
|
||||
MaxWriteBytes uint64
|
||||
WriteNanos uint64
|
||||
}
|
||||
|
||||
// ConnectionStageTimestamps holds timestamps for each connection stage
|
||||
type ConnectionStageTimestamps struct {
|
||||
SignalingReceived time.Time // First signal received from remote peer (both initial and reconnection)
|
||||
@@ -147,17 +127,6 @@ func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Du
|
||||
c.impl.RecordSyncDuration(ctx, agentInfo, duration)
|
||||
}
|
||||
|
||||
// RecordVNCSessionTick records a periodic snapshot of one VNC session.
|
||||
func (c *ClientMetrics) RecordVNCSessionTick(ctx context.Context, tick VNCSessionTick) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.mu.RLock()
|
||||
agentInfo := c.agentInfo
|
||||
c.mu.RUnlock()
|
||||
c.impl.RecordVNCSessionTick(ctx, agentInfo, tick)
|
||||
}
|
||||
|
||||
// RecordLoginDuration records how long the login to management server took
|
||||
func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) {
|
||||
if c == nil {
|
||||
|
||||
@@ -73,9 +73,6 @@ func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.
|
||||
func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) {
|
||||
}
|
||||
|
||||
func (m *mockMetrics) RecordVNCSessionTick(_ context.Context, _ AgentInfo, _ VNCSessionTick) {
|
||||
}
|
||||
|
||||
func (m *mockMetrics) Export(w io.Writer) error {
|
||||
if m.exportData != "" {
|
||||
_, err := w.Write([]byte(m.exportData))
|
||||
|
||||
@@ -26,6 +26,7 @@ type connStatusInputs struct {
|
||||
iceInProgress bool // a negotiation is currently in flight
|
||||
}
|
||||
|
||||
|
||||
// ConnStatus describe the status of a peer's connection
|
||||
type ConnStatus int32
|
||||
|
||||
|
||||
@@ -193,7 +193,6 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
|
||||
type Status struct {
|
||||
mux sync.RWMutex
|
||||
peers map[string]State
|
||||
ipToKey map[string]string
|
||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||
signalState bool
|
||||
signalError error
|
||||
@@ -232,7 +231,6 @@ type Status struct {
|
||||
func NewRecorder(mgmAddress string) *Status {
|
||||
return &Status{
|
||||
peers: make(map[string]State),
|
||||
ipToKey: make(map[string]string),
|
||||
changeNotify: make(map[string]map[string]*StatusChangeSubscription),
|
||||
eventStreams: make(map[string]chan *proto.SystemEvent),
|
||||
eventQueue: NewEventQueue(eventQueueSize),
|
||||
@@ -284,12 +282,6 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string)
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
d.peerListChangedForNotification = true
|
||||
if ipv6 != "" {
|
||||
d.ipToKey[ipv6] = peerPubKey
|
||||
}
|
||||
if ip != "" {
|
||||
d.ipToKey[ip] = peerPubKey
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -319,22 +311,28 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
|
||||
|
||||
// PeerStateByIP returns the full peer State for the given tunnel IP.
|
||||
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
|
||||
// address so dual-stack peers are reachable on either family. Only
|
||||
// active peers are matched; peers moved into the offline slice by
|
||||
// ReplaceOfflinePeers are intentionally treated as unknown.
|
||||
// address so dual-stack peers are reachable on either family. Searches
|
||||
// both d.peers and d.offlinePeers — peers that have been moved into
|
||||
// the offline slice by ReplaceOfflinePeers are still part of the
|
||||
// account's roster and callers (DNS filter, embed.Client.IdentityForIP)
|
||||
// need to recognise them rather than treating them as unknown. Returns
|
||||
// the zero State and false when no peer matches or the input is empty.
|
||||
func (d *Status) PeerStateByIP(ip string) (State, bool) {
|
||||
if ip == "" {
|
||||
return State{}, false
|
||||
}
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
key, ok := d.ipToKey[ip]
|
||||
if !ok {
|
||||
return State{}, false
|
||||
|
||||
for _, state := range d.peers {
|
||||
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
|
||||
return state, true
|
||||
}
|
||||
}
|
||||
state, ok := d.peers[key]
|
||||
if ok {
|
||||
return state, true
|
||||
for _, state := range d.offlinePeers {
|
||||
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
|
||||
return state, true
|
||||
}
|
||||
}
|
||||
return State{}, false
|
||||
}
|
||||
@@ -344,18 +342,12 @@ func (d *Status) RemovePeer(peerPubKey string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
p, ok := d.peers[peerPubKey]
|
||||
_, ok := d.peers[peerPubKey]
|
||||
if !ok {
|
||||
return errors.New("no peer with to remove")
|
||||
}
|
||||
|
||||
delete(d.peers, peerPubKey)
|
||||
if mappedKey, exists := d.ipToKey[p.IP]; exists && mappedKey == peerPubKey {
|
||||
delete(d.ipToKey, p.IP)
|
||||
}
|
||||
if mappedKey, exists := d.ipToKey[p.IPv6]; exists && mappedKey == peerPubKey {
|
||||
delete(d.ipToKey, p.IPv6)
|
||||
}
|
||||
d.peerListChangedForNotification = true
|
||||
return nil
|
||||
}
|
||||
@@ -1231,15 +1223,6 @@ func (d *Status) SubscribeToEvents() *EventSubscription {
|
||||
}
|
||||
}
|
||||
|
||||
// HasEventSubscribers reports whether any client is currently subscribed
|
||||
// to the daemon's SystemEvent stream. Used by the VNC approval broker to
|
||||
// fail closed when no UI is connected to prompt the user.
|
||||
func (d *Status) HasEventSubscribers() bool {
|
||||
d.eventMux.Lock()
|
||||
defer d.eventMux.Unlock()
|
||||
return len(d.eventStreams) > 0
|
||||
}
|
||||
|
||||
// UnsubscribeFromEvents removes an event subscription
|
||||
func (d *Status) UnsubscribeFromEvents(sub *EventSubscription) {
|
||||
if sub == nil {
|
||||
|
||||
@@ -90,11 +90,12 @@ func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
|
||||
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
|
||||
}
|
||||
|
||||
// TestStatus_PeerStateByIP_IgnoresOfflinePeers documents that peers
|
||||
// moved into the offline slice via ReplaceOfflinePeers are intentionally
|
||||
// not resolvable by IP: only active peers can carry traffic, so callers
|
||||
// (DNS filter, embed.Client.IdentityForIP) treat them as unknown.
|
||||
func TestStatus_PeerStateByIP_IgnoresOfflinePeers(t *testing.T) {
|
||||
// TestStatus_PeerStateByIP_MatchesOfflinePeers covers peers that have
|
||||
// been moved into the offline slice via ReplaceOfflinePeers. Callers
|
||||
// (DNS filter, embed.Client.IdentityForIP) need to treat them as known
|
||||
// rather than unknown — otherwise authentication / DNS filtering treats
|
||||
// known-but-offline peers as foreign IPs.
|
||||
func TestStatus_PeerStateByIP_MatchesOfflinePeers(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
req := require.New(t)
|
||||
|
||||
@@ -102,31 +103,13 @@ func TestStatus_PeerStateByIP_IgnoresOfflinePeers(t *testing.T) {
|
||||
{PubKey: "pk-offline", FQDN: "offline.netbird", IP: "100.64.0.20", IPv6: "fd00::20"},
|
||||
})
|
||||
|
||||
_, ok := status.PeerStateByIP("100.64.0.20")
|
||||
req.False(ok, "offline peer must not resolve by IPv4 tunnel address")
|
||||
state, ok := status.PeerStateByIP("100.64.0.20")
|
||||
req.True(ok, "offline peer must resolve by IPv4 tunnel address")
|
||||
req.Equal("pk-offline", state.PubKey, "matching state must carry the offline peer's pub key")
|
||||
|
||||
_, ok = status.PeerStateByIP("fd00::20")
|
||||
req.False(ok, "offline peer must not resolve by IPv6 tunnel address")
|
||||
}
|
||||
|
||||
// TestStatus_PeerStateByIP_RemovedPeer verifies RemovePeer drops the
|
||||
// IP index entries for both address families.
|
||||
func TestStatus_PeerStateByIP_RemovedPeer(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
req := require.New(t)
|
||||
|
||||
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", "fd00::1"))
|
||||
|
||||
_, ok := status.PeerStateByIP("100.64.0.10")
|
||||
req.True(ok, "active peer must resolve before removal")
|
||||
|
||||
req.NoError(status.RemovePeer("pk-1"))
|
||||
|
||||
_, ok = status.PeerStateByIP("100.64.0.10")
|
||||
req.False(ok, "removed peer must not resolve by IPv4 tunnel address")
|
||||
|
||||
_, ok = status.PeerStateByIP("fd00::1")
|
||||
req.False(ok, "removed peer must not resolve by IPv6 tunnel address")
|
||||
state, ok = status.PeerStateByIP("fd00::20")
|
||||
req.True(ok, "offline peer must resolve by IPv6 tunnel address")
|
||||
req.Equal("pk-offline", state.PubKey, "IPv6 match must carry the offline peer's pub key")
|
||||
}
|
||||
|
||||
func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -58,10 +57,6 @@ var DefaultInterfaceBlacklist = []string{
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
|
||||
}
|
||||
|
||||
// loadMDMPolicy is the package-level indirection used by apply() to read the
|
||||
// active MDM policy. Tests override this to inject a fake policy.
|
||||
var loadMDMPolicy = mdm.LoadPolicy
|
||||
|
||||
// ConfigInput carries configuration changes to the client
|
||||
type ConfigInput struct {
|
||||
ManagementURL string
|
||||
@@ -70,8 +65,6 @@ type ConfigInput struct {
|
||||
StateFilePath string
|
||||
PreSharedKey *string
|
||||
ServerSSHAllowed *bool
|
||||
ServerVNCAllowed *bool
|
||||
DisableVNCApproval *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
@@ -123,8 +116,6 @@ type Config struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed *bool
|
||||
ServerVNCAllowed *bool
|
||||
DisableVNCApproval *bool
|
||||
EnableSSHRoot *bool
|
||||
EnableSSHSFTP *bool
|
||||
EnableSSHLocalPortForwarding *bool
|
||||
@@ -183,23 +174,6 @@ type Config struct {
|
||||
LazyConnectionEnabled bool
|
||||
|
||||
MTU uint16
|
||||
|
||||
// policy is the MDM policy that produced the currently-set values for
|
||||
// any MDM-enforced fields. Set by applyMDMPolicy at the tail of apply()
|
||||
// and reset on every apply() invocation. Never persisted to disk.
|
||||
// Callers query enforcement state via Policy() and the mdm.Policy API
|
||||
// (HasKey, ManagedKeys, IsEmpty).
|
||||
policy *mdm.Policy `json:"-"`
|
||||
}
|
||||
|
||||
// Policy returns the MDM policy applied to this Config. Returns a non-nil
|
||||
// empty Policy when MDM enforcement is inactive; callers can always invoke
|
||||
// HasKey / ManagedKeys / IsEmpty without a nil check.
|
||||
func (config *Config) Policy() *mdm.Policy {
|
||||
if config == nil || config.policy == nil {
|
||||
return mdm.NewPolicy(nil)
|
||||
}
|
||||
return config.policy
|
||||
}
|
||||
|
||||
var ConfigDirOverride string
|
||||
@@ -444,33 +418,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.ServerVNCAllowed != nil {
|
||||
if config.ServerVNCAllowed == nil || *input.ServerVNCAllowed != *config.ServerVNCAllowed {
|
||||
if *input.ServerVNCAllowed {
|
||||
log.Infof("enabling VNC server")
|
||||
} else {
|
||||
log.Infof("disabling VNC server")
|
||||
}
|
||||
config.ServerVNCAllowed = input.ServerVNCAllowed
|
||||
updated = true
|
||||
}
|
||||
} else if config.ServerVNCAllowed == nil {
|
||||
config.ServerVNCAllowed = util.False()
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableVNCApproval != nil {
|
||||
if config.DisableVNCApproval == nil || *input.DisableVNCApproval != *config.DisableVNCApproval {
|
||||
if *input.DisableVNCApproval {
|
||||
log.Infof("disabling VNC connection approval prompt")
|
||||
} else {
|
||||
log.Infof("enabling VNC connection approval prompt")
|
||||
}
|
||||
config.DisableVNCApproval = input.DisableVNCApproval
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
|
||||
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
||||
if *input.EnableSSHRoot {
|
||||
log.Infof("enabling SSH root login")
|
||||
@@ -665,93 +612,10 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
// MDM is the last override layer: any key present in the policy
|
||||
// supersedes defaults, on-disk config, env vars and CLI input.
|
||||
config.applyMDMPolicy(loadMDMPolicy())
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// applyMDMPolicy overlays MDM-supplied values on top of the resolved Config.
|
||||
// The provided Policy is also stored on the Config so callers can later query
|
||||
// which fields are enforced. Invalid values (e.g. malformed URLs) are logged
|
||||
// and skipped to avoid bricking the client; the field keeps its previous
|
||||
// resolved value but is still marked as managed (Policy.HasKey returns true
|
||||
// for the key, so per-field rejection of user writes still applies).
|
||||
func (config *Config) applyMDMPolicy(policy *mdm.Policy) {
|
||||
config.policy = policy
|
||||
if policy.IsEmpty() {
|
||||
return
|
||||
}
|
||||
|
||||
// Helper: log the application of a single MDM-managed key. Values for
|
||||
// keys in mdm.SecretKeys are redacted.
|
||||
logApplied := func(key string, displayValue any) {
|
||||
if _, secret := mdm.SecretKeys[key]; secret {
|
||||
log.Infof("MDM override %s = ********** (secret)", key)
|
||||
return
|
||||
}
|
||||
log.Infof("MDM override %s = %v", key, displayValue)
|
||||
}
|
||||
|
||||
if v, ok := policy.GetString(mdm.KeyManagementURL); ok {
|
||||
if u, err := parseURL("Management URL", v); err != nil {
|
||||
log.Warnf("MDM management URL %q invalid: %v; keeping previous value", v, err)
|
||||
} else {
|
||||
config.ManagementURL = u
|
||||
logApplied(mdm.KeyManagementURL, u.String())
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := policy.GetString(mdm.KeyPreSharedKey); ok {
|
||||
// Defensive: refuse the redaction mask in case it round-tripped
|
||||
// through a manifest by mistake.
|
||||
if !isPreSharedKeyHidden(&v) {
|
||||
config.PreSharedKey = v
|
||||
logApplied(mdm.KeyPreSharedKey, "")
|
||||
}
|
||||
}
|
||||
|
||||
// applyBool collapses the per-key "read + set + log" boilerplate
|
||||
// for every plain bool MDM key into a single helper. Keeps the
|
||||
// outer function's cognitive complexity below SonarCube's
|
||||
// threshold; functional behaviour is identical to the inlined
|
||||
// branches it replaces.
|
||||
applyBool := func(key string, setter func(bool)) {
|
||||
v, ok := policy.GetBool(key)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
setter(v)
|
||||
logApplied(key, v)
|
||||
}
|
||||
|
||||
applyBool(mdm.KeyAllowServerSSH, func(v bool) { bv := v; config.ServerSSHAllowed = &bv })
|
||||
applyBool(mdm.KeyDisableClientRoutes, func(v bool) { config.DisableClientRoutes = v })
|
||||
applyBool(mdm.KeyDisableServerRoutes, func(v bool) { config.DisableServerRoutes = v })
|
||||
applyBool(mdm.KeyBlockInbound, func(v bool) { config.BlockInbound = v })
|
||||
applyBool(mdm.KeyDisableAutoConnect, func(v bool) { config.DisableAutoConnect = v })
|
||||
applyBool(mdm.KeyRosenpassEnabled, func(v bool) { config.RosenpassEnabled = v })
|
||||
applyBool(mdm.KeyRosenpassPermissive, func(v bool) { config.RosenpassPermissive = v })
|
||||
|
||||
if v, ok := policy.GetInt(mdm.KeyWireguardPort); ok {
|
||||
// REG_DWORD is 32-bit; UDP port range is 1-65535. Clamp at the
|
||||
// upper bound and reject obviously-invalid values to avoid the
|
||||
// engine binding to an unusable port if the admin pushes garbage.
|
||||
if v >= 1 && v <= 65535 {
|
||||
config.WgPort = int(v)
|
||||
logApplied(mdm.KeyWireguardPort, v)
|
||||
} else {
|
||||
log.Warnf("MDM wireguard port %d out of range [1,65535]; keeping previous value", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseURL parses and validates the URL for the named service. The URL
|
||||
// must use the http or https scheme; if no port is present, ":443" is
|
||||
// appended for https or ":80" for http. The serviceName parameter is
|
||||
// used to contextualise error messages. On success returns the parsed
|
||||
// *url.URL; on failure returns a non-nil error.
|
||||
// parseURL parses and validates a service URL
|
||||
func parseURL(serviceName, serviceURL string) (*url.URL, error) {
|
||||
parsedMgmtURL, err := url.ParseRequestURI(serviceURL)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
package profilemanager
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
)
|
||||
|
||||
// withMDMPolicy temporarily overrides the package-level loadMDMPolicy hook so
|
||||
// apply() observes the supplied Policy. The original loader is restored at
|
||||
// test cleanup.
|
||||
func withMDMPolicy(t *testing.T, policy *mdm.Policy) {
|
||||
t.Helper()
|
||||
prev := loadMDMPolicy
|
||||
loadMDMPolicy = func() *mdm.Policy { return policy }
|
||||
t.Cleanup(func() { loadMDMPolicy = prev })
|
||||
}
|
||||
|
||||
func TestApply_MDMEmpty_NoEnforcement(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(nil))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.True(t, cfg.Policy().IsEmpty(), "no MDM source ⇒ empty Policy")
|
||||
assert.False(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
assert.Empty(t, cfg.Policy().ManagedKeys())
|
||||
|
||||
// Default management URL still resolves.
|
||||
assert.Equal(t, DefaultManagementURL, cfg.ManagementURL.String())
|
||||
}
|
||||
|
||||
func TestApply_MDMOnly_OverridesDefaults(t *testing.T) {
|
||||
const mdmURL = "https://corp.mdm.example.com:443"
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: mdmURL,
|
||||
mdm.KeyDisableClientRoutes: true,
|
||||
mdm.KeyBlockInbound: true,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.Equal(t, mdmURL, cfg.ManagementURL.String())
|
||||
assert.True(t, cfg.DisableClientRoutes)
|
||||
assert.True(t, cfg.BlockInbound)
|
||||
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyDisableClientRoutes))
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyBlockInbound))
|
||||
assert.False(t, cfg.Policy().HasKey(mdm.KeyAllowServerSSH))
|
||||
}
|
||||
|
||||
func TestApply_MDMBeatsCLIInput(t *testing.T) {
|
||||
const mdmURL = "https://mdm.example.com:443"
|
||||
const cliURL = "https://cli.example.com:443"
|
||||
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: mdmURL,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
ManagementURL: cliURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// MDM wins over CLI-supplied management URL.
|
||||
assert.Equal(t, mdmURL, cfg.ManagementURL.String())
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
}
|
||||
|
||||
func TestApply_MDMInvalidURL_KeepsPreviousValue(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "not-a-url",
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// Invalid MDM URL is logged and skipped: default URL stays in place
|
||||
// to keep the client functional.
|
||||
assert.Equal(t, DefaultManagementURL, cfg.ManagementURL.String())
|
||||
|
||||
// But the key is still considered MDM-managed (admin intent is to
|
||||
// enforce, daemon rejects user writes to this field — phase-1 scaffolding
|
||||
// reflects this by keeping Policy.HasKey true even on parse failure).
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
}
|
||||
|
||||
func TestApply_MDMBoolKeysOverrideOnDiskValue(t *testing.T) {
|
||||
tmp := filepath.Join(t.TempDir(), "config.json")
|
||||
|
||||
// Seed without MDM.
|
||||
withMDMPolicy(t, mdm.NewPolicy(nil))
|
||||
_, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: tmp,
|
||||
DisableClientRoutes: boolPtr(false),
|
||||
RosenpassEnabled: boolPtr(false),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now enable MDM enforcement for these keys.
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyDisableClientRoutes: true,
|
||||
mdm.KeyRosenpassEnabled: true,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{ConfigPath: tmp})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.True(t, cfg.DisableClientRoutes, "MDM override should flip on-disk false to true")
|
||||
assert.True(t, cfg.RosenpassEnabled)
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyDisableClientRoutes))
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyRosenpassEnabled))
|
||||
}
|
||||
|
||||
func TestApply_MDMPreSharedKeyRedactionSentinelRejected(t *testing.T) {
|
||||
const maskSentinel = "**********"
|
||||
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyPreSharedKey: maskSentinel,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// Mask sentinel must not be persisted as the actual PSK.
|
||||
assert.NotEqual(t, maskSentinel, cfg.PreSharedKey)
|
||||
// Key still marked managed so user writes are still rejected.
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyPreSharedKey))
|
||||
}
|
||||
|
||||
func boolPtr(b bool) *bool { return &b }
|
||||
@@ -700,13 +700,6 @@ func resolveURLsToIPs(urls []string) []net.IP {
|
||||
|
||||
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
|
||||
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
|
||||
// An explicit user "deselect all" must not be overridden by management auto-apply.
|
||||
// Auto-applying an exit node here would call SelectRoutes, which clears the
|
||||
// deselect-all flag and re-enables every route the user turned off.
|
||||
if m.routeSelector.IsDeselectAll() {
|
||||
return
|
||||
}
|
||||
|
||||
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
|
||||
if len(exitNodeInfo.allIDs) == 0 {
|
||||
return
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func exitNodeRoutes(netID route.NetID, skipAutoApply bool) route.HAMap {
|
||||
haID := route.HAUniqueID(string(netID) + "|0.0.0.0/0")
|
||||
return route.HAMap{
|
||||
haID: []*route.Route{
|
||||
{
|
||||
ID: "r-" + route.ID(netID),
|
||||
NetID: netID,
|
||||
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Enabled: true,
|
||||
SkipAutoApply: skipAutoApply,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement(t *testing.T) {
|
||||
t.Run("management auto-apply selects exit node without user selection", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
routes := exitNodeRoutes("exit1", false)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsSelected("exit1"), "auto-apply exit node should be selected")
|
||||
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "selected exit node should pass the filter")
|
||||
})
|
||||
|
||||
t.Run("management SkipAutoApply leaves exit node deselected", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
routes := exitNodeRoutes("exit1", true)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.False(t, m.routeSelector.IsSelected("exit1"), "SkipAutoApply exit node should not be selected")
|
||||
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "deselected exit node should be filtered out")
|
||||
})
|
||||
|
||||
t.Run("user selection is not overridden by management", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
require.NoError(t, m.routeSelector.SelectRoutes([]route.NetID{"exit1"}, true, []route.NetID{"exit1"}))
|
||||
routes := exitNodeRoutes("exit1", true)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsSelected("exit1"), "explicit user selection must survive a management sync that wants to skip auto-apply")
|
||||
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "user-selected exit node should pass the filter")
|
||||
})
|
||||
|
||||
t.Run("deselect-all is preserved across a management sync", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
m.routeSelector.DeselectAllRoutes()
|
||||
routes := exitNodeRoutes("exit1", false)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsDeselectAll(), "an explicit deselect-all must not be cleared by management auto-apply")
|
||||
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "no routes should be selected while deselect-all is set")
|
||||
})
|
||||
}
|
||||
@@ -116,14 +116,6 @@ func (rs *RouteSelector) DeselectAllRoutes() {
|
||||
clear(rs.selectedRoutes)
|
||||
}
|
||||
|
||||
// IsDeselectAll reports whether the user has explicitly deselected all routes.
|
||||
func (rs *RouteSelector) IsDeselectAll() bool {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
return rs.deselectAll
|
||||
}
|
||||
|
||||
// IsSelected checks if a specific route is selected.
|
||||
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
rs.mu.RLock()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build cgo && !osusergo && !windows
|
||||
|
||||
package server
|
||||
package shell
|
||||
|
||||
import "os/user"
|
||||
|
||||
@@ -8,17 +8,22 @@ import "os/user"
|
||||
// When CGO is enabled, os/user uses libc (getpwnam_r) which goes through
|
||||
// the NSS stack natively. If it fails, the user truly doesn't exist and
|
||||
// getent would also fail.
|
||||
func lookupWithGetent(username string) (*user.User, error) {
|
||||
func LookupWithGetent(username string) (*user.User, error) {
|
||||
return user.Lookup(username)
|
||||
}
|
||||
|
||||
// currentUserWithGetent with CGO delegates directly to os/user.Current.
|
||||
func currentUserWithGetent() (*user.User, error) {
|
||||
func CurrentUserWithGetent() (*user.User, error) {
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
// LookupGroupWithGetent returns the resolved group from either a gid or groupname
|
||||
func LookupGroupWithGetent(name string) (*user.Group, error) {
|
||||
return user.LookupGroup(name)
|
||||
}
|
||||
|
||||
// groupIdsWithFallback with CGO delegates directly to user.GroupIds.
|
||||
// libc's getgrouplist handles NSS groups natively.
|
||||
func groupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
func GroupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
return u.GroupIds()
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build (!cgo || osusergo) && !windows
|
||||
|
||||
package server
|
||||
package shell
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
// lookupWithGetent looks up a user by name, falling back to getent if os/user fails.
|
||||
// Without CGO, os/user only reads /etc/passwd and misses NSS-provided users.
|
||||
// getent goes through the host's NSS stack.
|
||||
func lookupWithGetent(username string) (*user.User, error) {
|
||||
func LookupWithGetent(username string) (*user.User, error) {
|
||||
u, err := user.Lookup(username)
|
||||
if err == nil {
|
||||
return u, nil
|
||||
@@ -22,7 +22,7 @@ func lookupWithGetent(username string) (*user.User, error) {
|
||||
stdErr := err
|
||||
log.Debugf("os/user.Lookup(%q) failed, trying getent: %v", username, err)
|
||||
|
||||
u, _, getentErr := runGetent(username)
|
||||
u, _, getentErr := runGetentPasswd(username)
|
||||
if getentErr != nil {
|
||||
log.Debugf("getent fallback for %q also failed: %v", username, getentErr)
|
||||
return nil, stdErr
|
||||
@@ -31,8 +31,25 @@ func lookupWithGetent(username string) (*user.User, error) {
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// LookupGroupWithGetent returns the resolved group from either a gid or groupname
|
||||
func LookupGroupWithGetent(name string) (*user.Group, error) {
|
||||
g, err := user.LookupGroup(name)
|
||||
if err == nil {
|
||||
return g, nil
|
||||
}
|
||||
|
||||
stdErr := err
|
||||
log.Debugf("os/user.LookupGroup(%q) failed, trying getent: %v", name, err)
|
||||
g, getentErr := runGetentGroup(name)
|
||||
if getentErr != nil {
|
||||
log.Debugf("getent fallback for %q also failed: %v", name, getentErr)
|
||||
return nil, stdErr
|
||||
}
|
||||
return g, nil
|
||||
}
|
||||
|
||||
// currentUserWithGetent gets the current user, falling back to getent if os/user fails.
|
||||
func currentUserWithGetent() (*user.User, error) {
|
||||
func CurrentUserWithGetent() (*user.User, error) {
|
||||
u, err := user.Current()
|
||||
if err == nil {
|
||||
return u, nil
|
||||
@@ -42,7 +59,7 @@ func currentUserWithGetent() (*user.User, error) {
|
||||
uid := strconv.Itoa(os.Getuid())
|
||||
log.Debugf("os/user.Current() failed, trying getent with UID %s: %v", uid, err)
|
||||
|
||||
u, _, getentErr := runGetent(uid)
|
||||
u, _, getentErr := runGetentPasswd(uid)
|
||||
if getentErr != nil {
|
||||
return nil, stdErr
|
||||
}
|
||||
@@ -57,7 +74,7 @@ func currentUserWithGetent() (*user.User, error) {
|
||||
// only reads /etc/group and silently returns incomplete results for NSS users
|
||||
// (no error, just missing groups). The id command goes through NSS and returns
|
||||
// the full set.
|
||||
func groupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
func GroupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
ids, err := runIdGroups(u.Username)
|
||||
if err == nil {
|
||||
return ids, nil
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package shell
|
||||
|
||||
import (
|
||||
"os/user"
|
||||
@@ -15,7 +15,7 @@ func TestLookupWithGetent_CurrentUser(t *testing.T) {
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
u, err := lookupWithGetent(current.Username)
|
||||
u, err := LookupWithGetent(current.Username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, current.Username, u.Username)
|
||||
assert.Equal(t, current.Uid, u.Uid)
|
||||
@@ -23,7 +23,7 @@ func TestLookupWithGetent_CurrentUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLookupWithGetent_NonexistentUser(t *testing.T) {
|
||||
_, err := lookupWithGetent("nonexistent_user_xyzzy_12345")
|
||||
_, err := LookupWithGetent("nonexistent_user_xyzzy_12345")
|
||||
require.Error(t, err, "should fail for nonexistent user")
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ func TestCurrentUserWithGetent(t *testing.T) {
|
||||
stdUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
u, err := currentUserWithGetent()
|
||||
u, err := CurrentUserWithGetent()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, stdUser.Uid, u.Uid)
|
||||
assert.Equal(t, stdUser.Username, u.Username)
|
||||
@@ -41,7 +41,7 @@ func TestGroupIdsWithFallback_CurrentUser(t *testing.T) {
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, err := groupIdsWithFallback(current)
|
||||
groups, err := GroupIdsWithFallback(current)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, groups, "current user should have at least one group")
|
||||
|
||||
@@ -56,7 +56,7 @@ func TestGroupIdsWithFallback_CurrentUser(t *testing.T) {
|
||||
func TestGetShellFromGetent_CurrentUser(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
// Windows stub always returns empty, which is correct
|
||||
shell := getShellFromGetent("1000")
|
||||
shell := GetShellFromGetent("1000")
|
||||
assert.Empty(t, shell, "Windows stub should return empty")
|
||||
return
|
||||
}
|
||||
@@ -65,7 +65,7 @@ func TestGetShellFromGetent_CurrentUser(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// getent may not be available on all systems (e.g., macOS without Homebrew getent)
|
||||
shell := getShellFromGetent(current.Uid)
|
||||
shell := GetShellFromGetent(current.Uid)
|
||||
if shell == "" {
|
||||
t.Log("getShellFromGetent returned empty, getent may not be available")
|
||||
return
|
||||
@@ -78,7 +78,7 @@ func TestLookupWithGetent_RootUser(t *testing.T) {
|
||||
t.Skip("no root user on Windows")
|
||||
}
|
||||
|
||||
u, err := lookupWithGetent("root")
|
||||
u, err := LookupWithGetent("root")
|
||||
if err != nil {
|
||||
t.Skip("root user not available on this system")
|
||||
}
|
||||
@@ -91,20 +91,20 @@ func TestLookupWithGetent_RootUser(t *testing.T) {
|
||||
// consistent and correct results when composed together.
|
||||
func TestIntegration_FullLookupChain(t *testing.T) {
|
||||
// Step 1: currentUserWithGetent must resolve the running user.
|
||||
current, err := currentUserWithGetent()
|
||||
current, err := CurrentUserWithGetent()
|
||||
require.NoError(t, err, "currentUserWithGetent must resolve the running user")
|
||||
require.NotEmpty(t, current.Uid)
|
||||
require.NotEmpty(t, current.Username)
|
||||
|
||||
// Step 2: lookupWithGetent by the same username must return matching identity.
|
||||
byName, err := lookupWithGetent(current.Username)
|
||||
byName, err := LookupWithGetent(current.Username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, current.Uid, byName.Uid, "lookup by name should return same UID")
|
||||
assert.Equal(t, current.Gid, byName.Gid, "lookup by name should return same GID")
|
||||
assert.Equal(t, current.HomeDir, byName.HomeDir, "lookup by name should return same home")
|
||||
|
||||
// Step 3: groupIdsWithFallback must return at least the primary GID.
|
||||
groups, err := groupIdsWithFallback(current)
|
||||
groups, err := GroupIdsWithFallback(current)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, groups, "user must have at least one group")
|
||||
|
||||
@@ -123,7 +123,7 @@ func TestIntegration_FullLookupChain(t *testing.T) {
|
||||
// Step 4: getShellFromGetent should either return a valid shell path or empty
|
||||
// (empty is OK when getent is not available, e.g. macOS without Homebrew getent).
|
||||
if runtime.GOOS != "windows" {
|
||||
shell := getShellFromGetent(current.Uid)
|
||||
shell := GetShellFromGetent(current.Uid)
|
||||
if shell != "" {
|
||||
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
|
||||
}
|
||||
@@ -138,10 +138,10 @@ func TestIntegration_LookupAndGroupsConsistency(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate the SSH server flow: lookup user, then get their groups.
|
||||
resolved, err := lookupWithGetent(current.Username)
|
||||
resolved, err := LookupWithGetent(current.Username)
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, err := groupIdsWithFallback(resolved)
|
||||
groups, err := GroupIdsWithFallback(resolved)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, groups, "resolved user must have groups")
|
||||
|
||||
@@ -154,19 +154,3 @@ func TestIntegration_LookupAndGroupsConsistency(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_ShellLookupChain tests the full shell resolution chain
|
||||
// (getShellFromPasswd -> getShellFromGetent -> $SHELL -> default) on Unix.
|
||||
func TestIntegration_ShellLookupChain(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Unix shell lookup not applicable on Windows")
|
||||
}
|
||||
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// getUserShell is the top-level function used by the SSH server.
|
||||
shell := getUserShell(current.Uid)
|
||||
require.NotEmpty(t, shell, "getUserShell must always return a shell")
|
||||
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
package shell
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -14,19 +14,25 @@ import (
|
||||
|
||||
const getentTimeout = 5 * time.Second
|
||||
|
||||
// getShellFromGetent gets a user's login shell via getent by UID.
|
||||
// GetShellFromGetent gets a user's login shell via getent by UID.
|
||||
// This is needed even with CGO because getShellFromPasswd reads /etc/passwd
|
||||
// directly and won't find NSS-provided users there.
|
||||
func getShellFromGetent(userID string) string {
|
||||
_, shell, err := runGetent(userID)
|
||||
func GetShellFromGetent(userID string) string {
|
||||
_, shell, err := runGetentPasswd(userID)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return shell
|
||||
}
|
||||
|
||||
// runGetent executes `getent passwd <query>` and returns the user and login shell.
|
||||
func runGetent(query string) (*user.User, string, error) {
|
||||
// GetUserFromGetent returns the resolved group from either a uid or username
|
||||
func GetUserFromGetent(user string) (*user.User, error) {
|
||||
u, _, err := runGetentPasswd(user)
|
||||
return u, err
|
||||
}
|
||||
|
||||
// runGetentPasswd executes `getent passwd <query>` and returns the user and login shell.
|
||||
func runGetentPasswd(query string) (*user.User, string, error) {
|
||||
if !validateGetentInput(query) {
|
||||
return nil, "", fmt.Errorf("invalid getent input: %q", query)
|
||||
}
|
||||
@@ -42,6 +48,23 @@ func runGetent(query string) (*user.User, string, error) {
|
||||
return parseGetentPasswd(string(out))
|
||||
}
|
||||
|
||||
// runGetentGroup executes `getent group <query>` and returns the group
|
||||
func runGetentGroup(query string) (*user.Group, error) {
|
||||
if !validateGetentInput(query) {
|
||||
return nil, fmt.Errorf("invalid getent input: %q", query)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), getentTimeout)
|
||||
defer cancel()
|
||||
|
||||
out, err := exec.CommandContext(ctx, "getent", "group", query).Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getent passwd%s: %w", query, err)
|
||||
}
|
||||
|
||||
return parseGetentGroup(string(out))
|
||||
}
|
||||
|
||||
// parseGetentPasswd parses getent passwd output: "name:x:uid:gid:gecos:home:shell"
|
||||
func parseGetentPasswd(output string) (*user.User, string, error) {
|
||||
fields := strings.SplitN(strings.TrimSpace(output), ":", 8)
|
||||
@@ -67,6 +90,20 @@ func parseGetentPasswd(output string) (*user.User, string, error) {
|
||||
}, shell, nil
|
||||
}
|
||||
|
||||
// parseGetentGroup parses getent group output: "group:x:gid:user"
|
||||
func parseGetentGroup(output string) (*user.Group, error) {
|
||||
fields := strings.SplitN(strings.TrimSpace(output), ":", 8)
|
||||
if len(fields) < 4 {
|
||||
return nil, fmt.Errorf("unexpected getent output (need 4+ fields): %q", output)
|
||||
}
|
||||
|
||||
if fields[0] == "" || fields[2] == "" {
|
||||
return nil, fmt.Errorf("missing required fields in getent output: %q", output)
|
||||
}
|
||||
|
||||
return &user.Group{Gid: fields[2], Name: fields[0]}, nil
|
||||
}
|
||||
|
||||
// validateGetentInput checks that the input is safe to pass to getent or id.
|
||||
// Allows POSIX usernames, numeric UIDs, and common NSS extensions
|
||||
// (@ for Kerberos, $ for Samba, + for NIS compat).
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
package shell
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
@@ -195,7 +195,7 @@ func TestRunGetent_RootUser(t *testing.T) {
|
||||
t.Skip("getent not available on this system")
|
||||
}
|
||||
|
||||
u, shell, err := runGetent("root")
|
||||
u, shell, err := runGetentPasswd("root")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "root", u.Username)
|
||||
assert.Equal(t, "0", u.Uid)
|
||||
@@ -208,7 +208,7 @@ func TestRunGetent_ByUID(t *testing.T) {
|
||||
t.Skip("getent not available on this system")
|
||||
}
|
||||
|
||||
u, _, err := runGetent("0")
|
||||
u, _, err := runGetentPasswd("0")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "root", u.Username)
|
||||
assert.Equal(t, "0", u.Uid)
|
||||
@@ -219,15 +219,15 @@ func TestRunGetent_NonexistentUser(t *testing.T) {
|
||||
t.Skip("getent not available on this system")
|
||||
}
|
||||
|
||||
_, _, err := runGetent("nonexistent_user_xyzzy_12345")
|
||||
_, _, err := runGetentPasswd("nonexistent_user_xyzzy_12345")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRunGetent_InvalidInput(t *testing.T) {
|
||||
_, _, err := runGetent("")
|
||||
_, _, err := runGetentPasswd("")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, _, err = runGetent("user\x00name")
|
||||
_, _, err = runGetentPasswd("user\x00name")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -236,7 +236,7 @@ func TestRunGetent_NotAvailable(t *testing.T) {
|
||||
t.Skip("getent is available, can't test missing case")
|
||||
}
|
||||
|
||||
_, _, err := runGetent("root")
|
||||
_, _, err := runGetentPasswd("root")
|
||||
assert.Error(t, err, "should fail when getent is not installed")
|
||||
}
|
||||
|
||||
@@ -283,7 +283,7 @@ func TestGetentResultsMatchStdlib(t *testing.T) {
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
getentUser, _, err := runGetent(current.Username)
|
||||
getentUser, _, err := runGetentPasswd(current.Username)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, current.Username, getentUser.Username, "username should match")
|
||||
@@ -300,7 +300,7 @@ func TestGetentResultsMatchStdlib_ByUID(t *testing.T) {
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
getentUser, _, err := runGetent(current.Uid)
|
||||
getentUser, _, err := runGetentPasswd(current.Uid)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, current.Username, getentUser.Username, "username should match when looked up by UID")
|
||||
@@ -356,7 +356,7 @@ func TestGetShellFromPasswd_CurrentUser(t *testing.T) {
|
||||
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
|
||||
|
||||
if _, err := exec.LookPath("getent"); err == nil {
|
||||
_, getentShell, getentErr := runGetent(current.Uid)
|
||||
_, getentShell, getentErr := runGetentPasswd(current.Uid)
|
||||
if getentErr == nil && getentShell != "" {
|
||||
assert.Equal(t, getentShell, shell, "shell from /etc/passwd should match getent")
|
||||
}
|
||||
@@ -400,7 +400,7 @@ func TestGetShellFromPasswd_MatchesGetentForKnownUsers(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
|
||||
_, getentShell, err := runGetent(uid)
|
||||
_, getentShell, err := runGetentPasswd(uid)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -1,26 +1,26 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
package shell
|
||||
|
||||
import "os/user"
|
||||
|
||||
// lookupWithGetent on Windows just delegates to os/user.Lookup.
|
||||
// Windows does not use NSS/getent; its user lookup works without CGO.
|
||||
func lookupWithGetent(username string) (*user.User, error) {
|
||||
func LookupWithGetent(username string) (*user.User, error) {
|
||||
return user.Lookup(username)
|
||||
}
|
||||
|
||||
// currentUserWithGetent on Windows just delegates to os/user.Current.
|
||||
func currentUserWithGetent() (*user.User, error) {
|
||||
func CurrentUserWithGetent() (*user.User, error) {
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
// getShellFromGetent is a no-op on Windows; shell resolution uses PowerShell detection.
|
||||
func getShellFromGetent(_ string) string {
|
||||
func GetShellFromGetent(_ string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// groupIdsWithFallback on Windows just delegates to u.GroupIds().
|
||||
func groupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
func GroupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
return u.GroupIds()
|
||||
}
|
||||
@@ -1,17 +1,14 @@
|
||||
package server
|
||||
package shell
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -24,7 +21,7 @@ const (
|
||||
|
||||
// getUserShell returns the appropriate shell for the given user ID
|
||||
// Handles all platform-specific logic and fallbacks consistently
|
||||
func getUserShell(userID string) string {
|
||||
func GetUserShell(userID string) string {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return getWindowsUserShell()
|
||||
@@ -56,7 +53,7 @@ func getUnixUserShell(userID string) string {
|
||||
return shell
|
||||
}
|
||||
|
||||
if shell := getShellFromGetent(userID); shell != "" {
|
||||
if shell := GetShellFromGetent(userID); shell != "" {
|
||||
return shell
|
||||
}
|
||||
|
||||
@@ -101,8 +98,8 @@ func getShellFromPasswd(userID string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// prepareUserEnv prepares environment variables for user execution
|
||||
func prepareUserEnv(user *user.User, shell string) []string {
|
||||
// PrepareUserEnv prepares environment variables for user execution
|
||||
func PrepareUserEnv(user *user.User, shell string) []string {
|
||||
pathValue := "/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games"
|
||||
if runtime.GOOS == "windows" {
|
||||
pathValue = `C:\Windows\System32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0`
|
||||
@@ -119,7 +116,7 @@ func prepareUserEnv(user *user.User, shell string) []string {
|
||||
|
||||
// acceptEnv checks if environment variable from SSH client should be accepted
|
||||
// This is a whitelist of variables that SSH clients can send to the server
|
||||
func acceptEnv(envVar string) bool {
|
||||
func AcceptEnv(envVar string) bool {
|
||||
varName := envVar
|
||||
if idx := strings.Index(envVar, "="); idx != -1 {
|
||||
varName = envVar[:idx]
|
||||
@@ -156,29 +153,3 @@ func acceptEnv(envVar string) bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// prepareSSHEnv prepares SSH protocol-specific environment variables
|
||||
// These variables provide information about the SSH connection itself
|
||||
func prepareSSHEnv(session ssh.Session) []string {
|
||||
remoteAddr := session.RemoteAddr()
|
||||
localAddr := session.LocalAddr()
|
||||
|
||||
remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
|
||||
if err != nil {
|
||||
remoteHost = remoteAddr.String()
|
||||
remotePort = "0"
|
||||
}
|
||||
|
||||
localHost, localPort, err := net.SplitHostPort(localAddr.String())
|
||||
if err != nil {
|
||||
localHost = localAddr.String()
|
||||
localPort = strconv.Itoa(InternalSSHPort)
|
||||
}
|
||||
|
||||
return []string{
|
||||
// SSH_CLIENT format: "client_ip client_port server_port"
|
||||
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
|
||||
// SSH_CONNECTION format: "client_ip client_port server_ip server_port"
|
||||
fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
|
||||
}
|
||||
}
|
||||
26
client/internal/shell/shell_test.go
Normal file
26
client/internal/shell/shell_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package shell
|
||||
|
||||
import (
|
||||
"os/user"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestIntegration_ShellLookupChain tests the full shell resolution chain
|
||||
// (getShellFromPasswd -> getShellFromGetent -> $SHELL -> default) on Unix.
|
||||
func TestIntegration_ShellLookupChain(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Unix shell lookup not applicable on Windows")
|
||||
}
|
||||
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// getUserShell is the top-level function used by the SSH server.
|
||||
shell := GetUserShell(current.Uid)
|
||||
require.NotEmpty(t, shell, "getUserShell must always return a shell")
|
||||
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
|
||||
}
|
||||
@@ -74,14 +74,6 @@ func New(filePath string) *Manager {
|
||||
}
|
||||
}
|
||||
|
||||
// FilePath returns the path of the underlying state file.
|
||||
func (m *Manager) FilePath() string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
return m.filePath
|
||||
}
|
||||
|
||||
// Start starts the state manager periodic save routine
|
||||
func (m *Manager) Start() {
|
||||
if m == nil {
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package mdm
|
||||
|
||||
import "strings"
|
||||
|
||||
// allKeys is the set of recognised MDM keys. Unknown keys in a managed
|
||||
// configuration are ignored but logged. Lives in this build-tagged file
|
||||
// (windows || darwin) because only desktop loaders need the
|
||||
// canonicalisation table that consumes it; including it unconditionally
|
||||
// would trigger the `unused` golangci-lint check on platforms that
|
||||
// don't import canonical_loaders.go.
|
||||
var allKeys = []string{
|
||||
KeyManagementURL,
|
||||
KeyDisableUpdateSettings,
|
||||
KeyDisableProfiles,
|
||||
KeyDisableNetworks,
|
||||
KeyDisableClientRoutes,
|
||||
KeyDisableServerRoutes,
|
||||
KeyBlockInbound,
|
||||
KeyDisableMetricsCollection,
|
||||
KeyAllowServerSSH,
|
||||
KeyDisableAutoConnect,
|
||||
KeyPreSharedKey,
|
||||
KeyRosenpassEnabled,
|
||||
KeyRosenpassPermissive,
|
||||
KeyWireguardPort,
|
||||
KeySplitTunnelMode,
|
||||
KeySplitTunnelApps,
|
||||
}
|
||||
|
||||
// canonicalKey maps the lowercase form of a managed-config value name to
|
||||
// its canonical mdm.Key* form. Admins commonly write PascalCase value
|
||||
// names in ADMX / Group Policy ("ManagementURL"); the iOS/AppConfig and
|
||||
// macOS plist conventions are camelCase ("managementURL"); both must
|
||||
// resolve to the same Policy lookup.
|
||||
//
|
||||
// Lives in a desktop-loader-only file (build tag `windows || darwin`)
|
||||
// because no other build path consumes it. Linux / FreeBSD / mobile
|
||||
// builds don't ship a platform loader that reads arbitrary-case key
|
||||
// names, so they don't need the canonicalisation table — and including
|
||||
// the var unconditionally would trigger the `unused` golangci-lint
|
||||
// check on those platforms.
|
||||
var canonicalKey = func() map[string]string {
|
||||
m := make(map[string]string, len(allKeys))
|
||||
for _, k := range allKeys {
|
||||
m[strings.ToLower(k)] = k
|
||||
}
|
||||
return m
|
||||
}()
|
||||
@@ -1,247 +0,0 @@
|
||||
// Package mdm reads MDM-managed configuration from platform-native sources
|
||||
// (plist on macOS, registry on Windows, UserDefaults on iOS,
|
||||
// RestrictionsManager on Android). The returned Policy is consumed by
|
||||
// profilemanager.Config.apply() as the highest-priority override layer.
|
||||
//
|
||||
// An empty Policy (no source present, or source present with zero keys)
|
||||
// means no MDM enforcement is active and the client behaves as if the
|
||||
// feature did not exist.
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Well-known policy keys. Names mirror the corresponding ConfigInput Go field
|
||||
// names (lowerCamelCase) so the daemon can map a Policy key directly to a
|
||||
// configuration field.
|
||||
const (
|
||||
KeyManagementURL = "managementURL"
|
||||
KeyDisableUpdateSettings = "disableUpdateSettings"
|
||||
KeyDisableProfiles = "disableProfiles"
|
||||
KeyDisableNetworks = "disableNetworks"
|
||||
KeyDisableClientRoutes = "disableClientRoutes"
|
||||
KeyDisableServerRoutes = "disableServerRoutes"
|
||||
KeyBlockInbound = "blockInbound"
|
||||
KeyDisableMetricsCollection = "disableMetricsCollection"
|
||||
KeyAllowServerSSH = "allowServerSSH"
|
||||
KeyDisableAutoConnect = "disableAutoConnect"
|
||||
KeyPreSharedKey = "preSharedKey"
|
||||
KeyRosenpassEnabled = "rosenpassEnabled"
|
||||
KeyRosenpassPermissive = "rosenpassPermissive"
|
||||
KeyWireguardPort = "wireguardPort"
|
||||
|
||||
// Split tunnel is modeled as a single conceptual policy with two
|
||||
// registry/plist values. KeySplitTunnelMode is the discriminator
|
||||
// ("allow" or "disallow"); KeySplitTunnelApps is a comma-separated
|
||||
// list of package names. The values are mutually exclusive by
|
||||
// construction — only one mode can be set at a time.
|
||||
KeySplitTunnelMode = "splitTunnelMode"
|
||||
KeySplitTunnelApps = "splitTunnelApps"
|
||||
)
|
||||
|
||||
// Split-tunnel mode literals (KeySplitTunnelMode values).
|
||||
const (
|
||||
SplitTunnelModeAllow = "allow"
|
||||
SplitTunnelModeDisallow = "disallow"
|
||||
)
|
||||
|
||||
// SecretKeys lists keys whose values must be redacted in logs.
|
||||
var SecretKeys = map[string]struct{}{
|
||||
KeyPreSharedKey: {},
|
||||
}
|
||||
|
||||
// boolStringLiterals enumerates the textual boolean encodings the
|
||||
// platform loaders may produce (Windows REG_SZ "true", iOS / Android
|
||||
// managed-config booleans-as-strings, etc.). Lookup keeps GetBool flat
|
||||
// (no nested switch on the string case).
|
||||
var boolStringLiterals = map[string]bool{
|
||||
"true": true,
|
||||
"1": true,
|
||||
"yes": true,
|
||||
"false": false,
|
||||
"0": false,
|
||||
"no": false,
|
||||
}
|
||||
|
||||
|
||||
// Policy holds MDM-managed settings read from the platform source. A nil or
|
||||
// empty Policy means no enforcement is active.
|
||||
type Policy struct {
|
||||
values map[string]any
|
||||
}
|
||||
|
||||
// NewPolicy constructs a Policy from a key→value map. Pass nil or an
|
||||
// empty map to construct an empty (no-enforcement) Policy. The returned
|
||||
// *Policy is always non-nil.
|
||||
func NewPolicy(values map[string]any) *Policy {
|
||||
if values == nil {
|
||||
values = map[string]any{}
|
||||
}
|
||||
return &Policy{values: values}
|
||||
}
|
||||
|
||||
// LoadPolicy reads the platform-native MDM configuration. Returns an
|
||||
// empty (but non-nil) Policy when no source is present, the source is
|
||||
// empty, or the platform is unsupported.
|
||||
//
|
||||
// Diagnostic logging differentiates the three states:
|
||||
// - source absent / unsupported platform: trace log only
|
||||
// - source present, zero keys: info "MDM enrolled (no managed keys)"
|
||||
// - source present, N keys: info "MDM enrolled with N managed keys: [...]"
|
||||
func LoadPolicy() *Policy {
|
||||
values, err := loadPlatformPolicy()
|
||||
if err != nil {
|
||||
log.Tracef("MDM policy load: %v", err)
|
||||
return &Policy{values: map[string]any{}}
|
||||
}
|
||||
if values == nil {
|
||||
return &Policy{values: map[string]any{}}
|
||||
}
|
||||
if len(values) == 0 {
|
||||
log.Info("MDM enrolled (no managed keys)")
|
||||
} else {
|
||||
log.Infof("MDM enrolled with %d managed key(s): %v", len(values), sortedKeys(values))
|
||||
}
|
||||
return &Policy{values: values}
|
||||
}
|
||||
|
||||
// IsEmpty reports whether the Policy has no managed keys.
|
||||
func (p *Policy) IsEmpty() bool {
|
||||
return p == nil || len(p.values) == 0
|
||||
}
|
||||
|
||||
// HasKey reports whether the given key is MDM-managed.
|
||||
func (p *Policy) HasKey(key string) bool {
|
||||
if p == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := p.values[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ManagedKeys returns the sorted list of managed key names. Returns an empty
|
||||
// slice (not nil) on an empty Policy.
|
||||
func (p *Policy) ManagedKeys() []string {
|
||||
if p == nil {
|
||||
return []string{}
|
||||
}
|
||||
return sortedKeys(p.values)
|
||||
}
|
||||
|
||||
// GetString returns the managed value for key coerced to string, and whether
|
||||
// the key was set. A non-string value returns ("", false).
|
||||
func (p *Policy) GetString(key string) (string, bool) {
|
||||
if p == nil {
|
||||
return "", false
|
||||
}
|
||||
v, ok := p.values[key]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok || s == "" {
|
||||
return "", false
|
||||
}
|
||||
return s, true
|
||||
}
|
||||
|
||||
// GetBool returns the managed value for key coerced to bool, and whether the
|
||||
// key was set. Accepts native bool and string literals "true"/"false"/"1"/"0".
|
||||
func (p *Policy) GetBool(key string) (bool, bool) {
|
||||
if p == nil {
|
||||
return false, false
|
||||
}
|
||||
v, ok := p.values[key]
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case bool:
|
||||
return t, true
|
||||
case string:
|
||||
b, known := boolStringLiterals[t]
|
||||
return b, known
|
||||
case int:
|
||||
return t != 0, true
|
||||
case int64:
|
||||
return t != 0, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
// GetInt returns the managed value for key as int64, and whether the key
|
||||
// was set. Accepts native int / int64 (as produced by the Windows registry
|
||||
// loader for REG_DWORD/REG_QWORD) and numeric strings (decimal).
|
||||
func (p *Policy) GetInt(key string) (int64, bool) {
|
||||
if p == nil {
|
||||
return 0, false
|
||||
}
|
||||
v, ok := p.values[key]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case int64:
|
||||
return t, true
|
||||
case int:
|
||||
return int64(t), true
|
||||
case int32:
|
||||
return int64(t), true
|
||||
case uint64:
|
||||
return int64(t), true
|
||||
case float64:
|
||||
return int64(t), true
|
||||
case string:
|
||||
if n, err := strconv.ParseInt(t, 10, 64); err == nil {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetStringSlice returns the managed value for key as []string, and whether
|
||||
// the key was set. Accepts []string, []any (of strings), and a single string
|
||||
// (treated as a one-element list).
|
||||
func (p *Policy) GetStringSlice(key string) ([]string, bool) {
|
||||
if p == nil {
|
||||
return nil, false
|
||||
}
|
||||
v, ok := p.values[key]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case []string:
|
||||
return append([]string(nil), t...), true
|
||||
case []any:
|
||||
out := make([]string, 0, len(t))
|
||||
for _, item := range t {
|
||||
s, ok := item.(string)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
out = append(out, s)
|
||||
}
|
||||
return out, true
|
||||
case string:
|
||||
return []string{t}, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// sortedKeys returns the keys of m as a deterministic, lexicographically
|
||||
// sorted slice. Used internally by Policy.ManagedKeys and LoadPolicy's
|
||||
// diagnostic log line so callers see a stable key order across runs
|
||||
// regardless of Go's randomised map iteration.
|
||||
func sortedKeys(m map[string]any) []string {
|
||||
out := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
out = append(out, k)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"howett.net/plist"
|
||||
)
|
||||
|
||||
// policyPlistPath is the well-known location where macOS writes the
|
||||
// device-level mandatory MDM payload for NetBird. The path is fixed by
|
||||
// Apple convention: when an MDM provider (Jamf / Kandji / Mosyle /
|
||||
// Intune for Mac / Workspace ONE) pushes a Configuration Profile that
|
||||
// contains a com.apple.ManagedClient.preferences payload targeting the
|
||||
// bundle id io.netbird.client, the OS materializes the payload here.
|
||||
//
|
||||
// Read-only — only the OS (root) is supposed to write this file. The
|
||||
// loader sanity-checks the file mode and refuses to honour a world-
|
||||
// writable plist, as a defense against tampered installs.
|
||||
const policyPlistPath = "/Library/Managed Preferences/io.netbird.client.plist"
|
||||
|
||||
// loadPlatformPolicy reads the MDM-managed configuration from the macOS
|
||||
// managed-preferences plist at policyPlistPath. Returns:
|
||||
// - (nil, nil) when the plist is absent (device not MDM-enrolled for
|
||||
// NetBird, or admin has not yet pushed a payload)
|
||||
// - (map, nil) with N entries when N managed values are present
|
||||
// (N may be 0 — empty plist still signals enrollment to the caller)
|
||||
// - (nil, err) on permission / parse / safety errors (including
|
||||
// refusal to read a world-writable plist)
|
||||
//
|
||||
// Top-level plist keys are canonicalised case-insensitively to the
|
||||
// package's internal mdm.Key* names; unknown keys are logged and
|
||||
// skipped so a stray entry in the payload does not block startup.
|
||||
// Native plist value types map naturally onto the Policy accessor
|
||||
// expectations (GetString / GetBool / GetInt / GetStringSlice).
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
f, err := os.Open(policyPlistPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
// Not enrolled for NetBird. Caller treats nil as
|
||||
// "no MDM source present".
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("open %s: %w", policyPlistPath, err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := f.Close(); closeErr != nil {
|
||||
log.Warnf("MDM close plist %s: %v", policyPlistPath, closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
info, err := f.Stat()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stat %s: %w", policyPlistPath, err)
|
||||
}
|
||||
// World-writable plist => tampered install. Refuse rather than
|
||||
// honour potentially attacker-controlled policy values.
|
||||
if info.Mode().Perm()&0o002 != 0 {
|
||||
return nil, fmt.Errorf("refusing to read world-writable MDM source %s (mode %o)",
|
||||
policyPlistPath, info.Mode().Perm())
|
||||
}
|
||||
|
||||
raw := make(map[string]any)
|
||||
if err := plist.NewDecoder(f).Decode(&raw); err != nil {
|
||||
return nil, fmt.Errorf("decode plist %s: %w", policyPlistPath, err)
|
||||
}
|
||||
|
||||
out := make(map[string]any, len(raw))
|
||||
for name, val := range raw {
|
||||
// macOS / AppConfig conventions both use camelCase for managed
|
||||
// preferences keys; canonicalize to the mdm.Key* form so a key
|
||||
// written as "ManagementURL" (PascalCase, rare on macOS but
|
||||
// possible if the admin reused an ADMX-style name) still
|
||||
// resolves.
|
||||
canonical, known := canonicalKey[strings.ToLower(name)]
|
||||
if !known {
|
||||
log.Warnf("MDM ignoring unknown plist key %s: %s", policyPlistPath, name)
|
||||
continue
|
||||
}
|
||||
out[canonical] = val
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
//go:build ios || android
|
||||
|
||||
package mdm
|
||||
|
||||
// loadPlatformPolicy is unused on mobile: the native layer (Swift on iOS,
|
||||
// Kotlin/Java on Android) reads the OS managed-config store and pushes the
|
||||
// resulting dictionary in-process via a gomobile entry point that lands in
|
||||
// Phase 5 / Phase 6. The stub keeps the package compilable for mobile
|
||||
// builds and returns (nil, nil) — the platform-absent sentinel that
|
||||
// LoadPolicy in policy.go treats as "no MDM source present".
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
//go:build !windows && !darwin && !ios && !android
|
||||
|
||||
package mdm
|
||||
|
||||
// loadPlatformPolicy returns no policy on platforms without an MDM channel
|
||||
// (Linux, FreeBSD). MDM enforcement is off and the client behaves as if
|
||||
// the feature did not exist. Returns (nil, nil) — the platform-absent
|
||||
// sentinel the caller (LoadPolicy in policy.go) treats as "no MDM
|
||||
// source present"; an error here would just translate to the same
|
||||
// outcome with an extra log line.
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
@@ -1,160 +0,0 @@
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPolicy_NilSafe(t *testing.T) {
|
||||
var p *Policy
|
||||
assert.True(t, p.IsEmpty())
|
||||
assert.False(t, p.HasKey(KeyManagementURL))
|
||||
assert.Empty(t, p.ManagedKeys())
|
||||
|
||||
_, ok := p.GetString(KeyManagementURL)
|
||||
assert.False(t, ok)
|
||||
_, ok = p.GetBool(KeyDisableProfiles)
|
||||
assert.False(t, ok)
|
||||
_, ok = p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPolicy_Empty(t *testing.T) {
|
||||
p := NewPolicy(nil)
|
||||
require.NotNil(t, p)
|
||||
assert.True(t, p.IsEmpty())
|
||||
assert.False(t, p.HasKey(KeyManagementURL))
|
||||
assert.Empty(t, p.ManagedKeys())
|
||||
}
|
||||
|
||||
func TestPolicy_HasKey(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true,
|
||||
})
|
||||
assert.False(t, p.IsEmpty())
|
||||
assert.True(t, p.HasKey(KeyManagementURL))
|
||||
assert.True(t, p.HasKey(KeyDisableProfiles))
|
||||
assert.False(t, p.HasKey(KeyPreSharedKey))
|
||||
}
|
||||
|
||||
func TestPolicy_ManagedKeysSorted(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyDisableProfiles: true,
|
||||
KeyManagementURL: "https://x",
|
||||
KeyAllowServerSSH: false,
|
||||
})
|
||||
got := p.ManagedKeys()
|
||||
assert.Equal(t, []string{KeyAllowServerSSH, KeyDisableProfiles, KeyManagementURL}, got)
|
||||
}
|
||||
|
||||
func TestPolicy_GetString(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true, // wrong type for GetString
|
||||
KeyPreSharedKey: "", // empty rejected
|
||||
})
|
||||
v, ok := p.GetString(KeyManagementURL)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "https://corp.example.com", v)
|
||||
|
||||
_, ok = p.GetString(KeyDisableProfiles)
|
||||
assert.False(t, ok, "non-string value must not be reported as string")
|
||||
|
||||
_, ok = p.GetString(KeyPreSharedKey)
|
||||
assert.False(t, ok, "empty string treated as unset")
|
||||
|
||||
_, ok = p.GetString("nonexistent")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPolicy_GetBool(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
raw any
|
||||
want bool
|
||||
ok bool
|
||||
}{
|
||||
{"native true", true, true, true},
|
||||
{"native false", false, false, true},
|
||||
{"string true", "true", true, true},
|
||||
{"string false", "false", false, true},
|
||||
{"string 1", "1", true, true},
|
||||
{"string 0", "0", false, true},
|
||||
{"string yes", "yes", true, true},
|
||||
{"string no", "no", false, true},
|
||||
{"int nonzero", 1, true, true},
|
||||
{"int zero", 0, false, true},
|
||||
{"int64 nonzero", int64(2), true, true},
|
||||
{"int64 zero", int64(0), false, true},
|
||||
{"string garbage", "maybe", false, false},
|
||||
{"float unsupported", 1.0, false, false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{KeyDisableProfiles: c.raw})
|
||||
got, ok := p.GetBool(KeyDisableProfiles)
|
||||
assert.Equal(t, c.ok, ok)
|
||||
if c.ok {
|
||||
assert.Equal(t, c.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
_, ok := NewPolicy(nil).GetBool(KeyDisableProfiles)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPolicy_GetStringSlice(t *testing.T) {
|
||||
t.Run("native string slice", func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeySplitTunnelApps: []string{"com.a", "com.b"},
|
||||
})
|
||||
got, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"com.a", "com.b"}, got)
|
||||
})
|
||||
|
||||
t.Run("any slice of strings", func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeySplitTunnelApps: []any{"com.a", "com.b"},
|
||||
})
|
||||
got, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"com.a", "com.b"}, got)
|
||||
})
|
||||
|
||||
t.Run("single string lifts to one-element slice", func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeySplitTunnelApps: "com.a",
|
||||
})
|
||||
got, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"com.a"}, got)
|
||||
})
|
||||
|
||||
t.Run("mixed any slice rejected", func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeySplitTunnelApps: []any{"com.a", 1},
|
||||
})
|
||||
_, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("missing key", func(t *testing.T) {
|
||||
p := NewPolicy(nil)
|
||||
_, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadPolicy_PlatformStubReturnsEmpty(t *testing.T) {
|
||||
// loadPlatformPolicy is a stub on every OS for Phase 1. LoadPolicy must
|
||||
// degrade gracefully and never return nil.
|
||||
p := LoadPolicy()
|
||||
require.NotNil(t, p)
|
||||
assert.True(t, p.IsEmpty())
|
||||
assert.Empty(t, p.ManagedKeys())
|
||||
}
|
||||
@@ -1,108 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
// policyRegistryPath is the well-known MDM policy registry key for NetBird.
|
||||
// Admins push values here through Group Policy, Intune ADMX ingestion, an
|
||||
// Intune custom Registry CSP profile, or `reg add` during MSI deployment.
|
||||
// Listed in the project's docs/mdm/netbird.admx schema.
|
||||
const policyRegistryPath = `Software\Policies\NetBird`
|
||||
|
||||
// readRegistryValue reads a single value under policyRegistryPath and,
|
||||
// on success, stores the type-coerced result in out[canonical]. Type
|
||||
// coercion mirrors loadPlatformPolicy's documented mapping:
|
||||
// - REG_SZ / REG_EXPAND_SZ -> string (REG_EXPAND_SZ is expanded by the API)
|
||||
// - REG_DWORD / REG_QWORD -> int64
|
||||
// - REG_MULTI_SZ -> []string
|
||||
//
|
||||
// Unsupported value types and per-value read failures are logged at
|
||||
// warn level and skipped — one malformed value must not block the
|
||||
// surrounding loop. Extracted from loadPlatformPolicy to keep that
|
||||
// function's cognitive complexity in check.
|
||||
func readRegistryValue(k registry.Key, name, canonical string, out map[string]any) {
|
||||
_, valType, err := k.GetValue(name, nil)
|
||||
if err != nil {
|
||||
log.Warnf("MDM stat %s\\%s: %v", policyRegistryPath, name, err)
|
||||
return
|
||||
}
|
||||
switch valType {
|
||||
case registry.SZ, registry.EXPAND_SZ:
|
||||
if v, _, err := k.GetStringValue(name); err == nil {
|
||||
out[canonical] = v
|
||||
} else {
|
||||
log.Warnf("MDM read string %s\\%s: %v", policyRegistryPath, name, err)
|
||||
}
|
||||
case registry.DWORD, registry.QWORD:
|
||||
if v, _, err := k.GetIntegerValue(name); err == nil {
|
||||
// uint64 from the registry API; Policy.GetBool / GetInt
|
||||
// helpers consume int64, so narrow safely.
|
||||
out[canonical] = int64(v)
|
||||
} else {
|
||||
log.Warnf("MDM read int %s\\%s: %v", policyRegistryPath, name, err)
|
||||
}
|
||||
case registry.MULTI_SZ:
|
||||
if v, _, err := k.GetStringsValue(name); err == nil {
|
||||
out[canonical] = v
|
||||
} else {
|
||||
log.Warnf("MDM read multi-string %s\\%s: %v", policyRegistryPath, name, err)
|
||||
}
|
||||
default:
|
||||
log.Warnf("MDM ignoring unsupported registry value type %d at %s\\%s",
|
||||
valType, policyRegistryPath, name)
|
||||
}
|
||||
}
|
||||
|
||||
// loadPlatformPolicy reads the MDM-managed configuration from the
|
||||
// Windows registry under HKLM\Software\Policies\NetBird. Returns:
|
||||
// - (nil, nil) when the key is absent (device not MDM-enrolled for NetBird)
|
||||
// - (map, nil) with N entries when N managed values are set (N may be 0)
|
||||
// - (nil, err) on open / enumerate registry errors
|
||||
//
|
||||
// Per-value type coercion + skip-on-error is delegated to
|
||||
// readRegistryValue. Unknown value names are logged and skipped so a
|
||||
// malformed deployment does not block startup.
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, policyRegistryPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
if errors.Is(err, registry.ErrNotExist) {
|
||||
// Not enrolled. Caller treats nil as "no MDM source present".
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("open %s: %w", policyRegistryPath, err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := k.Close(); closeErr != nil {
|
||||
log.Warnf("MDM close registry key %s: %v", policyRegistryPath, closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
names, err := k.ReadValueNames(-1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("enumerate values of %s: %w", policyRegistryPath, err)
|
||||
}
|
||||
|
||||
out := make(map[string]any, len(names))
|
||||
for _, name := range names {
|
||||
// Canonicalize the registry value name against the known MDM key
|
||||
// set so Policy.HasKey lookups (which use the canonical names)
|
||||
// succeed regardless of the casing used by the admin's ADMX or
|
||||
// `reg add` command.
|
||||
canonical, known := canonicalKey[strings.ToLower(name)]
|
||||
if !known {
|
||||
log.Warnf("MDM ignoring unknown registry value %s\\%s", policyRegistryPath, name)
|
||||
continue
|
||||
}
|
||||
readRegistryValue(k, name, canonical, out)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -1,129 +0,0 @@
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DefaultReloadInterval is the production cadence at which the desktop daemon
|
||||
// re-reads the OS-native MDM policy. Picked to balance responsiveness against
|
||||
// registry/plist I/O overhead. Mobile builds use OS-side notifications
|
||||
// instead, hence anticipating the ticker mechanism entirely.
|
||||
const DefaultReloadInterval = 1 * time.Minute
|
||||
|
||||
// policyLoader is the indirection through which the ticker reads the
|
||||
// OS-native policy, both for the initial observation and on every tick.
|
||||
// Production points it at LoadPolicy; tests in this package override it to
|
||||
// feed a scripted sequence of policies without touching the real OS store.
|
||||
var policyLoader = LoadPolicy
|
||||
|
||||
// Ticker periodically re-reads the OS-native MDM policy via LoadPolicy and
|
||||
// invokes the onChange callback (supplied to Run) whenever the observed
|
||||
// Policy diverges from the last observation (added / removed / changed
|
||||
// keys). Launch with Run from a goroutine; cancel the supplied context
|
||||
// to stop.
|
||||
type Ticker struct {
|
||||
interval time.Duration
|
||||
prev *Policy
|
||||
}
|
||||
|
||||
// NewTicker constructs a Ticker that will re-read the OS-native policy
|
||||
// every reloadInterval once Run is called.
|
||||
// The initial snapshot is populated by calling policyLoader at
|
||||
// construction time so the first tick only fires
|
||||
// onChange when the policy actually changed since boot — without
|
||||
// this baseline the first tick would report every currently-managed
|
||||
// key as "added" and trigger a spurious engine restart.
|
||||
func NewTicker(reloadInterval time.Duration) *Ticker {
|
||||
return &Ticker{
|
||||
interval: reloadInterval,
|
||||
prev: policyLoader(),
|
||||
}
|
||||
}
|
||||
|
||||
// Run blocks until ctx is cancelled, polling the OS-native policy store at
|
||||
// the configured cadence and emitting log lines + onChange callback on
|
||||
// every observed diff. onChange must be non-nil.
|
||||
func (t *Ticker) Run(ctx context.Context, onChange func(prev, curr *Policy) error) {
|
||||
tk := time.NewTicker(t.interval)
|
||||
defer tk.Stop()
|
||||
log.Infof("MDM policy reload ticker started (interval=%s)", t.interval)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("MDM policy reload ticker stopped")
|
||||
return
|
||||
case <-tk.C:
|
||||
curr := policyLoader()
|
||||
if policiesEqual(t.prev, curr) {
|
||||
continue
|
||||
}
|
||||
added, removed, changed := diffPolicies(t.prev, curr)
|
||||
log.Infof("MDM policy changed: added=%v removed=%v changed=%v",
|
||||
added, removed, changed)
|
||||
prev := t.prev
|
||||
if err := onChange(prev, curr); err != nil {
|
||||
log.Errorf("MDM policy change handler failed (retrying in 1 minute): %v", err)
|
||||
continue
|
||||
}
|
||||
t.prev = curr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// policiesEqual reports whether two Policy instances carry the same
|
||||
// managed key set with identical values. Nil and empty policies
|
||||
// compare equal; one-nil/one-non-empty compare not equal; otherwise
|
||||
// the underlying values maps are compared with reflect.DeepEqual.
|
||||
func policiesEqual(a, b *Policy) bool {
|
||||
if a.IsEmpty() && b.IsEmpty() {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return reflect.DeepEqual(a.values, b.values)
|
||||
}
|
||||
|
||||
// diffPolicies returns the keys added in curr, removed from prev, and
|
||||
// whose values changed between prev and curr. Each slice is sorted
|
||||
// lexicographically for stable log output; value differences are
|
||||
// determined with reflect.DeepEqual.
|
||||
func diffPolicies(prev, curr *Policy) (added, removed, changed []string) {
|
||||
prevKVs := mapOf(prev)
|
||||
currKVs := mapOf(curr)
|
||||
for k := range currKVs {
|
||||
if _, ok := prevKVs[k]; !ok {
|
||||
added = append(added, k)
|
||||
} else if !reflect.DeepEqual(prevKVs[k], currKVs[k]) {
|
||||
changed = append(changed, k)
|
||||
}
|
||||
}
|
||||
for k := range prevKVs {
|
||||
if _, ok := currKVs[k]; !ok {
|
||||
removed = append(removed, k)
|
||||
}
|
||||
}
|
||||
sort.Strings(added)
|
||||
sort.Strings(removed)
|
||||
sort.Strings(changed)
|
||||
return added, removed, changed
|
||||
}
|
||||
|
||||
// mapOf returns a (possibly empty, never nil) copy of the underlying
|
||||
// values map of a Policy so callers outside this package can compare
|
||||
// keys/values across the type boundary. Returns an empty map on nil p.
|
||||
func mapOf(p *Policy) map[string]any {
|
||||
if p == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
out := make(map[string]any, len(p.values))
|
||||
for k, v := range p.values {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -1,100 +0,0 @@
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testReloadInterval for speeding up the ticker cadence under `go test`
|
||||
const testReloadInterval = 1 * time.Second
|
||||
|
||||
// withPolicyLoader overrides the package-level policyLoader for the duration
|
||||
// of the test so the ticker observes a scripted policy instead of the real
|
||||
// OS-native store. The original loader is restored on cleanup.
|
||||
func withPolicyLoader(t *testing.T, fn func() *Policy) {
|
||||
t.Helper()
|
||||
prev := policyLoader
|
||||
policyLoader = fn
|
||||
t.Cleanup(func() { policyLoader = prev })
|
||||
}
|
||||
|
||||
func TestTicker_FiresOnChangeWithDelta(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
current := NewPolicy(nil) // initial observation: empty (no enforcement)
|
||||
withPolicyLoader(t, func() *Policy {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return current
|
||||
})
|
||||
|
||||
type change struct{ prev, curr *Policy }
|
||||
changes := make(chan change, 1)
|
||||
tk := NewTicker(testReloadInterval)
|
||||
require.Equal(t, testReloadInterval, tk.interval)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
tk.Run(ctx, func(prev, curr *Policy) error {
|
||||
select {
|
||||
case changes <- change{prev, curr}:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
// Stop Run and wait for it to exit before returning, so the policyLoader
|
||||
// restore in t.Cleanup can't race the ticker goroutine still reading it.
|
||||
defer func() { cancel(); <-done }()
|
||||
|
||||
// Flip the OS-observed policy from empty to one managed key. The next
|
||||
// tick must detect the diff and invoke onChange.
|
||||
mu.Lock()
|
||||
current = NewPolicy(map[string]any{KeyManagementURL: "https://mdm.example.com:443"})
|
||||
mu.Unlock()
|
||||
|
||||
select {
|
||||
case c := <-changes:
|
||||
assert.True(t, c.prev.IsEmpty(), "prev should be the initial empty policy")
|
||||
assert.True(t, c.curr.HasKey(KeyManagementURL), "curr should carry the newly-pushed managed key")
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("onChange not invoked within 5s; ticker should fire every 1s under test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTicker_NoCallbackWhenPolicyUnchanged(t *testing.T) {
|
||||
withPolicyLoader(t, func() *Policy {
|
||||
return NewPolicy(map[string]any{KeyBlockInbound: true})
|
||||
})
|
||||
|
||||
fired := make(chan struct{}, 1)
|
||||
tk := NewTicker(testReloadInterval)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
tk.Run(ctx, func(_, _ *Policy) error {
|
||||
select {
|
||||
case fired <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
defer func() { cancel(); <-done }()
|
||||
|
||||
// Over ~2 ticks at the 1s test cadence the policy never changes, so the
|
||||
// diff guard must suppress the callback entirely.
|
||||
select {
|
||||
case <-fired:
|
||||
t.Fatal("onChange fired despite an unchanged policy")
|
||||
case <-time.After(2500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -119,14 +119,6 @@ service DaemonService {
|
||||
|
||||
// ExposeService exposes a local port via the NetBird reverse proxy
|
||||
rpc ExposeService(ExposeServiceRequest) returns (stream ExposeServiceEvent) {}
|
||||
|
||||
// RespondApproval delivers the user's accept/deny decision for a
|
||||
// pending user-approval prompt. The daemon pushes the prompt as a
|
||||
// SystemEvent with category APPROVAL and metadata key "request_id";
|
||||
// the UI calls this RPC with the same request_id to unblock whichever
|
||||
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
|
||||
// the UI which subsystem the prompt belongs to.
|
||||
rpc RespondApproval(RespondApprovalRequest) returns (RespondApprovalResponse) {}
|
||||
}
|
||||
|
||||
|
||||
@@ -213,10 +205,6 @@ message LoginRequest {
|
||||
optional bool disableSSHAuth = 38;
|
||||
optional int32 sshJWTCacheTTL = 39;
|
||||
optional bool disable_ipv6 = 40;
|
||||
|
||||
optional bool serverVNCAllowed = 41;
|
||||
|
||||
optional bool disableVNCApproval = 42;
|
||||
}
|
||||
|
||||
message LoginResponse {
|
||||
@@ -326,17 +314,6 @@ message GetConfigResponse {
|
||||
int32 sshJWTCacheTTL = 26;
|
||||
|
||||
bool disable_ipv6 = 27;
|
||||
|
||||
bool serverVNCAllowed = 28;
|
||||
|
||||
bool disableVNCApproval = 29;
|
||||
|
||||
// mDMManagedFields lists the names of configuration keys whose value is
|
||||
// currently enforced by an MDM policy. Names match mdm.Key* constants
|
||||
// (e.g. "managementURL", "disableClientRoutes"). UI/CLI clients should
|
||||
// render the corresponding inputs as read-only and display a "managed
|
||||
// by MDM" indicator.
|
||||
repeated string mDMManagedFields = 30;
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
@@ -418,25 +395,6 @@ message SSHServerState {
|
||||
repeated SSHSessionInfo sessions = 2;
|
||||
}
|
||||
|
||||
// VNCSessionInfo contains information about an active VNC session
|
||||
message VNCSessionInfo {
|
||||
string remoteAddress = 1;
|
||||
string mode = 2;
|
||||
string username = 3;
|
||||
// userID is the Noise-verified session identity (hashed user ID from
|
||||
// the ACL session-key entry), empty when auth is disabled.
|
||||
string userID = 4;
|
||||
// initiator is the human-readable display name of the dashboard user
|
||||
// who minted the SessionPubKey, when known.
|
||||
string initiator = 5;
|
||||
}
|
||||
|
||||
// VNCServerState contains the latest state of the VNC server
|
||||
message VNCServerState {
|
||||
bool enabled = 1;
|
||||
repeated VNCSessionInfo sessions = 2;
|
||||
}
|
||||
|
||||
// FullStatus contains the full state held by the Status instance
|
||||
message FullStatus {
|
||||
ManagementState managementState = 1;
|
||||
@@ -451,7 +409,6 @@ message FullStatus {
|
||||
|
||||
bool lazyConnectionEnabled = 9;
|
||||
SSHServerState sshServerState = 10;
|
||||
VNCServerState vncServerState = 11;
|
||||
}
|
||||
|
||||
// Networks
|
||||
@@ -640,7 +597,6 @@ message SystemEvent {
|
||||
AUTHENTICATION = 2;
|
||||
CONNECTIVITY = 3;
|
||||
SYSTEM = 4;
|
||||
APPROVAL = 5;
|
||||
}
|
||||
|
||||
string id = 1;
|
||||
@@ -724,10 +680,6 @@ message SetConfigRequest {
|
||||
optional bool disableSSHAuth = 33;
|
||||
optional int32 sshJWTCacheTTL = 34;
|
||||
optional bool disable_ipv6 = 35;
|
||||
|
||||
optional bool serverVNCAllowed = 36;
|
||||
|
||||
optional bool disableVNCApproval = 37;
|
||||
}
|
||||
|
||||
message SetConfigResponse{}
|
||||
@@ -781,15 +733,6 @@ message GetFeaturesResponse{
|
||||
bool disable_networks = 3;
|
||||
}
|
||||
|
||||
// MDMManagedFieldsViolation is attached as a gRPC error detail on a
|
||||
// FailedPrecondition status returned from SetConfig (and similar mutating
|
||||
// RPCs) when the caller tries to modify one or more MDM-enforced fields.
|
||||
// The fields list contains the offending key names; the entire request is
|
||||
// rejected (no partial apply).
|
||||
message MDMManagedFieldsViolation {
|
||||
repeated string fields = 1;
|
||||
}
|
||||
|
||||
message TriggerUpdateRequest {}
|
||||
|
||||
message TriggerUpdateResponse {
|
||||
@@ -931,18 +874,3 @@ message StartBundleCaptureRequest {
|
||||
message StartBundleCaptureResponse {}
|
||||
message StopBundleCaptureRequest {}
|
||||
message StopBundleCaptureResponse {}
|
||||
|
||||
message RespondApprovalRequest {
|
||||
// request_id matches the SystemEvent metadata key emitted by the daemon
|
||||
// when a subsystem awaits user approval for an inbound connection.
|
||||
string request_id = 1;
|
||||
// accept is true if the user approved the request, false if they
|
||||
// denied it. A missing or unknown request_id is treated as a no-op.
|
||||
bool accept = 2;
|
||||
// view_only signals that the user granted the connection but withheld
|
||||
// input control. Only meaningful when accept is true; ignored when
|
||||
// accept is false.
|
||||
bool view_only = 3;
|
||||
}
|
||||
|
||||
message RespondApprovalResponse {}
|
||||
|
||||
@@ -58,7 +58,6 @@ const (
|
||||
DaemonService_StopCPUProfile_FullMethodName = "/daemon.DaemonService/StopCPUProfile"
|
||||
DaemonService_GetInstallerResult_FullMethodName = "/daemon.DaemonService/GetInstallerResult"
|
||||
DaemonService_ExposeService_FullMethodName = "/daemon.DaemonService/ExposeService"
|
||||
DaemonService_RespondApproval_FullMethodName = "/daemon.DaemonService/RespondApproval"
|
||||
)
|
||||
|
||||
// DaemonServiceClient is the client API for DaemonService service.
|
||||
@@ -135,13 +134,6 @@ type DaemonServiceClient interface {
|
||||
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
|
||||
// ExposeService exposes a local port via the NetBird reverse proxy
|
||||
ExposeService(ctx context.Context, in *ExposeServiceRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[ExposeServiceEvent], error)
|
||||
// RespondApproval delivers the user's accept/deny decision for a
|
||||
// pending user-approval prompt. The daemon pushes the prompt as a
|
||||
// SystemEvent with category APPROVAL and metadata key "request_id";
|
||||
// the UI calls this RPC with the same request_id to unblock whichever
|
||||
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
|
||||
// the UI which subsystem the prompt belongs to.
|
||||
RespondApproval(ctx context.Context, in *RespondApprovalRequest, opts ...grpc.CallOption) (*RespondApprovalResponse, error)
|
||||
}
|
||||
|
||||
type daemonServiceClient struct {
|
||||
@@ -569,16 +561,6 @@ func (c *daemonServiceClient) ExposeService(ctx context.Context, in *ExposeServi
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type DaemonService_ExposeServiceClient = grpc.ServerStreamingClient[ExposeServiceEvent]
|
||||
|
||||
func (c *daemonServiceClient) RespondApproval(ctx context.Context, in *RespondApprovalRequest, opts ...grpc.CallOption) (*RespondApprovalResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(RespondApprovalResponse)
|
||||
err := c.cc.Invoke(ctx, DaemonService_RespondApproval_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// DaemonServiceServer is the server API for DaemonService service.
|
||||
// All implementations must embed UnimplementedDaemonServiceServer
|
||||
// for forward compatibility.
|
||||
@@ -653,13 +635,6 @@ type DaemonServiceServer interface {
|
||||
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
|
||||
// ExposeService exposes a local port via the NetBird reverse proxy
|
||||
ExposeService(*ExposeServiceRequest, grpc.ServerStreamingServer[ExposeServiceEvent]) error
|
||||
// RespondApproval delivers the user's accept/deny decision for a
|
||||
// pending user-approval prompt. The daemon pushes the prompt as a
|
||||
// SystemEvent with category APPROVAL and metadata key "request_id";
|
||||
// the UI calls this RPC with the same request_id to unblock whichever
|
||||
// subsystem (VNC, SSH, ...) is waiting. The "kind" metadata key tells
|
||||
// the UI which subsystem the prompt belongs to.
|
||||
RespondApproval(context.Context, *RespondApprovalRequest) (*RespondApprovalResponse, error)
|
||||
mustEmbedUnimplementedDaemonServiceServer()
|
||||
}
|
||||
|
||||
@@ -787,9 +762,6 @@ func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *Ins
|
||||
func (UnimplementedDaemonServiceServer) ExposeService(*ExposeServiceRequest, grpc.ServerStreamingServer[ExposeServiceEvent]) error {
|
||||
return status.Error(codes.Unimplemented, "method ExposeService not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) RespondApproval(context.Context, *RespondApprovalRequest) (*RespondApprovalResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method RespondApproval not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||
func (UnimplementedDaemonServiceServer) testEmbeddedByValue() {}
|
||||
|
||||
@@ -1492,24 +1464,6 @@ func _DaemonService_ExposeService_Handler(srv interface{}, stream grpc.ServerStr
|
||||
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name.
|
||||
type DaemonService_ExposeServiceServer = grpc.ServerStreamingServer[ExposeServiceEvent]
|
||||
|
||||
func _DaemonService_RespondApproval_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RespondApprovalRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).RespondApproval(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: DaemonService_RespondApproval_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).RespondApproval(ctx, req.(*RespondApprovalRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
@@ -1661,10 +1615,6 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "GetInstallerResult",
|
||||
Handler: _DaemonService_GetInstallerResult_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "RespondApproval",
|
||||
Handler: _DaemonService_RespondApproval_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
|
||||
@@ -111,7 +111,7 @@ func (s *Server) StartCapture(req *proto.StartCaptureRequest, stream proto.Daemo
|
||||
return status.Errorf(codes.Internal, "create capture session: %v", err)
|
||||
}
|
||||
|
||||
engine, err := s.claimCapture(sess, func() { pw.Close() })
|
||||
engine, err := s.claimCapture(sess)
|
||||
if err != nil {
|
||||
sess.Stop()
|
||||
pw.Close()
|
||||
@@ -190,7 +190,10 @@ func (s *Server) StartBundleCapture(_ context.Context, req *proto.StartBundleCap
|
||||
|
||||
s.stopBundleCaptureLocked()
|
||||
s.cleanupBundleCapture()
|
||||
s.evictActiveCaptureLocked()
|
||||
|
||||
if s.activeCapture != nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
|
||||
}
|
||||
|
||||
engine, err := s.getCaptureEngineLocked()
|
||||
if err != nil {
|
||||
@@ -301,58 +304,29 @@ func (s *Server) cleanupBundleCapture() {
|
||||
s.bundleCapture = nil
|
||||
}
|
||||
|
||||
// claimCapture reserves the engine's capture slot for sess. If another
|
||||
// capture is already running it is evicted: a previous streaming session
|
||||
// whose gRPC client died and never freed the slot stays stuck otherwise,
|
||||
// and a bundle capture is just informational state.
|
||||
func (s *Server) claimCapture(sess *capture.Session, cancel func()) (*internal.Engine, error) {
|
||||
// claimCapture reserves the engine's capture slot for sess. Returns
|
||||
// FailedPrecondition if another capture is already active.
|
||||
func (s *Server) claimCapture(sess *capture.Session) (*internal.Engine, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
s.evictActiveCaptureLocked()
|
||||
if s.activeCapture != nil {
|
||||
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
|
||||
}
|
||||
engine, err := s.getCaptureEngineLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.activeCapture = sess
|
||||
s.activeCaptureCancel = cancel
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
// evictActiveCaptureLocked tears down whatever capture currently owns
|
||||
// the engine slot so a fresh claim can succeed. Caller must hold mutex.
|
||||
func (s *Server) evictActiveCaptureLocked() {
|
||||
if s.activeCapture == nil {
|
||||
return
|
||||
}
|
||||
if s.bundleCapture != nil && s.bundleCapture.sess == s.activeCapture {
|
||||
log.Infof("evicting running bundle capture to start a new capture")
|
||||
s.stopBundleCaptureLocked()
|
||||
return
|
||||
}
|
||||
log.Infof("evicting previous streaming capture to start a new one")
|
||||
prev := s.activeCapture
|
||||
cancel := s.activeCaptureCancel
|
||||
if engine, err := s.getCaptureEngineLocked(); err == nil {
|
||||
if err := engine.SetCapture(nil); err != nil {
|
||||
log.Debugf("clear previous capture: %v", err)
|
||||
}
|
||||
}
|
||||
s.activeCapture = nil
|
||||
s.activeCaptureCancel = nil
|
||||
prev.Stop()
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// releaseCapture clears the active-capture owner if it still matches sess.
|
||||
func (s *Server) releaseCapture(sess *capture.Session) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
if s.activeCapture == sess {
|
||||
s.activeCapture = nil
|
||||
s.activeCaptureCancel = nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -367,7 +341,6 @@ func (s *Server) clearCaptureIfOwner(sess *capture.Session, engine *internal.Eng
|
||||
log.Debugf("clear capture: %v", err)
|
||||
}
|
||||
s.activeCapture = nil
|
||||
s.activeCaptureCancel = nil
|
||||
}
|
||||
|
||||
func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) {
|
||||
|
||||
@@ -1,419 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// preSharedKeyRedactedSentinel is the value GetConfig returns in place
|
||||
// of an actual PSK, so a UI that round-trips the field back to the
|
||||
// daemon (via SetConfig / Login) can be distinguished from a deliberate
|
||||
// override. Any incoming PSK that equals this sentinel is treated as
|
||||
// a no-op echo, never as a conflict with the policy.
|
||||
const preSharedKeyRedactedSentinel = "**********"
|
||||
|
||||
// loadMDMPolicy is the indirection used by server handlers to read the
|
||||
// active MDM policy. Tests override this to inject a fake policy.
|
||||
var loadMDMPolicy = mdm.LoadPolicy
|
||||
|
||||
// conflictCheck is a value-aware comparison between a single field in
|
||||
// the incoming request and the corresponding MDM-enforced value. It
|
||||
// runs only when the field was actually set in the request (presence
|
||||
// already filtered upstream); ok=true reports the policy value, ok=false
|
||||
// means the policy is silent on the key — both are treated as conflicts
|
||||
// to be safe (an MDM key declared as managed must hold a value).
|
||||
type conflictCheck struct {
|
||||
key string
|
||||
check func(*mdm.Policy) (match bool)
|
||||
}
|
||||
|
||||
// onMDMPolicyChange is invoked by the MDM reload ticker every time the
|
||||
// OS-native managed-config store reports a diff vs the last observation.
|
||||
//
|
||||
// Restart sequence:
|
||||
// 1. Cancel the active engine context (terminates connectWithRetryRuns).
|
||||
// 2. Wait briefly for that goroutine to exit (giveUpChan is closed on exit).
|
||||
// 3. Re-resolve Config from disk + MDM policy (Config.apply re-runs
|
||||
// applyMDMPolicy with the freshly loaded Policy).
|
||||
// 4. Spawn a fresh connectWithRetryRuns with the new context and config.
|
||||
// 5. Broadcast a SystemEvent so any GUI / CLI subscriber (SubscribeEvents
|
||||
// RPC) can refresh its cached config view without polling.
|
||||
//
|
||||
// The callback runs in the ticker's own goroutine. Ticker has already
|
||||
// logged the per-key diff before invoking this hook.
|
||||
func (s *Server) onMDMPolicyChange(_, _ *mdm.Policy) error {
|
||||
log.Warn("MDM policy changed; restarting engine to apply new configuration")
|
||||
|
||||
// Hold s.mutex for the entire restart sequence (cancel + quiescence
|
||||
// wait + re-spawn). Any concurrent Up/Down/Status arriving while
|
||||
// MDM is restarting blocks on the Lock until we are done — they
|
||||
// then observe the post-restart state coherently. This is safe
|
||||
// because the connectWithRetryRuns goroutine no longer acquires
|
||||
// s.mutex in its defer (intent vs. goroutine-alive concerns are
|
||||
// fully separated; see the connectionGoroutineRunning helper).
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if !s.clientRunning {
|
||||
// The client is not running, so there's no engine to restart.
|
||||
return nil
|
||||
}
|
||||
if s.actCancel != nil {
|
||||
s.actCancel()
|
||||
}
|
||||
|
||||
// Wait for previous connectWithRetryRuns to exit so we don't end up
|
||||
// with two goroutines fighting over the same status recorder + engine.
|
||||
// The teardown engages a fan-out of engine goroutines (peer workers,
|
||||
// signal handler, route manager, ...). close(clientGiveUpChan)
|
||||
// happens in the function-scope defer of connectWithRetryRuns, on
|
||||
// every exit path (ctx cancel, backoff exhausted, panic) — see the
|
||||
// defer in server.go.
|
||||
if s.clientGiveUpChan != nil {
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
case <-time.After(10 * time.Second):
|
||||
return fmt.Errorf("failed to restart the engine due to timeout")
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.restartEngineForMDMLocked(); err != nil {
|
||||
log.Errorf("MDM restart failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// publishConfigChangedEvent has already fired inside
|
||||
// restartEngineForMDMLocked with source="mdm". Emit an MDM-specific
|
||||
// user-visible toast so the operator knows their IT policy was
|
||||
// applied (UserMessage != "" triggers the GUI notifier).
|
||||
s.statusRecorder.PublishEvent(
|
||||
proto.SystemEvent_INFO,
|
||||
proto.SystemEvent_SYSTEM,
|
||||
"MDM policy applied",
|
||||
"NetBird configuration was updated by your IT policy.",
|
||||
map[string]string{"source": "mdm", "type": "policy_applied"},
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// publishConfigChangedEvent broadcasts a SystemEvent informing any active
|
||||
// SubscribeEvents subscriber (typically the GUI tray) that the daemon's
|
||||
// effective Config has been replaced and any cached client-side view
|
||||
// should be refreshed. Callers pass a stable `source` label so the GUI
|
||||
// can distinguish a startup spawn from a user-triggered Up or an
|
||||
// MDM-driven restart. Reusing the SYSTEM category keeps the proto enum
|
||||
// stable; metadata.type="config_changed" routes to the GUI's refresh
|
||||
// handler. UserMessage is left empty so the system tray does not toast
|
||||
// for every internal restart; the MDM path emits a separate
|
||||
// "policy_applied" event (with UserMessage) for that purpose.
|
||||
func (s *Server) publishConfigChangedEvent(source string) {
|
||||
if s.statusRecorder == nil {
|
||||
return
|
||||
}
|
||||
s.statusRecorder.PublishEvent(
|
||||
proto.SystemEvent_INFO,
|
||||
proto.SystemEvent_SYSTEM,
|
||||
fmt.Sprintf("daemon config changed (source=%s)", source),
|
||||
"",
|
||||
map[string]string{
|
||||
"source": source,
|
||||
"type": "config_changed",
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// restartEngineForMDMLocked re-resolves the active profile config
|
||||
// (re-running applyMDMPolicy via Config.apply) and re-spawns
|
||||
// connectWithRetryRuns. Mirrors the tail of Server.Start so a runtime
|
||||
// MDM change behaves identically to a fresh boot under the new policy.
|
||||
//
|
||||
// MUST be called with s.mutex held — onMDMPolicyChange holds the lock
|
||||
// for the entire restart sequence (cancel + quiescence wait + re-spawn)
|
||||
// so concurrent Up/Down/Status RPCs observe a coherent post-restart
|
||||
// state.
|
||||
func (s *Server) restartEngineForMDMLocked() error {
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile state: %w", err)
|
||||
}
|
||||
config, _, err := s.getConfig(activeProf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile config: %w", err)
|
||||
}
|
||||
|
||||
s.config = config
|
||||
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
|
||||
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
|
||||
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
s.actCancel = cancel
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
log.Info("MDM restart: spawning connectWithRetryRuns with re-resolved config")
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
s.publishConfigChangedEvent("mdm")
|
||||
return nil
|
||||
}
|
||||
|
||||
// conflictBool builds a conflictCheck for a boolean MDM key. If p is nil
|
||||
// the field is treated as matching (no override requested); otherwise the
|
||||
// check returns true only when the policy contains the key and its
|
||||
// boolean value equals *p.
|
||||
func conflictBool(key string, p *bool) conflictCheck {
|
||||
return conflictCheck{
|
||||
key: key,
|
||||
check: func(pol *mdm.Policy) bool {
|
||||
if p == nil {
|
||||
return true // absent → match by definition
|
||||
}
|
||||
want, ok := pol.GetBool(key)
|
||||
return ok && want == *p
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// conflictString builds a conflictCheck for a string MDM key. An empty
|
||||
// `got` is treated as "field not set" (no override requested); otherwise
|
||||
// the check returns true only when the policy contains the key and its
|
||||
// value equals got.
|
||||
func conflictString(key, got string) conflictCheck {
|
||||
return conflictCheck{
|
||||
key: key,
|
||||
check: func(pol *mdm.Policy) bool {
|
||||
if got == "" {
|
||||
return true
|
||||
}
|
||||
want, ok := pol.GetString(key)
|
||||
return ok && want == got
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// conflictInt64 builds a conflictCheck for an integer MDM key. If p is
|
||||
// nil the field is treated as matching; otherwise the check returns
|
||||
// true only when the policy contains the key and its int value equals *p.
|
||||
func conflictInt64(key string, p *int64) conflictCheck {
|
||||
return conflictCheck{
|
||||
key: key,
|
||||
check: func(pol *mdm.Policy) bool {
|
||||
if p == nil {
|
||||
return true
|
||||
}
|
||||
want, ok := pol.GetInt(key)
|
||||
return ok && want == *p
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// resolveConflicts walks the per-field checks against the active MDM
|
||||
// policy and returns the names of keys whose requested value diverges
|
||||
// from the policy-enforced value. Keys not present in the policy are
|
||||
// skipped silently (the gate fires only for keys the admin has
|
||||
// actually pushed). Returns nil for an empty policy.
|
||||
func resolveConflicts(policy *mdm.Policy, checks []conflictCheck) []string {
|
||||
if policy.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
var conflicts []string
|
||||
for _, c := range checks {
|
||||
if !policy.HasKey(c.key) {
|
||||
continue
|
||||
}
|
||||
if !c.check(policy) {
|
||||
conflicts = append(conflicts, c.key)
|
||||
}
|
||||
}
|
||||
return conflicts
|
||||
}
|
||||
|
||||
// mdmManagedFieldConflicts returns the names of MDM-managed keys whose
|
||||
// requested value in the SetConfigRequest differs from the MDM-enforced
|
||||
// value. A field set to the same value the policy already enforces is
|
||||
// treated as a no-op echo (the GUI tray sends a full Config snapshot on
|
||||
// every toggle, so most fields in a typical request match the policy
|
||||
// exactly and must NOT be flagged as conflicts). The redacted PSK
|
||||
// sentinel ("**********") returned by GetConfig is recognised and
|
||||
// treated as no-op so the UI can safely round-trip it.
|
||||
func mdmManagedFieldConflicts(msg *proto.SetConfigRequest, policy *mdm.Policy) []string {
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// PSK round-trip echo: collapse the sentinel to empty so the
|
||||
// shared check treats it as "field not set".
|
||||
pskGot := ""
|
||||
if msg.OptionalPreSharedKey != nil && *msg.OptionalPreSharedKey != preSharedKeyRedactedSentinel {
|
||||
pskGot = *msg.OptionalPreSharedKey
|
||||
}
|
||||
|
||||
return resolveConflicts(policy, []conflictCheck{
|
||||
conflictString(mdm.KeyManagementURL, msg.ManagementUrl),
|
||||
conflictString(mdm.KeyPreSharedKey, pskGot),
|
||||
conflictBool(mdm.KeyRosenpassEnabled, msg.RosenpassEnabled),
|
||||
conflictBool(mdm.KeyRosenpassPermissive, msg.RosenpassPermissive),
|
||||
conflictBool(mdm.KeyDisableAutoConnect, msg.DisableAutoConnect),
|
||||
conflictBool(mdm.KeyAllowServerSSH, msg.ServerSSHAllowed),
|
||||
conflictBool(mdm.KeyDisableClientRoutes, msg.DisableClientRoutes),
|
||||
conflictBool(mdm.KeyDisableServerRoutes, msg.DisableServerRoutes),
|
||||
conflictBool(mdm.KeyBlockInbound, msg.BlockInbound),
|
||||
conflictInt64(mdm.KeyWireguardPort, msg.WireguardPort),
|
||||
})
|
||||
}
|
||||
|
||||
// setConfigRequestHasConfigOverrides reports whether the SetConfigRequest
|
||||
// carries ANY field that would actually mutate the persisted config.
|
||||
// The CLI builds a SetConfigRequest unconditionally on every
|
||||
// `netbird up` (see setupSetConfigReq in cmd/up.go) — a plain
|
||||
// `netbird up` produces a request with every field at its zero value;
|
||||
// the gate must skip such no-op invocations or it would always fire
|
||||
// even when the user did not pass any --flag. Returns false on a nil
|
||||
// msg; true when any management/admin URL, PSK, DNS/NAT list+clean
|
||||
// flag, interface/port/MTU, or any optional bool/duration field is set.
|
||||
func setConfigRequestHasConfigOverrides(msg *proto.SetConfigRequest) bool {
|
||||
if msg == nil {
|
||||
return false
|
||||
}
|
||||
return msg.ManagementUrl != "" ||
|
||||
msg.AdminURL != "" ||
|
||||
msg.OptionalPreSharedKey != nil ||
|
||||
len(msg.CustomDNSAddress) > 0 ||
|
||||
len(msg.NatExternalIPs) > 0 || msg.CleanNATExternalIPs ||
|
||||
len(msg.ExtraIFaceBlacklist) > 0 ||
|
||||
len(msg.DnsLabels) > 0 || msg.CleanDNSLabels ||
|
||||
msg.DnsRouteInterval != nil ||
|
||||
msg.RosenpassEnabled != nil ||
|
||||
msg.RosenpassPermissive != nil ||
|
||||
msg.InterfaceName != nil ||
|
||||
msg.WireguardPort != nil ||
|
||||
msg.Mtu != nil ||
|
||||
msg.DisableAutoConnect != nil ||
|
||||
msg.ServerSSHAllowed != nil ||
|
||||
msg.NetworkMonitor != nil ||
|
||||
msg.DisableClientRoutes != nil ||
|
||||
msg.DisableServerRoutes != nil ||
|
||||
msg.DisableDns != nil ||
|
||||
msg.DisableFirewall != nil ||
|
||||
msg.BlockLanAccess != nil ||
|
||||
msg.DisableNotifications != nil ||
|
||||
msg.LazyConnectionEnabled != nil ||
|
||||
msg.BlockInbound != nil ||
|
||||
msg.DisableIpv6 != nil ||
|
||||
msg.EnableSSHRoot != nil ||
|
||||
msg.EnableSSHSFTP != nil ||
|
||||
msg.EnableSSHLocalPortForwarding != nil ||
|
||||
msg.EnableSSHRemotePortForwarding != nil ||
|
||||
msg.DisableSSHAuth != nil ||
|
||||
msg.SshJWTCacheTTL != nil
|
||||
}
|
||||
|
||||
// loginRequestHasConfigOverrides reports whether the LoginRequest
|
||||
// carries ANY field that would mutate persisted daemon configuration
|
||||
// (as opposed to pure-auth fields like setupKey, hostname, hint,
|
||||
// profileName, username). Used by the Login handler to decide whether
|
||||
// the `--disable-update-settings` / MDM gates must run: a re-auth that
|
||||
// changes nothing about the configuration is always allowed.
|
||||
func loginRequestHasConfigOverrides(msg *proto.LoginRequest) bool {
|
||||
if msg == nil {
|
||||
return false
|
||||
}
|
||||
return msg.ManagementUrl != "" ||
|
||||
msg.AdminURL != "" ||
|
||||
msg.PreSharedKey != "" || //nolint:staticcheck // SA1019: legacy proto field still accepted by Login
|
||||
msg.OptionalPreSharedKey != nil ||
|
||||
len(msg.CustomDNSAddress) > 0 ||
|
||||
len(msg.NatExternalIPs) > 0 || msg.CleanNATExternalIPs ||
|
||||
msg.RosenpassEnabled != nil ||
|
||||
msg.InterfaceName != nil ||
|
||||
msg.WireguardPort != nil ||
|
||||
msg.DisableAutoConnect != nil ||
|
||||
msg.ServerSSHAllowed != nil ||
|
||||
msg.RosenpassPermissive != nil ||
|
||||
len(msg.ExtraIFaceBlacklist) > 0 ||
|
||||
msg.NetworkMonitor != nil ||
|
||||
msg.DnsRouteInterval != nil ||
|
||||
msg.DisableClientRoutes != nil ||
|
||||
msg.DisableServerRoutes != nil ||
|
||||
msg.DisableDns != nil ||
|
||||
msg.DisableFirewall != nil ||
|
||||
msg.BlockLanAccess != nil ||
|
||||
msg.DisableNotifications != nil ||
|
||||
len(msg.DnsLabels) > 0 || msg.CleanDNSLabels ||
|
||||
msg.LazyConnectionEnabled != nil ||
|
||||
msg.BlockInbound != nil
|
||||
}
|
||||
|
||||
// loginRequestMDMConflicts mirrors mdmManagedFieldConflicts but for the
|
||||
// LoginRequest surface. Same value-aware semantics: a field set to the
|
||||
// MDM-enforced value is a no-op echo, not a conflict; only a divergent
|
||||
// value is flagged. PSK has two proto fields — PreSharedKey (deprecated)
|
||||
// and OptionalPreSharedKey (current); either route trips the gate if it
|
||||
// diverges from the MDM-enforced PSK. OptionalPreSharedKey wins when
|
||||
// both are set; the redaction sentinel ("**********") is accepted as
|
||||
// a no-op echo.
|
||||
func loginRequestMDMConflicts(msg *proto.LoginRequest, policy *mdm.Policy) []string {
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Collapse the two PSK fields + the redaction sentinel down to a
|
||||
// single "got" string the shared check can compare against the
|
||||
// policy: OptionalPreSharedKey wins if set; PreSharedKey (deprecated)
|
||||
// is the fallback; sentinel echo is treated as "field not set".
|
||||
pskGot := ""
|
||||
if msg.OptionalPreSharedKey != nil {
|
||||
pskGot = *msg.OptionalPreSharedKey
|
||||
} else if msg.PreSharedKey != "" { //nolint:staticcheck // SA1019: legacy proto field still accepted by Login
|
||||
pskGot = msg.PreSharedKey //nolint:staticcheck // SA1019
|
||||
}
|
||||
if pskGot == preSharedKeyRedactedSentinel {
|
||||
pskGot = ""
|
||||
}
|
||||
|
||||
return resolveConflicts(policy, []conflictCheck{
|
||||
conflictString(mdm.KeyManagementURL, msg.ManagementUrl),
|
||||
conflictString(mdm.KeyPreSharedKey, pskGot),
|
||||
conflictBool(mdm.KeyRosenpassEnabled, msg.RosenpassEnabled),
|
||||
conflictBool(mdm.KeyRosenpassPermissive, msg.RosenpassPermissive),
|
||||
conflictBool(mdm.KeyDisableAutoConnect, msg.DisableAutoConnect),
|
||||
conflictBool(mdm.KeyAllowServerSSH, msg.ServerSSHAllowed),
|
||||
conflictBool(mdm.KeyDisableClientRoutes, msg.DisableClientRoutes),
|
||||
conflictBool(mdm.KeyDisableServerRoutes, msg.DisableServerRoutes),
|
||||
conflictBool(mdm.KeyBlockInbound, msg.BlockInbound),
|
||||
conflictInt64(mdm.KeyWireguardPort, msg.WireguardPort),
|
||||
})
|
||||
}
|
||||
|
||||
// rejectMDMManagedFieldConflicts returns a FailedPrecondition gRPC error
|
||||
// with an MDMManagedFieldsViolation detail when any of the requested
|
||||
// fields tries to change an MDM-enforced value to something else, and
|
||||
// nil otherwise. The whole request is rejected on any conflict; non-
|
||||
// conflicting fields in the same request are not applied either (no
|
||||
// partial apply).
|
||||
func rejectMDMManagedFieldConflicts(conflicts []string) error {
|
||||
if len(conflicts) == 0 {
|
||||
return nil
|
||||
}
|
||||
log.Warnf("MDM rejected request: tried to modify %d managed key(s): %v",
|
||||
len(conflicts), conflicts)
|
||||
st := gstatus.New(
|
||||
codes.FailedPrecondition,
|
||||
fmt.Sprintf("fields managed by MDM cannot be modified: %v", conflicts),
|
||||
)
|
||||
detailed, err := st.WithDetails(&proto.MDMManagedFieldsViolation{Fields: conflicts})
|
||||
if err != nil {
|
||||
// Detail attachment is best-effort; fall back to the plain status
|
||||
// so the caller still gets a usable FailedPrecondition.
|
||||
return st.Err()
|
||||
}
|
||||
return detailed.Err()
|
||||
}
|
||||
@@ -30,7 +30,7 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.checkNetworksDisabled() {
|
||||
if s.networksDisabled {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
@@ -143,7 +143,7 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.checkNetworksDisabled() {
|
||||
if s.networksDisabled {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
@@ -195,7 +195,7 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.checkNetworksDisabled() {
|
||||
if s.networksDisabled {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sleephandler "github.com/netbirdio/netbird/client/internal/sleep/handler"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -72,13 +71,7 @@ type Server struct {
|
||||
mutex sync.Mutex
|
||||
config *profilemanager.Config
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
// clientRunning tracks "the daemon wants to be connected" — set true by
|
||||
// Start / Up, cleared by Down / Logout. Persists across retry
|
||||
// loops, signal disconnects, and ErrResetConnection cycles. NOT
|
||||
// changed by connectWithRetryRuns goroutine exit — for that
|
||||
// (goroutine-still-alive) check, see connectionGoroutineRunning() which
|
||||
// derives from clientGiveUpChan close state. Protected by s.mutex.
|
||||
clientRunning bool
|
||||
clientRunning bool // protected by mutex
|
||||
clientRunningChan chan struct{}
|
||||
clientGiveUpChan chan struct{} // closed when connectWithRetryRuns goroutine exits
|
||||
|
||||
@@ -100,20 +93,11 @@ type Server struct {
|
||||
captureEnabled bool
|
||||
bundleCapture *bundleCapture
|
||||
// activeCapture is the session currently installed on the engine; guarded by s.mutex.
|
||||
activeCapture *capture.Session
|
||||
// activeCaptureCancel tears down the streaming pipe/cancel for the
|
||||
// active streaming capture so eviction unblocks the StartCapture RPC
|
||||
// handler. Nil for bundle captures (they own their own context).
|
||||
activeCaptureCancel func()
|
||||
networksDisabled bool
|
||||
activeCapture *capture.Session
|
||||
networksDisabled bool
|
||||
|
||||
sleepHandler *sleephandler.SleepHandler
|
||||
|
||||
// mdmTicker periodically re-reads the OS-native MDM policy and triggers
|
||||
// an engine restart when the policy changes. Launched once by Start;
|
||||
// stopped by the rootCtx cancellation.
|
||||
mdmTicker *mdm.Ticker
|
||||
|
||||
updateManager *updater.Manager
|
||||
|
||||
jwtCache *jwtCache
|
||||
@@ -171,17 +155,6 @@ func (s *Server) Start() error {
|
||||
s.updateManager.CheckUpdateSuccess(s.rootCtx)
|
||||
}
|
||||
|
||||
// MDM policy reload ticker: every minute the desktop daemon re-reads
|
||||
// the OS-native managed-config store and, on diff vs the previous
|
||||
// observation, cancels the active engine context so connectWithRetry-
|
||||
// Runs re-resolves Config (re-running profilemanager.Config.apply which
|
||||
// applies the freshly-read MDM policy as the last layer) and brings
|
||||
// the engine back with the new values.
|
||||
if s.mdmTicker == nil {
|
||||
s.mdmTicker = mdm.NewTicker(mdm.DefaultReloadInterval)
|
||||
go s.mdmTicker.Run(s.rootCtx, s.onMDMPolicyChange)
|
||||
}
|
||||
|
||||
// if current state contains any error, return it
|
||||
// in all other cases we can continue execution only if status is idle and up command was
|
||||
// not in the progress or already successfully established connection.
|
||||
@@ -240,27 +213,17 @@ func (s *Server) Start() error {
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
s.publishConfigChangedEvent("startup")
|
||||
return nil
|
||||
}
|
||||
|
||||
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
|
||||
// mechanism to keep the client connected even when the connection is lost.
|
||||
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
|
||||
//
|
||||
// The goroutine's exit is signalled to the daemon via close(giveUpChan)
|
||||
// — placed in the function-scope defer so every return path (panic,
|
||||
// DisableAutoConnect early-exit, backoff exhausted, ctx cancel) closes
|
||||
// it. Callers that need to observe "is the goroutine still alive?" use
|
||||
// Server.connectionGoroutineRunning() which non-blockingly checks the close state
|
||||
// of clientGiveUpChan. The defer does NOT touch s.mutex; the daemon's
|
||||
// "intent" (clientRunning) is maintained by the RPC handlers, not by this
|
||||
// goroutine.
|
||||
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
|
||||
defer func() {
|
||||
if giveUpChan != nil {
|
||||
close(giveUpChan)
|
||||
}
|
||||
s.mutex.Lock()
|
||||
s.clientRunning = false
|
||||
s.mutex.Unlock()
|
||||
}()
|
||||
|
||||
if s.config.DisableAutoConnect {
|
||||
@@ -306,26 +269,9 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
|
||||
if err := backoff.Retry(runOperation, backOff); err != nil {
|
||||
log.Errorf("operation failed: %v", err)
|
||||
}
|
||||
// giveUpChan is closed by the function-scope defer.
|
||||
}
|
||||
|
||||
// connectionGoroutineRunning reports whether the connectWithRetryRuns goroutine is
|
||||
// still running. Returns false when no goroutine has ever been started
|
||||
// AND when the most recent one has already closed clientGiveUpChan on
|
||||
// exit (whether due to ctx cancel, DisableAutoConnect single-shot
|
||||
// completion, or backoff retry exhaustion).
|
||||
//
|
||||
// MUST be called with s.mutex held — accesses s.clientGiveUpChan which
|
||||
// is written by Start/Up under the same lock.
|
||||
func (s *Server) connectionGoroutineRunning() bool {
|
||||
if s.clientGiveUpChan == nil {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
if giveUpChan != nil {
|
||||
close(giveUpChan)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -358,85 +304,54 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
// Skip the update-settings gate when the request carries no actual
|
||||
// overrides: the CLI builds a SetConfigRequest unconditionally on
|
||||
// every `netbird up` (setupSetConfigReq in cmd/up.go), so a plain
|
||||
// `netbird up` would otherwise always trip the gate and surface a
|
||||
// misleading "setConfig method is not available" warning, even when
|
||||
// the user did not pass any config flag.
|
||||
if setConfigRequestHasConfigOverrides(msg) {
|
||||
if s.checkUpdateSettingsDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
|
||||
}
|
||||
if s.checkUpdateSettingsDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
|
||||
}
|
||||
|
||||
// MDM gate: refuse the whole request if any of its fields is enforced
|
||||
// by the active MDM policy. The error carries an MDMManagedFields-
|
||||
// Violation detail listing the offending key names. Non-conflicting
|
||||
// fields in the same request are not applied either.
|
||||
policy := loadMDMPolicy()
|
||||
if err := rejectMDMManagedFieldConflicts(mdmManagedFieldConflicts(msg, policy)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config, err := setConfigInputFromRequest(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := profilemanager.UpdateConfig(config); err != nil {
|
||||
log.Errorf("failed to update profile config: %v", err)
|
||||
return nil, fmt.Errorf("failed to update profile config: %w", err)
|
||||
}
|
||||
|
||||
return &proto.SetConfigResponse{}, nil
|
||||
}
|
||||
|
||||
// setConfigInputFromRequest translates a SetConfigRequest into the
|
||||
// profilemanager.ConfigInput that profilemanager.UpdateConfig consumes.
|
||||
// Pure mapping with no business logic beyond presence-aware copying of
|
||||
// optional fields and the "empty / clean" semantics for the two slice
|
||||
// fields (DNS labels, NAT external IPs). Extracted from SetConfig to
|
||||
// keep the handler's cognitive complexity below the SonarCube
|
||||
// threshold; the body is intentionally linear because each proto
|
||||
// field is its own optional case. Returns the resolved ConfigInput
|
||||
// and a non-nil error only when the active profile file path cannot
|
||||
// be determined.
|
||||
func setConfigInputFromRequest(msg *proto.SetConfigRequest) (profilemanager.ConfigInput, error) {
|
||||
var config profilemanager.ConfigInput
|
||||
|
||||
profState := profilemanager.ActiveProfileState{
|
||||
Name: msg.ProfileName,
|
||||
Username: msg.Username,
|
||||
}
|
||||
|
||||
profPath, err := profState.FilePath()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get active profile file path: %v", err)
|
||||
return config, fmt.Errorf("failed to get active profile file path: %w", err)
|
||||
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
|
||||
}
|
||||
|
||||
var config profilemanager.ConfigInput
|
||||
|
||||
config.ConfigPath = profPath
|
||||
|
||||
if msg.ManagementUrl != "" {
|
||||
config.ManagementURL = msg.ManagementUrl
|
||||
}
|
||||
|
||||
if msg.AdminURL != "" {
|
||||
config.AdminURL = msg.AdminURL
|
||||
}
|
||||
|
||||
if msg.InterfaceName != nil {
|
||||
config.InterfaceName = msg.InterfaceName
|
||||
}
|
||||
|
||||
if msg.WireguardPort != nil {
|
||||
wgPort := int(*msg.WireguardPort)
|
||||
config.WireguardPort = &wgPort
|
||||
}
|
||||
if msg.OptionalPreSharedKey != nil && *msg.OptionalPreSharedKey != "" {
|
||||
config.PreSharedKey = msg.OptionalPreSharedKey
|
||||
|
||||
if msg.OptionalPreSharedKey != nil {
|
||||
if *msg.OptionalPreSharedKey != "" {
|
||||
config.PreSharedKey = msg.OptionalPreSharedKey
|
||||
}
|
||||
}
|
||||
|
||||
if msg.CleanDNSLabels {
|
||||
config.DNSLabels = domain.List{}
|
||||
|
||||
} else if msg.DnsLabels != nil {
|
||||
config.DNSLabels = domain.FromPunycodeList(msg.DnsLabels)
|
||||
dnsLabels := domain.FromPunycodeList(msg.DnsLabels)
|
||||
config.DNSLabels = dnsLabels
|
||||
}
|
||||
|
||||
if msg.CleanNATExternalIPs {
|
||||
@@ -449,6 +364,7 @@ func setConfigInputFromRequest(msg *proto.SetConfigRequest) (profilemanager.Conf
|
||||
if string(msg.CustomDNSAddress) == "empty" {
|
||||
config.CustomDNSAddress = []byte{}
|
||||
}
|
||||
|
||||
config.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
|
||||
|
||||
if msg.DnsRouteInterval != nil {
|
||||
@@ -460,8 +376,6 @@ func setConfigInputFromRequest(msg *proto.SetConfigRequest) (profilemanager.Conf
|
||||
config.RosenpassPermissive = msg.RosenpassPermissive
|
||||
config.DisableAutoConnect = msg.DisableAutoConnect
|
||||
config.ServerSSHAllowed = msg.ServerSSHAllowed
|
||||
config.ServerVNCAllowed = msg.ServerVNCAllowed
|
||||
config.DisableVNCApproval = msg.DisableVNCApproval
|
||||
config.NetworkMonitor = msg.NetworkMonitor
|
||||
config.DisableClientRoutes = msg.DisableClientRoutes
|
||||
config.DisableServerRoutes = msg.DisableServerRoutes
|
||||
@@ -483,31 +397,22 @@ func setConfigInputFromRequest(msg *proto.SetConfigRequest) (profilemanager.Conf
|
||||
ttl := int(*msg.SshJWTCacheTTL)
|
||||
config.SSHJWTCacheTTL = &ttl
|
||||
}
|
||||
|
||||
if msg.Mtu != nil {
|
||||
mtu := uint16(*msg.Mtu)
|
||||
config.MTU = &mtu
|
||||
}
|
||||
return config, nil
|
||||
|
||||
if _, err := profilemanager.UpdateConfig(config); err != nil {
|
||||
log.Errorf("failed to update profile config: %v", err)
|
||||
return nil, fmt.Errorf("failed to update profile config: %w", err)
|
||||
}
|
||||
|
||||
return &proto.SetConfigResponse{}, nil
|
||||
}
|
||||
|
||||
// Login uses setup key to prepare configuration for the daemon.
|
||||
func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*proto.LoginResponse, error) {
|
||||
// Config-override gates. LoginRequest carries the same surface as
|
||||
// SetConfigRequest (managementUrl, PSK, ssh/rosenpass/port toggles,
|
||||
// ...), so the same protections must apply. Without these the CLI
|
||||
// command `netbird up --management-url=X` (which falls through to
|
||||
// Login when SetConfig is rejected — see cmd/up.go) would silently
|
||||
// bypass `--disable-update-settings` and any MDM policy.
|
||||
if loginRequestHasConfigOverrides(msg) {
|
||||
if s.checkUpdateSettingsDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
|
||||
}
|
||||
policy := loadMDMPolicy()
|
||||
if err := rejectMDMManagedFieldConflicts(loginRequestMDMConflicts(msg, policy)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
if s.actCancel != nil {
|
||||
s.actCancel()
|
||||
@@ -747,13 +652,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
||||
// Up starts engine work in the daemon.
|
||||
func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) {
|
||||
s.mutex.Lock()
|
||||
// clientRunning is the daemon-intent flag (set by previous Up/Start, cleared
|
||||
// by Down). connectionGoroutineRunning() reports whether the previous retry-loop
|
||||
// goroutine is still trying. When intent is up AND goroutine is alive,
|
||||
// the existing engine is on the job — just wait for it. When intent
|
||||
// is up but the goroutine has given up (backoff exhausted) OR when
|
||||
// intent is down, fall through to spawn a fresh retry loop.
|
||||
if s.clientRunning && s.connectionGoroutineRunning() {
|
||||
if s.clientRunning {
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
@@ -844,7 +743,6 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
|
||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
s.publishConfigChangedEvent("up_rpc")
|
||||
|
||||
s.mutex.Unlock()
|
||||
return s.waitForUp(callerCtx)
|
||||
@@ -973,12 +871,6 @@ func (s *Server) cleanupConnection() error {
|
||||
return ErrServiceNotUp
|
||||
}
|
||||
|
||||
// Daemon intent flips to "down" — all callers (Down RPC,
|
||||
// Logout RPC handlers) tear down the connection because the user
|
||||
// explicitly asked for it. MDM restart does NOT go through this
|
||||
// path, so its clientRunning stays true.
|
||||
s.clientRunning = false
|
||||
|
||||
// Capture the engine reference before cancelling the context.
|
||||
// After actCancel(), the connectWithRetryRuns goroutine wakes up
|
||||
// and sets connectClient.engine = nil, causing connectClient.Stop()
|
||||
@@ -1182,14 +1074,10 @@ func (s *Server) Status(
|
||||
msg *proto.StatusRequest,
|
||||
) (*proto.StatusResponse, error) {
|
||||
s.mutex.Lock()
|
||||
// Only wait if the retry-loop goroutine is alive and making
|
||||
// progress. clientRunning=true with connectionGoroutineRunning=false means the
|
||||
// backoff has given up — there is nothing to wait for; let the
|
||||
// caller observe the failed status directly.
|
||||
alive := s.connectionGoroutineRunning()
|
||||
clientRunning := s.clientRunning
|
||||
s.mutex.Unlock()
|
||||
|
||||
if msg.WaitForReady != nil && *msg.WaitForReady && alive {
|
||||
if msg.WaitForReady != nil && *msg.WaitForReady && clientRunning {
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
@@ -1248,7 +1136,6 @@ func (s *Server) Status(
|
||||
pbFullStatus := fullStatus.ToProto()
|
||||
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||
pbFullStatus.SshServerState = s.getSSHServerState()
|
||||
pbFullStatus.VncServerState = s.getVNCServerState()
|
||||
statusResponse.FullStatus = pbFullStatus
|
||||
}
|
||||
|
||||
@@ -1288,38 +1175,6 @@ func (s *Server) getSSHServerState() *proto.SSHServerState {
|
||||
return sshServerState
|
||||
}
|
||||
|
||||
// getVNCServerState retrieves the current VNC server state.
|
||||
func (s *Server) getVNCServerState() *proto.VNCServerState {
|
||||
s.mutex.Lock()
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
|
||||
if connectClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
enabled, sessions := engine.GetVNCServerStatus()
|
||||
pbSessions := make([]*proto.VNCSessionInfo, 0, len(sessions))
|
||||
for _, sess := range sessions {
|
||||
pbSessions = append(pbSessions, &proto.VNCSessionInfo{
|
||||
RemoteAddress: sess.RemoteAddress,
|
||||
Mode: sess.Mode,
|
||||
Username: sess.Username,
|
||||
UserID: sess.UserID,
|
||||
Initiator: sess.Initiator,
|
||||
})
|
||||
}
|
||||
return &proto.VNCServerState{
|
||||
Enabled: enabled,
|
||||
Sessions: pbSessions,
|
||||
}
|
||||
}
|
||||
|
||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||
func (s *Server) GetPeerSSHHostKey(
|
||||
ctx context.Context,
|
||||
@@ -1560,27 +1415,6 @@ func (s *Server) ExposeService(req *proto.ExposeServiceRequest, srv proto.Daemon
|
||||
return nil
|
||||
}
|
||||
|
||||
// RespondApproval relays the user's accept/deny decision for a pending
|
||||
// approval prompt to the engine's broker. Unknown or already-resolved
|
||||
// request_ids are silently no-op'd so a slow UI cannot deny a prompt the
|
||||
// user already handled (or that already timed out).
|
||||
func (s *Server) RespondApproval(_ context.Context, msg *proto.RespondApprovalRequest) (*proto.RespondApprovalResponse, error) {
|
||||
s.mutex.Lock()
|
||||
connectClient := s.connectClient
|
||||
s.mutex.Unlock()
|
||||
if connectClient == nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "client not initialized")
|
||||
}
|
||||
engine := connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, gstatus.Errorf(codes.FailedPrecondition, "engine not running")
|
||||
}
|
||||
if !engine.RespondApproval(msg.GetRequestId(), msg.GetAccept(), msg.GetViewOnly()) {
|
||||
log.Debugf("approval response for unknown request_id %s", msg.GetRequestId())
|
||||
}
|
||||
return &proto.RespondApprovalResponse{}, nil
|
||||
}
|
||||
|
||||
func isUnixRunningDesktop() bool {
|
||||
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||
return false
|
||||
@@ -1697,8 +1531,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
Mtu: int64(cfg.MTU),
|
||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||
ServerVNCAllowed: cfg.ServerVNCAllowed != nil && *cfg.ServerVNCAllowed,
|
||||
DisableVNCApproval: cfg.DisableVNCApproval != nil && *cfg.DisableVNCApproval,
|
||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||
@@ -1716,7 +1548,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
|
||||
DisableSSHAuth: disableSSHAuth,
|
||||
SshJWTCacheTTL: sshJWTCacheTTL,
|
||||
MDMManagedFields: cfg.Policy().ManagedKeys(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1815,7 +1646,7 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
|
||||
features := &proto.GetFeaturesResponse{
|
||||
DisableProfiles: s.checkProfilesDisabled(),
|
||||
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
|
||||
DisableNetworks: s.checkNetworksDisabled(),
|
||||
DisableNetworks: s.networksDisabled,
|
||||
}
|
||||
|
||||
return features, nil
|
||||
@@ -1837,46 +1668,22 @@ func (s *Server) connect(ctx context.Context, config *profilemanager.Config, sta
|
||||
return nil
|
||||
}
|
||||
|
||||
// MDM authority: when the platform-native MDM source sets a kill switch
|
||||
// key (regardless of true/false value), that value wins. The CLI flag
|
||||
// supplied at service install time is the fallback used only when the
|
||||
// MDM source is silent on the key. This honors the "MDM decides
|
||||
// everything" semantic agreed for NET-1214 — an admin pushing
|
||||
// disableX=false via MDM explicitly re-enables the feature even on a
|
||||
// box installed with --disable-X.
|
||||
func (s *Server) checkProfilesDisabled() bool {
|
||||
if s.config != nil {
|
||||
if v, ok := s.config.Policy().GetBool(mdm.KeyDisableProfiles); ok {
|
||||
return v
|
||||
}
|
||||
// Check if the environment variable is set to disable profiles
|
||||
if s.profilesDisabled {
|
||||
return true
|
||||
}
|
||||
return s.profilesDisabled
|
||||
}
|
||||
|
||||
// checkNetworksDisabled reports whether the networks/exit-node feature
|
||||
// is disabled on this daemon instance. Resolved MDM-first: when the
|
||||
// active policy declares mdm.KeyDisableNetworks the policy value wins
|
||||
// (regardless of true/false), so an admin can re-enable the feature
|
||||
// via MDM even on a host that was installed with --disable-networks.
|
||||
// Falls back to the s.networksDisabled CLI flag when the policy is
|
||||
// silent on the key. Mirrors checkProfilesDisabled and
|
||||
// checkUpdateSettingsDisabled.
|
||||
func (s *Server) checkNetworksDisabled() bool {
|
||||
if s.config != nil {
|
||||
if v, ok := s.config.Policy().GetBool(mdm.KeyDisableNetworks); ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return s.networksDisabled
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Server) checkUpdateSettingsDisabled() bool {
|
||||
if s.config != nil {
|
||||
if v, ok := s.config.Policy().GetBool(mdm.KeyDisableUpdateSettings); ok {
|
||||
return v
|
||||
}
|
||||
// Check if the environment variable is set to disable profiles
|
||||
if s.updateSettingsDisabled {
|
||||
return true
|
||||
}
|
||||
return s.updateSettingsDisabled
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Server) startUpdateManagerForGUI() {
|
||||
|
||||
@@ -101,7 +101,6 @@ func TestCleanupConnection_ClearsConnectClient(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
|
||||
assert.False(t, s.clientRunning, "clientRunning should be cleared after cleanup (intent = down)")
|
||||
}
|
||||
|
||||
// TestCleanState_NilConnectClient validates that CleanState doesn't panic
|
||||
@@ -145,20 +144,17 @@ func TestDownThenUp_StaleRunningChan(t *testing.T) {
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
s.actCancel = cancel
|
||||
|
||||
// Simulate Down(): cleanupConnection sets connectClient = nil and
|
||||
// flips clientRunning to false (intent = down). The connectionGoroutineRunning state
|
||||
// remains independent of intent — derived from clientGiveUpChan.
|
||||
// Simulate Down(): cleanupConnection sets connectClient = nil
|
||||
s.mutex.Lock()
|
||||
err := s.cleanupConnection()
|
||||
s.mutex.Unlock()
|
||||
require.NoError(t, err)
|
||||
|
||||
// After cleanup: connectClient is nil, clientRunning is false (intent
|
||||
// cleared by cleanupConnection), connectionGoroutineRunning may still be true
|
||||
// (goroutine teardown is independent of the intent flag).
|
||||
// After cleanup: connectClient is nil, clientRunning still true
|
||||
// (goroutine hasn't exited yet)
|
||||
s.mutex.Lock()
|
||||
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
|
||||
assert.False(t, s.clientRunning, "clientRunning should be cleared by cleanupConnection (intent = down)")
|
||||
assert.True(t, s.clientRunning, "clientRunning still true until goroutine exits")
|
||||
s.mutex.Unlock()
|
||||
|
||||
// waitForUp() returns immediately due to stale closed clientRunningChan
|
||||
|
||||
@@ -1,198 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// withMDMPolicy temporarily overrides the server-package loadMDMPolicy hook
|
||||
// so SetConfig observes the supplied Policy. Restores the original loader
|
||||
// at test cleanup.
|
||||
func withMDMPolicy(t *testing.T, policy *mdm.Policy) {
|
||||
t.Helper()
|
||||
prev := loadMDMPolicy
|
||||
loadMDMPolicy = func() *mdm.Policy { return policy }
|
||||
t.Cleanup(func() { loadMDMPolicy = prev })
|
||||
}
|
||||
|
||||
// setupServerWithProfile mirrors the boilerplate of TestSetConfig_AllFieldsSaved:
|
||||
// overrides profilemanager paths to a temp dir, seeds a profile, sets it
|
||||
// active, and constructs a Server instance. Returns the constructed server
|
||||
// plus context + profile name + username + cfgPath for the seeded profile.
|
||||
func setupServerWithProfile(t *testing.T) (s *Server, ctx context.Context, profName, username, cfgPath string) {
|
||||
t.Helper()
|
||||
tempDir := t.TempDir()
|
||||
|
||||
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
|
||||
origDefaultConfigPath := profilemanager.DefaultConfigPath
|
||||
origActiveProfileStatePath := profilemanager.ActiveProfileStatePath
|
||||
profilemanager.ConfigDirOverride = tempDir
|
||||
profilemanager.DefaultConfigPathDir = tempDir
|
||||
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
|
||||
profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json")
|
||||
t.Cleanup(func() {
|
||||
profilemanager.DefaultConfigPathDir = origDefaultProfileDir
|
||||
profilemanager.ActiveProfileStatePath = origActiveProfileStatePath
|
||||
profilemanager.DefaultConfigPath = origDefaultConfigPath
|
||||
profilemanager.ConfigDirOverride = ""
|
||||
})
|
||||
|
||||
currUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
profName = "test-profile-mdm"
|
||||
cfgPath = filepath.Join(tempDir, profName+".json")
|
||||
|
||||
_, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: cfgPath,
|
||||
ManagementURL: "https://api.netbird.io:443",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
require.NoError(t, pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: profName,
|
||||
Username: currUser.Username,
|
||||
}))
|
||||
|
||||
ctx = context.Background()
|
||||
s = New(ctx, "console", "", false, false, false, false)
|
||||
return s, ctx, profName, currUser.Username, cfgPath
|
||||
}
|
||||
|
||||
// extractViolation pulls the MDMManagedFieldsViolation detail from a
|
||||
// FailedPrecondition error. Fails the test if absent or malformed.
|
||||
func extractViolation(t *testing.T, err error) *proto.MDMManagedFieldsViolation {
|
||||
t.Helper()
|
||||
require.Error(t, err)
|
||||
st, ok := gstatus.FromError(err)
|
||||
require.True(t, ok, "error must be a gRPC status: %v", err)
|
||||
require.Equal(t, codes.FailedPrecondition, st.Code(), "expected FailedPrecondition, got %s", st.Code())
|
||||
for _, d := range st.Details() {
|
||||
if v, ok := d.(*proto.MDMManagedFieldsViolation); ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
t.Fatalf("MDMManagedFieldsViolation detail not found on status; details: %v", st.Details())
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMReject_SingleField(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
ManagementUrl: "https://user.tried.this.com:443",
|
||||
})
|
||||
|
||||
v := extractViolation(t, err)
|
||||
assert.Equal(t, []string{mdm.KeyManagementURL}, v.GetFields())
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMReject_MultipleFields(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
mdm.KeyBlockInbound: true,
|
||||
mdm.KeyRosenpassEnabled: true,
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
blockInbound := false
|
||||
rosenpassEnabled := false
|
||||
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
ManagementUrl: "https://user.tried.this.com:443",
|
||||
BlockInbound: &blockInbound,
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
})
|
||||
|
||||
v := extractViolation(t, err)
|
||||
assert.ElementsMatch(t, []string{
|
||||
mdm.KeyManagementURL,
|
||||
mdm.KeyBlockInbound,
|
||||
mdm.KeyRosenpassEnabled,
|
||||
}, v.GetFields())
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMReject_AllOrNothing(t *testing.T) {
|
||||
// MDM enforces ManagementURL only; user request touches both the
|
||||
// enforced field AND a non-enforced field (RosenpassEnabled).
|
||||
// The whole request must be rejected — non-conflicting fields are not
|
||||
// applied either.
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, cfgPath := setupServerWithProfile(t)
|
||||
|
||||
rosenpassEnabled := true
|
||||
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
ManagementUrl: "https://user.tried.this.com:443",
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
})
|
||||
|
||||
v := extractViolation(t, err)
|
||||
assert.Equal(t, []string{mdm.KeyManagementURL}, v.GetFields())
|
||||
|
||||
// Confirm RosenpassEnabled was NOT applied even though it was not
|
||||
// in the conflict list: the request was rejected as a whole.
|
||||
reloaded, err := profilemanager.GetConfig(cfgPath)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, reloaded.RosenpassEnabled, "non-conflicting field must not be applied when request is rejected")
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMAllow_NonManagedFields(t *testing.T) {
|
||||
// MDM enforces ManagementURL but the user only writes RosenpassEnabled.
|
||||
// Request must succeed.
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
rosenpassEnabled := true
|
||||
resp, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMEmpty_NoEnforcement(t *testing.T) {
|
||||
// No MDM policy active: any field can be written.
|
||||
withMDMPolicy(t, mdm.NewPolicy(nil))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
resp, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
ManagementUrl: "https://user.changed.url.com:443",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
@@ -58,8 +58,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
rosenpassEnabled := true
|
||||
rosenpassPermissive := true
|
||||
serverSSHAllowed := true
|
||||
serverVNCAllowed := true
|
||||
disableVNCApproval := true
|
||||
interfaceName := "utun100"
|
||||
wireguardPort := int64(51820)
|
||||
preSharedKey := "test-psk"
|
||||
@@ -85,8 +83,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
RosenpassPermissive: &rosenpassPermissive,
|
||||
ServerSSHAllowed: &serverSSHAllowed,
|
||||
ServerVNCAllowed: &serverVNCAllowed,
|
||||
DisableVNCApproval: &disableVNCApproval,
|
||||
InterfaceName: &interfaceName,
|
||||
WireguardPort: &wireguardPort,
|
||||
OptionalPreSharedKey: &preSharedKey,
|
||||
@@ -131,10 +127,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
require.Equal(t, rosenpassPermissive, cfg.RosenpassPermissive)
|
||||
require.NotNil(t, cfg.ServerSSHAllowed)
|
||||
require.Equal(t, serverSSHAllowed, *cfg.ServerSSHAllowed)
|
||||
require.NotNil(t, cfg.ServerVNCAllowed)
|
||||
require.Equal(t, serverVNCAllowed, *cfg.ServerVNCAllowed)
|
||||
require.NotNil(t, cfg.DisableVNCApproval)
|
||||
require.Equal(t, disableVNCApproval, *cfg.DisableVNCApproval)
|
||||
require.Equal(t, interfaceName, cfg.WgIface)
|
||||
require.Equal(t, int(wireguardPort), cfg.WgPort)
|
||||
require.Equal(t, preSharedKey, cfg.PreSharedKey)
|
||||
@@ -187,8 +179,6 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
||||
"RosenpassEnabled": true,
|
||||
"RosenpassPermissive": true,
|
||||
"ServerSSHAllowed": true,
|
||||
"ServerVNCAllowed": true,
|
||||
"DisableVNCApproval": true,
|
||||
"InterfaceName": true,
|
||||
"WireguardPort": true,
|
||||
"OptionalPreSharedKey": true,
|
||||
@@ -250,8 +240,6 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
||||
"enable-rosenpass": "RosenpassEnabled",
|
||||
"rosenpass-permissive": "RosenpassPermissive",
|
||||
"allow-server-ssh": "ServerSSHAllowed",
|
||||
"allow-server-vnc": "ServerVNCAllowed",
|
||||
"disable-vnc-approval": "DisableVNCApproval",
|
||||
"interface-name": "InterfaceName",
|
||||
"wireguard-port": "WireguardPort",
|
||||
"preshared-key": "OptionalPreSharedKey",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package sessionauth
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -15,8 +15,6 @@ const (
|
||||
DefaultUserIDClaim = "sub"
|
||||
// Wildcard is a special user ID that matches all users
|
||||
Wildcard = "*"
|
||||
// sessionPubKeyLen is the size of an X25519 static public key in bytes.
|
||||
sessionPubKeyLen = 32
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -24,7 +22,6 @@ var (
|
||||
ErrUserNotAuthorized = errors.New("user is not authorized to access this peer")
|
||||
ErrNoMachineUserMapping = errors.New("no authorization mapping for OS user")
|
||||
ErrUserNotMappedToOSUser = errors.New("user is not authorized to login as OS user")
|
||||
ErrSessionKeyNotKnown = errors.New("session pubkey not registered")
|
||||
)
|
||||
|
||||
// Authorizer handles SSH fine-grained access control authorization
|
||||
@@ -38,17 +35,6 @@ type Authorizer struct {
|
||||
// machineUsers maps OS login usernames to lists of authorized user indexes
|
||||
machineUsers map[string][]uint32
|
||||
|
||||
// sessionPubKeys maps an X25519 static public key (as map-safe
|
||||
// array) to the hashed user identity that key authenticates as.
|
||||
// Populated from management's temporary-access flow; used by VNC to
|
||||
// authenticate via the Noise_IK handshake.
|
||||
sessionPubKeys map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash
|
||||
// sessionDisplayNames mirrors sessionPubKeys with the optional
|
||||
// human-readable display name management associated with each
|
||||
// session key. Used by the per-connection UI approval prompt; not
|
||||
// consulted by any authorization decision.
|
||||
sessionDisplayNames map[[sessionPubKeyLen]byte]string
|
||||
|
||||
// mu protects the list of users
|
||||
mu sync.RWMutex
|
||||
}
|
||||
@@ -64,29 +50,13 @@ type Config struct {
|
||||
// MachineUsers maps OS login usernames to indexes in AuthorizedUsers
|
||||
// If a user wants to login as a specific OS user, their index must be in the corresponding list
|
||||
MachineUsers map[string][]uint32
|
||||
|
||||
// SessionPubKeys binds ephemeral X25519 static public keys to hashed
|
||||
// user identities. Populated for VNC; ignored on the SSH side.
|
||||
SessionPubKeys []SessionPubKey
|
||||
}
|
||||
|
||||
// SessionPubKey is a single ephemeral-key entry: the 32-byte X25519
|
||||
// static public key plus the hashed user identity it authenticates as,
|
||||
// optionally plus a human-readable display name for the UI approval
|
||||
// prompt to identify the requester.
|
||||
type SessionPubKey struct {
|
||||
PubKey []byte
|
||||
UserIDHash sshuserhash.UserIDHash
|
||||
DisplayName string
|
||||
}
|
||||
|
||||
// NewAuthorizer creates a new SSH authorizer with empty configuration
|
||||
func NewAuthorizer() *Authorizer {
|
||||
a := &Authorizer{
|
||||
userIDClaim: DefaultUserIDClaim,
|
||||
machineUsers: make(map[string][]uint32),
|
||||
sessionPubKeys: make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash),
|
||||
sessionDisplayNames: make(map[[sessionPubKeyLen]byte]string),
|
||||
userIDClaim: DefaultUserIDClaim,
|
||||
machineUsers: make(map[string][]uint32),
|
||||
}
|
||||
|
||||
return a
|
||||
@@ -102,8 +72,6 @@ func (a *Authorizer) Update(config *Config) {
|
||||
a.userIDClaim = DefaultUserIDClaim
|
||||
a.authorizedUsers = []sshuserhash.UserIDHash{}
|
||||
a.machineUsers = make(map[string][]uint32)
|
||||
a.sessionPubKeys = make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash)
|
||||
a.sessionDisplayNames = make(map[[sessionPubKeyLen]byte]string)
|
||||
log.Info("SSH authorization cleared")
|
||||
return
|
||||
}
|
||||
@@ -126,35 +94,8 @@ func (a *Authorizer) Update(config *Config) {
|
||||
}
|
||||
a.machineUsers = machineUsers
|
||||
|
||||
sessionPubKeys := make(map[[sessionPubKeyLen]byte]sshuserhash.UserIDHash, len(config.SessionPubKeys))
|
||||
sessionDisplayNames := make(map[[sessionPubKeyLen]byte]string, len(config.SessionPubKeys))
|
||||
conflicted := make(map[[sessionPubKeyLen]byte]struct{})
|
||||
for _, e := range config.SessionPubKeys {
|
||||
if len(e.PubKey) != sessionPubKeyLen {
|
||||
continue
|
||||
}
|
||||
var key [sessionPubKeyLen]byte
|
||||
copy(key[:], e.PubKey)
|
||||
if _, bad := conflicted[key]; bad {
|
||||
continue
|
||||
}
|
||||
if existing, ok := sessionPubKeys[key]; ok && existing != e.UserIDHash {
|
||||
log.Warnf("SSH auth: session pubkey bound to conflicting user hashes; dropping binding")
|
||||
delete(sessionPubKeys, key)
|
||||
delete(sessionDisplayNames, key)
|
||||
conflicted[key] = struct{}{}
|
||||
continue
|
||||
}
|
||||
sessionPubKeys[key] = e.UserIDHash
|
||||
if e.DisplayName != "" {
|
||||
sessionDisplayNames[key] = e.DisplayName
|
||||
}
|
||||
}
|
||||
a.sessionPubKeys = sessionPubKeys
|
||||
a.sessionDisplayNames = sessionDisplayNames
|
||||
|
||||
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings, %d session pubkeys",
|
||||
len(config.AuthorizedUsers), len(machineUsers), len(sessionPubKeys))
|
||||
log.Debugf("SSH auth: updated with %d authorized users, %d machine user mappings",
|
||||
len(config.AuthorizedUsers), len(machineUsers))
|
||||
}
|
||||
|
||||
// Authorize validates if a user is authorized to login as the specified OS user.
|
||||
@@ -214,54 +155,6 @@ func (a *Authorizer) GetUserIDClaim() string {
|
||||
return a.userIDClaim
|
||||
}
|
||||
|
||||
// LookupSessionKey resolves a Noise-verified static public key to the
|
||||
// hashed user identity registered with it. Fails closed when the key is
|
||||
// unknown.
|
||||
func (a *Authorizer) LookupSessionKey(pubKey []byte) (sshuserhash.UserIDHash, error) {
|
||||
var zero sshuserhash.UserIDHash
|
||||
if len(pubKey) != sessionPubKeyLen {
|
||||
return zero, fmt.Errorf("session pubkey wrong length: %d", len(pubKey))
|
||||
}
|
||||
var key [sessionPubKeyLen]byte
|
||||
copy(key[:], pubKey)
|
||||
a.mu.RLock()
|
||||
hash, ok := a.sessionPubKeys[key]
|
||||
a.mu.RUnlock()
|
||||
if !ok {
|
||||
return zero, ErrSessionKeyNotKnown
|
||||
}
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
// LookupSessionDisplayName returns the human-readable display name
|
||||
// management associated with a session pubkey, or empty string when none
|
||||
// is recorded. Never returns an error: a missing/unknown key reports as
|
||||
// "" and the caller falls back to other identifiers.
|
||||
func (a *Authorizer) LookupSessionDisplayName(pubKey []byte) string {
|
||||
if len(pubKey) != sessionPubKeyLen {
|
||||
return ""
|
||||
}
|
||||
var key [sessionPubKeyLen]byte
|
||||
copy(key[:], pubKey)
|
||||
a.mu.RLock()
|
||||
name := a.sessionDisplayNames[key]
|
||||
a.mu.RUnlock()
|
||||
return name
|
||||
}
|
||||
|
||||
// AuthorizeOSUserBySessionKey resolves the OS-user mapping for a session
|
||||
// key. Mirrors Authorize but skips the JWT-hash step since the key has
|
||||
// already been verified and the user identity hash is in hand.
|
||||
func (a *Authorizer) AuthorizeOSUserBySessionKey(userIDHash sshuserhash.UserIDHash, osUsername string) (string, error) {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
userIndex, found := a.findUserIndex(userIDHash)
|
||||
if !found {
|
||||
return "", fmt.Errorf("session user (hash: %s) not in authorized list for OS user %q: %w", userIDHash, osUsername, ErrUserNotAuthorized)
|
||||
}
|
||||
return a.checkMachineUserMapping("session", osUsername, userIndex)
|
||||
}
|
||||
|
||||
// findUserIndex finds the index of a hashed user ID in the authorized users list
|
||||
// Returns the index and true if found, 0 and false if not found
|
||||
func (a *Authorizer) findUserIndex(hashedUserID sshuserhash.UserIDHash) (int, bool) {
|
||||
@@ -1,7 +1,6 @@
|
||||
package sessionauth
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -611,61 +610,3 @@ func TestAuthorizer_Wildcard_WithPartialIndexes_AllowsAllUsers(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, ErrUserNotAuthorized, "unauthorized user should be denied")
|
||||
}
|
||||
|
||||
func TestAuthorizer_LookupSessionKey_Valid(t *testing.T) {
|
||||
pub := bytesRepeat(0x11, sessionPubKeyLen)
|
||||
userHash, err := sshauth.HashUserID("alice")
|
||||
require.NoError(t, err)
|
||||
|
||||
a := NewAuthorizer()
|
||||
a.Update(&Config{
|
||||
AuthorizedUsers: []sshauth.UserIDHash{userHash},
|
||||
MachineUsers: map[string][]uint32{Wildcard: {0}},
|
||||
SessionPubKeys: []SessionPubKey{{PubKey: pub, UserIDHash: userHash}},
|
||||
})
|
||||
|
||||
got, err := a.LookupSessionKey(pub)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, userHash, got)
|
||||
|
||||
if _, err := a.AuthorizeOSUserBySessionKey(got, "alice"); err != nil {
|
||||
t.Fatalf("AuthorizeOSUserBySessionKey: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizer_LookupSessionKey_UnknownPub(t *testing.T) {
|
||||
a := NewAuthorizer()
|
||||
a.Update(&Config{})
|
||||
_, err := a.LookupSessionKey(bytesRepeat(0x22, sessionPubKeyLen))
|
||||
require.ErrorIs(t, err, ErrSessionKeyNotKnown)
|
||||
}
|
||||
|
||||
func TestAuthorizer_LookupSessionKey_WrongLength(t *testing.T) {
|
||||
a := NewAuthorizer()
|
||||
_, err := a.LookupSessionKey([]byte("short"))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAuthorizer_LookupSessionKey_UpdateClears(t *testing.T) {
|
||||
pub := bytesRepeat(0x33, sessionPubKeyLen)
|
||||
userHash, err := sshauth.HashUserID("alice")
|
||||
require.NoError(t, err)
|
||||
|
||||
a := NewAuthorizer()
|
||||
a.Update(&Config{SessionPubKeys: []SessionPubKey{{PubKey: pub, UserIDHash: userHash}}})
|
||||
if _, err := a.LookupSessionKey(pub); err != nil {
|
||||
t.Fatalf("setup lookup: %v", err)
|
||||
}
|
||||
a.Update(&Config{})
|
||||
if _, err := a.LookupSessionKey(pub); !errors.Is(err, ErrSessionKeyNotKnown) {
|
||||
t.Fatalf("expected ErrSessionKeyNotKnown, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func bytesRepeat(b byte, n int) []byte {
|
||||
out := make([]byte, n)
|
||||
for i := range out {
|
||||
out[i] = b
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -28,10 +28,10 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
"github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/creack/pty"
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/netbirdio/netbird/client/internal/shell"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -146,10 +147,10 @@ func (s *Server) createShellCommand(ctx context.Context, shell string, args []st
|
||||
|
||||
// prepareCommandEnv prepares environment variables for command execution on Unix
|
||||
func (s *Server) prepareCommandEnv(_ *log.Entry, localUser *user.User, session ssh.Session) []string {
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
if shell.AcceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,10 +247,10 @@ func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, sess
|
||||
userEnv, err := s.getUserEnvironment(logger, username, domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
if shell.AcceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
@@ -260,7 +260,7 @@ func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, sess
|
||||
env := userEnv
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
if shell.AcceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
@@ -273,7 +273,7 @@ func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privileg
|
||||
return false
|
||||
}
|
||||
|
||||
shell := getUserShell(privilegeResult.User.Uid)
|
||||
shell := shell.GetUserShell(privilegeResult.User.Uid)
|
||||
logger.Infof("starting interactive shell: %s", shell)
|
||||
|
||||
s.executeCommandWithPty(logger, session, nil, privilegeResult, ptyReq, nil)
|
||||
@@ -384,7 +384,7 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _
|
||||
}
|
||||
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
shell := getUserShell(localUser.Uid)
|
||||
shell := shell.GetUserShell(localUser.Uid)
|
||||
|
||||
req := PtyExecutionRequest{
|
||||
Shell: shell,
|
||||
|
||||
@@ -23,11 +23,11 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
"github.com/netbirdio/netbird/client/ssh/client"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
|
||||
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
|
||||
@@ -23,10 +23,10 @@ import (
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
sshauth "github.com/netbirdio/netbird/shared/sessionauth"
|
||||
"github.com/netbirdio/netbird/util/netrelay"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
@@ -197,14 +197,6 @@ type Config struct {
|
||||
|
||||
// HostKey is the SSH server host key in PEM format
|
||||
HostKeyPEM []byte
|
||||
|
||||
// NetstackNet, when non-nil, makes the SSH server listen via the
|
||||
// supplied userspace network stack instead of an OS socket.
|
||||
NetstackNet *netstack.Net
|
||||
|
||||
// NetworkValidation, when non-zero, restricts inbound connections to
|
||||
// peers inside the NetBird overlay defined by this WireGuard address.
|
||||
NetworkValidation wgaddr.Address
|
||||
}
|
||||
|
||||
// SessionInfo contains information about an active SSH session
|
||||
@@ -216,15 +208,12 @@ type SessionInfo struct {
|
||||
PortForwards []string
|
||||
}
|
||||
|
||||
// New creates an SSH server instance from the supplied Config. Fields are
|
||||
// read once at construction; mutating Config afterwards has no effect.
|
||||
// JWT == nil disables JWT authentication.
|
||||
// New creates an SSH server instance with the provided host key and optional JWT configuration
|
||||
// If jwtConfig is nil, JWT authentication is disabled
|
||||
func New(config *Config) *Server {
|
||||
s := &Server{
|
||||
mu: sync.RWMutex{},
|
||||
hostKeyPEM: config.HostKeyPEM,
|
||||
netstackNet: config.NetstackNet,
|
||||
wgAddress: config.NetworkValidation,
|
||||
sessions: make(map[sessionKey]*sessionState),
|
||||
pendingAuthJWT: make(map[authKey]string),
|
||||
remoteForwardListeners: make(map[forwardKey]net.Listener),
|
||||
@@ -445,6 +434,20 @@ func (s *Server) buildSessionInfo(state *sessionState) SessionInfo {
|
||||
return info
|
||||
}
|
||||
|
||||
// SetNetstackNet sets the netstack network for userspace networking
|
||||
func (s *Server) SetNetstackNet(net *netstack.Net) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.netstackNet = net
|
||||
}
|
||||
|
||||
// SetNetworkValidation configures network-based connection filtering
|
||||
func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.wgAddress = addr
|
||||
}
|
||||
|
||||
// UpdateSSHAuth updates the SSH fine-grained access control configuration
|
||||
// This should be called when network map updates include new SSH auth configuration
|
||||
func (s *Server) UpdateSSHAuth(config *sshauth.Config) {
|
||||
|
||||
@@ -3,11 +3,15 @@ package server
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/netbirdio/netbird/client/internal/shell"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -23,8 +27,8 @@ func isPlatformUnix() bool {
|
||||
|
||||
// Dependency injection variables for testing - allows mocking dynamic runtime checks
|
||||
var (
|
||||
getCurrentUser = currentUserWithGetent
|
||||
lookupUser = lookupWithGetent
|
||||
getCurrentUser = shell.CurrentUserWithGetent
|
||||
lookupUser = shell.LookupWithGetent
|
||||
getCurrentOS = func() string { return runtime.GOOS }
|
||||
getIsProcessPrivileged = isCurrentProcessPrivileged
|
||||
|
||||
@@ -409,3 +413,29 @@ func isWindowsElevated() bool {
|
||||
log.Debugf("Windows user switching not supported: not running as privileged user (current: %s)", currentUser.Uid)
|
||||
return false
|
||||
}
|
||||
|
||||
// prepareSSHEnv prepares SSH protocol-specific environment variables
|
||||
// These variables provide information about the SSH connection itself
|
||||
func prepareSSHEnv(session ssh.Session) []string {
|
||||
remoteAddr := session.RemoteAddr()
|
||||
localAddr := session.LocalAddr()
|
||||
|
||||
remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
|
||||
if err != nil {
|
||||
remoteHost = remoteAddr.String()
|
||||
remotePort = "0"
|
||||
}
|
||||
|
||||
localHost, localPort, err := net.SplitHostPort(localAddr.String())
|
||||
if err != nil {
|
||||
localHost = localAddr.String()
|
||||
localPort = strconv.Itoa(InternalSSHPort)
|
||||
}
|
||||
|
||||
return []string{
|
||||
// SSH_CLIENT format: "client_ip client_port server_port"
|
||||
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
|
||||
// SSH_CONNECTION format: "client_ip client_port server_ip server_port"
|
||||
fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/netbirdio/netbird/client/internal/shell"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -160,7 +161,7 @@ func (s *Server) parseUserCredentials(localUser *user.User) (uint32, uint32, []u
|
||||
// getSupplementaryGroups retrieves supplementary group IDs for a user.
|
||||
// Uses id/getent fallback for NSS users in CGO_ENABLED=0 builds.
|
||||
func (s *Server) getSupplementaryGroups(u *user.User) ([]uint32, error) {
|
||||
groupIDStrings, err := groupIdsWithFallback(u)
|
||||
groupIDStrings, err := shell.GroupIdsWithFallback(u)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group IDs for user %s: %w", u.Username, err)
|
||||
}
|
||||
@@ -196,7 +197,7 @@ func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, l
|
||||
GID: gid,
|
||||
Groups: groups,
|
||||
WorkingDir: localUser.HomeDir,
|
||||
Shell: getUserShell(localUser.Uid),
|
||||
Shell: shell.GetUserShell(localUser.Uid),
|
||||
Command: session.RawCommand(),
|
||||
PTY: hasPty,
|
||||
}
|
||||
@@ -228,7 +229,7 @@ func (s *Server) createPtyCommand(privilegeResult PrivilegeCheckResult, ptyReq s
|
||||
func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.User, ptyReq ssh.Pty) *exec.Cmd {
|
||||
log.Debugf("creating direct Pty command for user %s (no user switching needed)", localUser.Username)
|
||||
|
||||
shell := getUserShell(localUser.Uid)
|
||||
shell := shell.GetUserShell(localUser.Uid)
|
||||
args := s.getShellCommandArgs(shell, session.RawCommand())
|
||||
|
||||
cmd := s.createShellCommand(session.Context(), shell, args)
|
||||
@@ -245,12 +246,12 @@ func (s *Server) preparePtyEnv(localUser *user.User, ptyReq ssh.Pty, session ssh
|
||||
termType = "xterm-256color"
|
||||
}
|
||||
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
env = append(env, fmt.Sprintf("TERM=%s", termType))
|
||||
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
if shell.AcceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/shell"
|
||||
)
|
||||
|
||||
// validateUsername validates Windows usernames according to SAM Account Name rules
|
||||
@@ -104,7 +106,7 @@ func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, l
|
||||
func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session, localUser *user.User) (*exec.Cmd, func(), error) {
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
|
||||
shell := getUserShell(localUser.Uid)
|
||||
sh := shell.GetUserShell(localUser.Uid)
|
||||
|
||||
rawCmd := session.RawCommand()
|
||||
var command string
|
||||
@@ -116,7 +118,7 @@ func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session,
|
||||
Username: username,
|
||||
Domain: domain,
|
||||
WorkingDir: localUser.HomeDir,
|
||||
Shell: shell,
|
||||
Shell: sh,
|
||||
Command: command,
|
||||
}
|
||||
|
||||
|
||||
@@ -131,19 +131,6 @@ type SSHServerStateOutput struct {
|
||||
Sessions []SSHSessionOutput `json:"sessions" yaml:"sessions"`
|
||||
}
|
||||
|
||||
type VNCSessionOutput struct {
|
||||
RemoteAddress string `json:"remoteAddress" yaml:"remoteAddress"`
|
||||
Mode string `json:"mode" yaml:"mode"`
|
||||
Username string `json:"username,omitempty" yaml:"username,omitempty"`
|
||||
UserID string `json:"userID,omitempty" yaml:"userID,omitempty"`
|
||||
Initiator string `json:"initiator,omitempty" yaml:"initiator,omitempty"`
|
||||
}
|
||||
|
||||
type VNCServerStateOutput struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Sessions []VNCSessionOutput `json:"sessions" yaml:"sessions"`
|
||||
}
|
||||
|
||||
type OutputOverview struct {
|
||||
Peers PeersStateOutput `json:"peers" yaml:"peers"`
|
||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
||||
@@ -167,7 +154,6 @@ type OutputOverview struct {
|
||||
LazyConnectionEnabled bool `json:"lazyConnectionEnabled" yaml:"lazyConnectionEnabled"`
|
||||
ProfileName string `json:"profileName" yaml:"profileName"`
|
||||
SSHServerState SSHServerStateOutput `json:"sshServer" yaml:"sshServer"`
|
||||
VNCServerState VNCServerStateOutput `json:"vncServer" yaml:"vncServer"`
|
||||
}
|
||||
|
||||
// ConvertToStatusOutputOverview converts protobuf status to the output overview.
|
||||
@@ -188,7 +174,6 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
|
||||
|
||||
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
||||
sshServerOverview := mapSSHServer(pbFullStatus.GetSshServerState())
|
||||
vncServerOverview := mapVNCServer(pbFullStatus.GetVncServerState())
|
||||
peersOverview := mapPeers(pbFullStatus.GetPeers(), opts.StatusFilter, opts.PrefixNamesFilter, opts.PrefixNamesFilterMap, opts.IPsFilter, opts.ConnectionTypeFilter)
|
||||
|
||||
overview := OutputOverview{
|
||||
@@ -214,7 +199,6 @@ func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, opts ConvertO
|
||||
LazyConnectionEnabled: pbFullStatus.GetLazyConnectionEnabled(),
|
||||
ProfileName: opts.ProfileName,
|
||||
SSHServerState: sshServerOverview,
|
||||
VNCServerState: vncServerOverview,
|
||||
}
|
||||
|
||||
if opts.Anonymize {
|
||||
@@ -289,26 +273,6 @@ func mapSSHServer(sshServerState *proto.SSHServerState) SSHServerStateOutput {
|
||||
}
|
||||
}
|
||||
|
||||
func mapVNCServer(state *proto.VNCServerState) VNCServerStateOutput {
|
||||
if state == nil {
|
||||
return VNCServerStateOutput{Sessions: []VNCSessionOutput{}}
|
||||
}
|
||||
sessions := make([]VNCSessionOutput, 0, len(state.GetSessions()))
|
||||
for _, sess := range state.GetSessions() {
|
||||
sessions = append(sessions, VNCSessionOutput{
|
||||
RemoteAddress: sess.GetRemoteAddress(),
|
||||
Mode: sess.GetMode(),
|
||||
Username: sess.GetUsername(),
|
||||
UserID: sess.GetUserID(),
|
||||
Initiator: sess.GetInitiator(),
|
||||
})
|
||||
}
|
||||
return VNCServerStateOutput{
|
||||
Enabled: state.GetEnabled(),
|
||||
Sessions: sessions,
|
||||
}
|
||||
}
|
||||
|
||||
func mapPeers(
|
||||
peers []*proto.PeerState,
|
||||
statusFilter string,
|
||||
@@ -571,26 +535,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
}
|
||||
}
|
||||
|
||||
vncServerStatus := "Disabled"
|
||||
if o.VNCServerState.Enabled {
|
||||
vncSessionCount := len(o.VNCServerState.Sessions)
|
||||
if vncSessionCount > 0 {
|
||||
sessionWord := "session"
|
||||
if vncSessionCount > 1 {
|
||||
sessionWord = "sessions"
|
||||
}
|
||||
vncServerStatus = fmt.Sprintf("Enabled (%d active %s)", vncSessionCount, sessionWord)
|
||||
} else {
|
||||
vncServerStatus = "Enabled"
|
||||
}
|
||||
|
||||
if showSSHSessions && vncSessionCount > 0 {
|
||||
for _, sess := range o.VNCServerState.Sessions {
|
||||
vncServerStatus += "\n " + formatVNCSessionLine(sess)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
|
||||
|
||||
var forwardingRulesString string
|
||||
@@ -637,7 +581,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
"Quantum resistance: %s\n"+
|
||||
"Lazy connection: %s\n"+
|
||||
"SSH Server: %s\n"+
|
||||
"VNC Server: %s\n"+
|
||||
"Networks: %s\n"+
|
||||
"%s"+
|
||||
"Peers count: %s\n",
|
||||
@@ -657,7 +600,6 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
|
||||
rosenpassEnabledStatus,
|
||||
lazyConnectionEnabledStatus,
|
||||
sshServerStatus,
|
||||
vncServerStatus,
|
||||
networks,
|
||||
forwardingRulesString,
|
||||
peersCountString,
|
||||
@@ -1017,26 +959,6 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *PeerStateDetailOutput) {
|
||||
}
|
||||
}
|
||||
|
||||
// formatVNCSessionLine renders a single VNC session row for the detailed
|
||||
// status output. The leading slot identifies the initiator (display name
|
||||
// when known, hashed UserID otherwise); the post-arrow slot is the OS
|
||||
// user the session targets and is omitted in attach mode where the
|
||||
// destination is the current console user (unknown to the daemon).
|
||||
func formatVNCSessionLine(sess VNCSessionOutput) string {
|
||||
who := sess.Initiator
|
||||
if who == "" {
|
||||
who = sess.UserID
|
||||
}
|
||||
prefix := sess.RemoteAddress
|
||||
if who != "" {
|
||||
prefix = fmt.Sprintf("%s@%s", who, sess.RemoteAddress)
|
||||
}
|
||||
if sess.Username != "" {
|
||||
return fmt.Sprintf("[%s -> %s] mode=%s", prefix, sess.Username, sess.Mode)
|
||||
}
|
||||
return fmt.Sprintf("[%s] mode=%s", prefix, sess.Mode)
|
||||
}
|
||||
|
||||
func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
for i, peer := range overview.Peers.Details {
|
||||
peer := peer
|
||||
@@ -1057,19 +979,6 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
overview.Relays.Details[i] = detail
|
||||
}
|
||||
|
||||
anonymizeNSServerGroups(a, overview)
|
||||
|
||||
for i, route := range overview.Networks {
|
||||
overview.Networks[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||
|
||||
anonymizeEvents(a, overview)
|
||||
anonymizeServerSessions(a, overview)
|
||||
}
|
||||
|
||||
func anonymizeNSServerGroups(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
for i, nsGroup := range overview.NSServerGroups {
|
||||
for j, domain := range nsGroup.Domains {
|
||||
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
|
||||
@@ -1081,9 +990,13 @@ func anonymizeNSServerGroups(a *anonymize.Anonymizer, overview *OutputOverview)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeEvents(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
for i, route := range overview.Networks {
|
||||
overview.Networks[i] = a.AnonymizeRoute(route)
|
||||
}
|
||||
|
||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
||||
|
||||
for i, event := range overview.Events {
|
||||
overview.Events[i].Message = a.AnonymizeString(event.Message)
|
||||
overview.Events[i].UserMessage = a.AnonymizeString(event.UserMessage)
|
||||
@@ -1092,24 +1005,13 @@ func anonymizeEvents(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
event.Metadata[k] = a.AnonymizeString(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func anonymizeRemoteAddress(a *anonymize.Anonymizer, addr string) string {
|
||||
if host, port, err := net.SplitHostPort(addr); err == nil {
|
||||
return fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
|
||||
}
|
||||
return a.AnonymizeIPString(addr)
|
||||
}
|
||||
|
||||
func anonymizeServerSessions(a *anonymize.Anonymizer, overview *OutputOverview) {
|
||||
for i, session := range overview.SSHServerState.Sessions {
|
||||
overview.SSHServerState.Sessions[i].RemoteAddress = anonymizeRemoteAddress(a, session.RemoteAddress)
|
||||
if host, port, err := net.SplitHostPort(session.RemoteAddress); err == nil {
|
||||
overview.SSHServerState.Sessions[i].RemoteAddress = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
|
||||
} else {
|
||||
overview.SSHServerState.Sessions[i].RemoteAddress = a.AnonymizeIPString(session.RemoteAddress)
|
||||
}
|
||||
overview.SSHServerState.Sessions[i].Command = a.AnonymizeString(session.Command)
|
||||
}
|
||||
for i, sess := range overview.VNCServerState.Sessions {
|
||||
overview.VNCServerState.Sessions[i].RemoteAddress = anonymizeRemoteAddress(a, sess.RemoteAddress)
|
||||
overview.VNCServerState.Sessions[i].Username = a.AnonymizeString(sess.Username)
|
||||
overview.VNCServerState.Sessions[i].UserID = a.AnonymizeString(sess.UserID)
|
||||
overview.VNCServerState.Sessions[i].Initiator = a.AnonymizeString(sess.Initiator)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -242,10 +242,6 @@ var overview = OutputOverview{
|
||||
Enabled: false,
|
||||
Sessions: []SSHSessionOutput{},
|
||||
},
|
||||
VNCServerState: VNCServerStateOutput{
|
||||
Enabled: false,
|
||||
Sessions: []VNCSessionOutput{},
|
||||
},
|
||||
}
|
||||
|
||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
||||
@@ -411,10 +407,6 @@ func TestParsingToJSON(t *testing.T) {
|
||||
"sshServer":{
|
||||
"enabled":false,
|
||||
"sessions":[]
|
||||
},
|
||||
"vncServer":{
|
||||
"enabled":false,
|
||||
"sessions":[]
|
||||
}
|
||||
}`
|
||||
// @formatter:on
|
||||
@@ -525,9 +517,6 @@ profileName: ""
|
||||
sshServer:
|
||||
enabled: false
|
||||
sessions: []
|
||||
vncServer:
|
||||
enabled: false
|
||||
sessions: []
|
||||
`
|
||||
|
||||
assert.Equal(t, expectedYAML, yaml)
|
||||
@@ -598,7 +587,6 @@ Wireguard port: %d
|
||||
Quantum resistance: false
|
||||
Lazy connection: false
|
||||
SSH Server: Disabled
|
||||
VNC Server: Disabled
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion, overview.WgPort)
|
||||
@@ -625,7 +613,6 @@ Wireguard port: 51820
|
||||
Quantum resistance: false
|
||||
Lazy connection: false
|
||||
SSH Server: Disabled
|
||||
VNC Server: Disabled
|
||||
Networks: 10.10.0.0/24
|
||||
Peers count: 2/2 Connected
|
||||
`
|
||||
|
||||
@@ -62,7 +62,6 @@ type Info struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed bool
|
||||
ServerVNCAllowed bool
|
||||
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
@@ -84,7 +83,6 @@ type Info struct {
|
||||
func (i *Info) SetFlags(
|
||||
rosenpassEnabled, rosenpassPermissive bool,
|
||||
serverSSHAllowed *bool,
|
||||
serverVNCAllowed *bool,
|
||||
disableClientRoutes, disableServerRoutes,
|
||||
disableDNS, disableFirewall, blockLANAccess, blockInbound, disableIPv6, lazyConnectionEnabled bool,
|
||||
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
|
||||
@@ -95,9 +93,6 @@ func (i *Info) SetFlags(
|
||||
if serverSSHAllowed != nil {
|
||||
i.ServerSSHAllowed = *serverSSHAllowed
|
||||
}
|
||||
if serverVNCAllowed != nil {
|
||||
i.ServerVNCAllowed = *serverVNCAllowed
|
||||
}
|
||||
|
||||
i.DisableClientRoutes = disableClientRoutes
|
||||
i.DisableServerRoutes = disableServerRoutes
|
||||
|
||||
@@ -1,259 +0,0 @@
|
||||
//go:build !(linux && 386)
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"fyne.io/fyne/v2"
|
||||
"fyne.io/fyne/v2/container"
|
||||
"fyne.io/fyne/v2/widget"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/approval"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// Approval metadata that is remote-peer or dashboard controlled is passed to
|
||||
// the forked netbird-ui via environment variables rather than argv, so it is
|
||||
// not exposed to other local users through ps.
|
||||
const (
|
||||
envApprovalInitiator = "NB_APPROVAL_INITIATOR"
|
||||
envApprovalPeerName = "NB_APPROVAL_PEER_NAME"
|
||||
envApprovalSourceIP = "NB_APPROVAL_SOURCE_IP"
|
||||
envApprovalUsername = "NB_APPROVAL_USERNAME"
|
||||
envApprovalKeyFingerprint = "NB_APPROVAL_KEY_FINGERPRINT"
|
||||
envApprovalSubject = "NB_APPROVAL_SUBJECT"
|
||||
)
|
||||
|
||||
// handleApprovalEvent forks a netbird-ui child process to render the
|
||||
// dialog on its own fyne main loop. Top-level windows opened from a
|
||||
// background goroutine of the tray process don't render reliably on
|
||||
// Linux/GTK, so the rest of the UI (settings, login URL, update) uses
|
||||
// the same fork pattern.
|
||||
func (s *serviceClient) handleApprovalEvent(ev *proto.SystemEvent) {
|
||||
if ev == nil || ev.Category != proto.SystemEvent_APPROVAL {
|
||||
return
|
||||
}
|
||||
requestID := ev.Metadata["request_id"]
|
||||
if requestID == "" {
|
||||
log.Warnf("approval event missing request_id: %v", ev.Metadata)
|
||||
return
|
||||
}
|
||||
|
||||
// Only the request id, kind, and deadline stay on argv: they are
|
||||
// daemon-issued and non-sensitive. The remote-influenced fields go
|
||||
// through the child's environment.
|
||||
args := []string{
|
||||
"--approval-request-id=" + requestID,
|
||||
"--approval-kind=" + ev.Metadata["kind"],
|
||||
"--approval-expires-at=" + ev.Metadata["expires_at"],
|
||||
}
|
||||
env := append(os.Environ(),
|
||||
envApprovalInitiator+"="+ev.Metadata["initiator"],
|
||||
envApprovalPeerName+"="+ev.Metadata["peer_name"],
|
||||
envApprovalSourceIP+"="+ev.Metadata["source_ip"],
|
||||
envApprovalUsername+"="+ev.Metadata["username"],
|
||||
envApprovalKeyFingerprint+"="+ev.Metadata["peer_pubkey"],
|
||||
envApprovalSubject+"="+ev.UserMessage,
|
||||
)
|
||||
go s.runApprovalCommand(s.ctx, env, args)
|
||||
}
|
||||
|
||||
// runApprovalCommand forks netbird-ui to render the approval dialog,
|
||||
// inheriting the parent environment plus the approval-specific variables. It
|
||||
// mirrors runSelfCommand but sets cmd.Env so the sensitive metadata never
|
||||
// appears on the child's argv.
|
||||
func (s *serviceClient) runApprovalCommand(ctx context.Context, env, args []string) {
|
||||
proc, err := os.Executable()
|
||||
if err != nil {
|
||||
log.Errorf("get executable path: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
cmdArgs := append([]string{"--approval=true", "--daemon-addr=" + s.addr}, args...)
|
||||
cmd := exec.CommandContext(ctx, proc, cmdArgs...)
|
||||
cmd.Env = env
|
||||
|
||||
if out := s.attachOutput(cmd); out != nil {
|
||||
defer func() {
|
||||
if err := out.Close(); err != nil {
|
||||
log.Errorf("close log file %s: %v", s.logFile, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
log.Printf("running approval command: %s", cmd.String())
|
||||
if err := cmd.Run(); err != nil {
|
||||
var exitErr *exec.ExitError
|
||||
if errors.As(err, &exitErr) {
|
||||
log.Printf("approval command failed with exit code %d", exitErr.ExitCode())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// showApprovalUI runs the dialog on the forked process's fyne main loop
|
||||
// and forwards the user's decision to the daemon via RespondApproval.
|
||||
func (s *serviceClient) showApprovalUI(req approvalRequest) {
|
||||
w := s.app.NewWindow(approvalTitle(req.kind))
|
||||
w.Resize(fyne.NewSize(480, 260))
|
||||
w.CenterOnScreen()
|
||||
w.RequestFocus()
|
||||
|
||||
var rows []string
|
||||
if req.initiator != "" {
|
||||
// The display name comes from the management dashboard and is
|
||||
// not cryptographically asserted by the connecting client. The
|
||||
// key fingerprint that follows IS: it's the Noise_IK static
|
||||
// public key the client just proved possession of. Show both
|
||||
// so the user can sanity-check that "Alice" is really the
|
||||
// Alice they trust.
|
||||
rows = append(rows, "From user: "+req.initiator)
|
||||
}
|
||||
if fp := approval.ShortKeyFingerprint(req.keyFingerprint); fp != "" {
|
||||
rows = append(rows, "Key fp: "+fp)
|
||||
}
|
||||
if req.peerName != "" {
|
||||
rows = append(rows, "Via peer: "+req.peerName)
|
||||
}
|
||||
if req.sourceIP != "" && req.sourceIP != req.peerName {
|
||||
rows = append(rows, "Source IP: "+req.sourceIP)
|
||||
}
|
||||
if req.username != "" {
|
||||
rows = append(rows, "OS user: "+req.username)
|
||||
}
|
||||
if len(rows) == 0 {
|
||||
rows = []string{"Remote: " + req.displayPeer()}
|
||||
}
|
||||
body := strings.Join(rows, "\n")
|
||||
bodyLabel := widget.NewLabel(body)
|
||||
bodyLabel.Wrapping = fyne.TextWrapWord
|
||||
|
||||
countdown := widget.NewLabel("")
|
||||
deadline := req.deadline()
|
||||
updateCountdown := func() {
|
||||
remaining := time.Until(deadline).Round(time.Second)
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
countdown.SetText(fmt.Sprintf("Auto-deny in %s", remaining))
|
||||
}
|
||||
updateCountdown()
|
||||
|
||||
type outcome struct {
|
||||
accept bool
|
||||
viewOnly bool
|
||||
}
|
||||
decided := make(chan outcome, 1)
|
||||
decide := func(o outcome) {
|
||||
select {
|
||||
case decided <- o:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
allow := widget.NewButton("Allow", func() { decide(outcome{accept: true}) })
|
||||
allow.Importance = widget.HighImportance
|
||||
allowView := widget.NewButton("Allow (view only)", func() { decide(outcome{accept: true, viewOnly: true}) })
|
||||
deny := widget.NewButton("Deny", func() { decide(outcome{accept: false}) })
|
||||
|
||||
header := widget.NewLabelWithStyle(req.subject, fyne.TextAlignLeading, fyne.TextStyle{Bold: true})
|
||||
buttonRow := container.NewGridWithColumns(3, allow, allowView, deny)
|
||||
info := container.NewVBox(header, widget.NewSeparator(), bodyLabel, widget.NewSeparator(), countdown)
|
||||
w.SetContent(container.NewPadded(container.NewBorder(nil, buttonRow, nil, nil, info)))
|
||||
w.SetCloseIntercept(func() { decide(outcome{}) })
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
if time.Until(deadline) <= 0 {
|
||||
decide(outcome{})
|
||||
return
|
||||
}
|
||||
fyne.Do(updateCountdown)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
o := <-decided
|
||||
s.sendApprovalResponse(req.requestID, o.accept, o.viewOnly)
|
||||
fyne.Do(func() {
|
||||
w.Close()
|
||||
s.app.Quit()
|
||||
})
|
||||
}()
|
||||
|
||||
w.Show()
|
||||
}
|
||||
|
||||
func (s *serviceClient) sendApprovalResponse(requestID string, accept, viewOnly bool) {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
log.Warnf("approval response: get daemon client: %v", err)
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(s.ctx, defaultFailTimeout)
|
||||
defer cancel()
|
||||
if _, err := conn.RespondApproval(ctx, &proto.RespondApprovalRequest{
|
||||
RequestId: requestID,
|
||||
Accept: accept,
|
||||
ViewOnly: viewOnly,
|
||||
}); err != nil {
|
||||
log.Warnf("approval response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// approvalRequest is the parsed --approval-* CLI args that the forked
|
||||
// dialog process consumes.
|
||||
type approvalRequest struct {
|
||||
requestID string
|
||||
kind string
|
||||
initiator string
|
||||
peerName string
|
||||
sourceIP string
|
||||
username string
|
||||
subject string
|
||||
expiresAt string
|
||||
keyFingerprint string
|
||||
}
|
||||
|
||||
func (r approvalRequest) displayPeer() string {
|
||||
switch {
|
||||
case r.initiator != "":
|
||||
return r.initiator
|
||||
case r.peerName != "":
|
||||
return r.peerName
|
||||
case r.sourceIP != "":
|
||||
return r.sourceIP
|
||||
default:
|
||||
return "unknown peer"
|
||||
}
|
||||
}
|
||||
|
||||
// deadline returns the wall-clock auto-deny moment. Falls back to a short
|
||||
// local window when the daemon's expires_at is missing/unparsable, so a
|
||||
// stale value never leaves the dialog open indefinitely.
|
||||
func (r approvalRequest) deadline() time.Time {
|
||||
if t, err := time.Parse(time.RFC3339, r.expiresAt); err == nil {
|
||||
return t
|
||||
}
|
||||
return time.Now().Add(13 * time.Second)
|
||||
}
|
||||
|
||||
func approvalTitle(kind string) string {
|
||||
switch kind {
|
||||
case "vnc":
|
||||
return "Allow VNC Connection?"
|
||||
case "ssh":
|
||||
return "Allow SSH Connection?"
|
||||
default:
|
||||
return "Allow Incoming Connection?"
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user