mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-12 19:09:54 +00:00
Compare commits
3 Commits
feature/af
...
socket-grp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d65927275d | ||
|
|
064f7bf0fd | ||
|
|
644615fed6 |
@@ -3,14 +3,12 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/user"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
@@ -87,73 +85,6 @@ var persistenceCmd = &cobra.Command{
|
||||
RunE: setSyncResponsePersistence,
|
||||
}
|
||||
|
||||
var debugConfigCmd = &cobra.Command{
|
||||
Use: "config",
|
||||
Example: " netbird debug config",
|
||||
Short: "Dump the effective configuration",
|
||||
Long: "Prints the daemon's resolved configuration (after applying defaults, file, env, CLI input, and MDM policy overrides) as JSON. Includes the list of MDM-managed fields.",
|
||||
RunE: debugConfigDump,
|
||||
}
|
||||
|
||||
// debugConfigDump implements `netbird debug config`. It resolves the
|
||||
// active profile, queries the daemon for the effective configuration
|
||||
// via GetConfig, and prints the resulting GetConfigResponse as JSON
|
||||
// (via protojson with EmitUnpopulated=true so the output is stable
|
||||
// across runs and includes zero-valued fields).
|
||||
//
|
||||
// Useful for verifying MDM enforcement end-to-end: the response's
|
||||
// mDMManagedFields array is the single source of truth for "which
|
||||
// fields is the daemon currently enforcing from the MDM source", and
|
||||
// every config field side-by-side with that list confirms the merge
|
||||
// result. Secrets in the response (e.g. PreSharedKey) are already
|
||||
// redacted by the daemon-side handler.
|
||||
func debugConfigDump(cmd *cobra.Command, _ []string) error {
|
||||
pm := profilemanager.NewProfileManager()
|
||||
activeProf, err := pm.GetActiveProfile()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile: %v", err)
|
||||
}
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %v", err)
|
||||
}
|
||||
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.Name,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get config: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
// Use protojson so well-known fields render correctly; emit defaults so
|
||||
// the operator sees every field even when zero/empty.
|
||||
m := protojson.MarshalOptions{Multiline: true, Indent: " ", EmitUnpopulated: true}
|
||||
out, err := m.Marshal(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
cmd.Println(string(out))
|
||||
return nil
|
||||
}
|
||||
|
||||
// debugBundle requests the daemon to create a debug bundle and prints
|
||||
// the resulting local file path and, if uploaded, the uploaded file
|
||||
// key. It uses the package flags (anonymize, system info, log file
|
||||
// count, CLI version, optional upload URL) to configure the bundle
|
||||
// request. Returns an error if the RPC fails or if the daemon reports
|
||||
// an upload failure reason.
|
||||
func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,301 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
KubernetesDNSSuffix = "netbird-kubeapi-proxy"
|
||||
)
|
||||
|
||||
var kubernetesCmd = &cobra.Command{
|
||||
Use: "kubernetes",
|
||||
Short: "Kubernetes cluster commands.",
|
||||
Long: "Kubernetes cluster commands.",
|
||||
}
|
||||
|
||||
var kubernetesListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
RunE: kubernetesList,
|
||||
Short: "List Kubernetes clusters.",
|
||||
Long: "List Kubernetes clusters by discovering NetBird peers running netbird-kubeapi-proxy.",
|
||||
}
|
||||
|
||||
var kubernetesWriteKubeconfigCmd = &cobra.Command{
|
||||
Use: "write-kubeconfig",
|
||||
RunE: kubernetesWriteKubeconfig,
|
||||
Args: cobra.ExactArgs(1),
|
||||
Short: "Write kubeconfig for a Kubernetes cluster.",
|
||||
Long: "Updates kubeconfig in place to allow token-less access to the Kubernetes cluster through NetBird.",
|
||||
}
|
||||
|
||||
func init() {
|
||||
kubernetesWriteKubeconfigCmd.Flags().String("kubeconfig", "", "path to kubeconfig file")
|
||||
}
|
||||
|
||||
func kubernetesList(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
statusResp, err := client.Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
kcs, err := getKubernetesClusters(cmd.Context(), statusResp.FullStatus.Peers, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(kcs) == 0 {
|
||||
cmd.Println("No Kubernetes clusters available.")
|
||||
return nil
|
||||
}
|
||||
cmd.Println("Available Kubernetes clusters:")
|
||||
for _, k := range kcs {
|
||||
cmd.Printf("\n - Name: %s\n FQDN: %s\n Version: %s\n", k.name, k.url.Host, k.version)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func kubernetesWriteKubeconfig(cmd *cobra.Command, args []string) error {
|
||||
kubeconfigPath, err := resolveKubeconfigPath(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
statusResp, err := client.Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clusterName := args[0]
|
||||
kcs, err := getKubernetesClusters(cmd.Context(), statusResp.FullStatus.Peers, clusterName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(kcs) == 0 {
|
||||
return fmt.Errorf("kubernetes cluster named %s not found", clusterName)
|
||||
}
|
||||
if len(kcs) > 1 {
|
||||
return fmt.Errorf("too many Kubernetes clusters returned")
|
||||
}
|
||||
err = writeKubeconfig(kubeconfigPath, kcs[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type kubernetesCluster struct {
|
||||
name string
|
||||
url *url.URL
|
||||
version string
|
||||
}
|
||||
|
||||
func getKubernetesClusters(ctx context.Context, peers []*proto.PeerState, nameFilter string) ([]kubernetesCluster, error) {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
httpClient := &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
resolver := net.Resolver{
|
||||
// Required so both DNS records are returned.
|
||||
// https://github.com/golang/go/issues/17093
|
||||
PreferGo: true,
|
||||
}
|
||||
|
||||
kcs := []kubernetesCluster{}
|
||||
attempted := map[string]struct{}{}
|
||||
for _, peer := range peers {
|
||||
fqdns, err := resolver.LookupAddr(ctx, peer.IP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, fqdn := range fqdns {
|
||||
if _, ok := attempted[fqdn]; ok {
|
||||
continue
|
||||
}
|
||||
attempted[fqdn] = struct{}{}
|
||||
comps := strings.Split(fqdn, ".")
|
||||
if len(comps) < 2 {
|
||||
continue
|
||||
}
|
||||
if comps[1] != KubernetesDNSSuffix {
|
||||
continue
|
||||
}
|
||||
if nameFilter != "" && nameFilter != comps[0] {
|
||||
continue
|
||||
}
|
||||
clusterURL, clusterVersion, err := fingerprintClusters(ctx, httpClient, fqdn)
|
||||
if err != nil {
|
||||
log.Debugf("could not fingerprint Kubernetes cluster %s %q", fqdn, err)
|
||||
continue
|
||||
}
|
||||
kc := kubernetesCluster{
|
||||
name: comps[0],
|
||||
url: clusterURL,
|
||||
version: clusterVersion,
|
||||
}
|
||||
if nameFilter != "" {
|
||||
return []kubernetesCluster{kc}, nil
|
||||
}
|
||||
kcs = append(kcs, kc)
|
||||
}
|
||||
}
|
||||
return kcs, nil
|
||||
}
|
||||
|
||||
func fingerprintClusters(ctx context.Context, httpClient *http.Client, fqdn string) (*url.URL, string, error) {
|
||||
clusterURL, err := url.Parse("https://" + fqdn)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
versionURL, err := clusterURL.Parse("/version")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, versionURL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, "", fmt.Errorf("expected %d response but got %s", http.StatusOK, resp.Status)
|
||||
}
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
versionData := map[string]string{}
|
||||
err = json.Unmarshal(b, &versionData)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
version, ok := versionData["gitVersion"]
|
||||
if !ok {
|
||||
return nil, "", errors.New("no version found in response")
|
||||
}
|
||||
return clusterURL, version, nil
|
||||
}
|
||||
|
||||
func resolveKubeconfigPath(cmd *cobra.Command) (string, error) {
|
||||
if cmd.Flags().Changed("kubeconfig") {
|
||||
path, err := cmd.Flags().GetString("kubeconfig")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
if env := os.Getenv("KUBECONFIG"); env != "" {
|
||||
return env, nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not determine home directory: %w", err)
|
||||
}
|
||||
return filepath.Join(home, ".kube", "config"), nil
|
||||
}
|
||||
|
||||
func writeKubeconfig(kubeconfigPath string, kc kubernetesCluster) error {
|
||||
b, err := os.ReadFile(kubeconfigPath)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
var cfg map[string]any
|
||||
if err := yaml.Unmarshal(b, &cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if cfg == nil {
|
||||
cfg = map[string]any{
|
||||
"apiVersion": "v1",
|
||||
"kind": "Config",
|
||||
}
|
||||
}
|
||||
|
||||
cfg["clusters"] = appendWithName(cfg["clusters"], map[string]any{
|
||||
"name": kc.name,
|
||||
"cluster": map[string]any{
|
||||
"server": kc.url.String(),
|
||||
"insecure-skip-tls-verify": true,
|
||||
},
|
||||
})
|
||||
cfg["users"] = appendWithName(cfg["users"], map[string]any{
|
||||
"name": "netbird",
|
||||
"user": map[string]any{
|
||||
"token": "none",
|
||||
},
|
||||
})
|
||||
cfg["contexts"] = appendWithName(cfg["contexts"], map[string]any{
|
||||
"name": kc.name,
|
||||
"context": map[string]any{
|
||||
"cluster": kc.name,
|
||||
"user": "netbird",
|
||||
"namespace": "default",
|
||||
},
|
||||
})
|
||||
cfg["current-context"] = kc.name
|
||||
|
||||
out, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(kubeconfigPath, out, 0o600); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func appendWithName(data any, add map[string]any) any {
|
||||
if data == nil {
|
||||
return []any{add}
|
||||
}
|
||||
v, ok := data.([]any)
|
||||
if !ok {
|
||||
return []any{add}
|
||||
}
|
||||
i := slices.IndexFunc(v, func(item any) bool {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return m["name"] == add["name"]
|
||||
})
|
||||
if i == -1 {
|
||||
return append(v, add)
|
||||
}
|
||||
v[i] = add
|
||||
return v
|
||||
}
|
||||
@@ -1,120 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFingerprintClusters(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
//nolint: errcheck
|
||||
w.Write([]byte(`{"gitVersion": "foobar"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
clusterURL, clusterVersion, err := fingerprintClusters(t.Context(), srv.Client(), srv.Listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, srv.URL, clusterURL.String())
|
||||
require.Equal(t, "foobar", clusterVersion)
|
||||
}
|
||||
|
||||
func TestResolveKubeconfigPath(t *testing.T) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
t.Fatalf("could not determine home directory: %v", err)
|
||||
}
|
||||
defaultPath := filepath.Join(home, ".kube", "config")
|
||||
path, err := resolveKubeconfigPath(&cobra.Command{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, defaultPath, path)
|
||||
|
||||
flagPath := "flag-path"
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().String("kubeconfig", "", "")
|
||||
err = cmd.Flags().Set("kubeconfig", flagPath)
|
||||
require.NoError(t, err)
|
||||
path, err = resolveKubeconfigPath(cmd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, flagPath, path)
|
||||
|
||||
envPath := "env-path"
|
||||
t.Setenv("KUBECONFIG", envPath)
|
||||
path, err = resolveKubeconfigPath(&cobra.Command{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, envPath, path)
|
||||
}
|
||||
|
||||
func TestWriteKubeconfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
existing string
|
||||
}{
|
||||
{
|
||||
name: "empty file",
|
||||
},
|
||||
{
|
||||
name: "existing content",
|
||||
existing: `apiVersion: v1
|
||||
clusters:
|
||||
- cluster:
|
||||
insecure-skip-tls-verify: true
|
||||
server: https://foobar.com
|
||||
name: foo
|
||||
current-context: test
|
||||
kind: Config
|
||||
users: []
|
||||
`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
kubeconfigPath := filepath.Join(t.TempDir(), "config")
|
||||
err := os.WriteFile(kubeconfigPath, []byte(tt.existing), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
kc := kubernetesCluster{
|
||||
name: "foo",
|
||||
url: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
}
|
||||
err = writeKubeconfig(kubeconfigPath, kc)
|
||||
require.NoError(t, err)
|
||||
|
||||
b, err := os.ReadFile(kubeconfigPath)
|
||||
require.NoError(t, err)
|
||||
expected := `apiVersion: v1
|
||||
clusters:
|
||||
- cluster:
|
||||
insecure-skip-tls-verify: true
|
||||
server: https://example.com
|
||||
name: foo
|
||||
contexts:
|
||||
- context:
|
||||
cluster: foo
|
||||
namespace: default
|
||||
user: netbird
|
||||
name: foo
|
||||
current-context: foo
|
||||
kind: Config
|
||||
users:
|
||||
- name: netbird
|
||||
user:
|
||||
token: none
|
||||
`
|
||||
require.Equal(t, expected, string(b))
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
36
client/cmd/peercred_bsd.go
Normal file
36
client/cmd/peercred_bsd.go
Normal file
@@ -0,0 +1,36 @@
|
||||
//go:build darwin || freebsd
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// peerUID returns the uid of the process on the other end of a unix socket
|
||||
// connection, read via LOCAL_PEERCRED (xucred). Note: xucred carries the uid
|
||||
// and group list but no pid, so audit on these platforms is uid-based.
|
||||
func peerUID(c net.Conn) (int, error) {
|
||||
uc, ok := c.(*net.UnixConn)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("connection is not a unix socket: %T", c)
|
||||
}
|
||||
raw, err := uc.SyscallConn()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("raw conn: %w", err)
|
||||
}
|
||||
|
||||
var cred *unix.Xucred
|
||||
var credErr error
|
||||
if err := raw.Control(func(fd uintptr) {
|
||||
cred, credErr = unix.GetsockoptXucred(int(fd), unix.SOL_LOCAL, unix.LOCAL_PEERCRED)
|
||||
}); err != nil {
|
||||
return 0, fmt.Errorf("getsockopt control: %w", err)
|
||||
}
|
||||
if credErr != nil {
|
||||
return 0, fmt.Errorf("LOCAL_PEERCRED: %w", credErr)
|
||||
}
|
||||
return int(cred.Uid), nil
|
||||
}
|
||||
35
client/cmd/peercred_linux.go
Normal file
35
client/cmd/peercred_linux.go
Normal file
@@ -0,0 +1,35 @@
|
||||
//go:build linux
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// peerUID returns the uid of the process on the other end of a unix socket
|
||||
// connection, read from the kernel via SO_PEERCRED.
|
||||
func peerUID(c net.Conn) (int, error) {
|
||||
uc, ok := c.(*net.UnixConn)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("connection is not a unix socket: %T", c)
|
||||
}
|
||||
raw, err := uc.SyscallConn()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("raw conn: %w", err)
|
||||
}
|
||||
|
||||
var cred *unix.Ucred
|
||||
var credErr error
|
||||
if err := raw.Control(func(fd uintptr) {
|
||||
cred, credErr = unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED)
|
||||
}); err != nil {
|
||||
return 0, fmt.Errorf("getsockopt control: %w", err)
|
||||
}
|
||||
if credErr != nil {
|
||||
return 0, fmt.Errorf("SO_PEERCRED: %w", credErr)
|
||||
}
|
||||
return int(cred.Uid), nil
|
||||
}
|
||||
16
client/cmd/peercred_unsupported.go
Normal file
16
client/cmd/peercred_unsupported.go
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build !linux && !darwin && !freebsd
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// peerUID is unimplemented on this platform, so the trust-on-first-use socket
|
||||
// migration cannot run here. Configure --socket-owner explicitly, or use
|
||||
// --disable-strict-socket. (Windows uses a TCP socket and never reaches this.)
|
||||
func peerUID(net.Conn) (int, error) {
|
||||
return 0, fmt.Errorf("peer credential check not supported on %s", runtime.GOOS)
|
||||
}
|
||||
@@ -77,6 +77,8 @@ var (
|
||||
updateSettingsDisabled bool
|
||||
captureEnabled bool
|
||||
networksDisabled bool
|
||||
socketOwner string
|
||||
strictSocketDisabled bool
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
@@ -95,9 +97,7 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// Execute runs the appropriate Cobra command for the CLI.
|
||||
// If the process is the update binary it delegates to updateCmd; otherwise it runs the root command.
|
||||
// It returns any error produced during command execution.
|
||||
// Execute executes the root command.
|
||||
func Execute() error {
|
||||
if isUpdateBinary() {
|
||||
return updateCmd.Execute()
|
||||
@@ -105,16 +105,6 @@ func Execute() error {
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
// init initialises package-level defaults and configures the root
|
||||
// Cobra command tree. Sets platform-specific config / log directory
|
||||
// paths (including legacy Wiretrustee fallbacks) and a default daemon
|
||||
// address; registers persistent CLI flags (daemon address,
|
||||
// management / admin URLs, logging, setup key (file and inline,
|
||||
// mutually exclusive), preshared key, hostname, anonymise, config
|
||||
// path); attaches top-level and nested subcommands to the root
|
||||
// command; and registers `up`-specific persistent flags (external IP
|
||||
// maps, custom DNS resolver address, Rosenpass options, auto-connect
|
||||
// disabling, lazy connection).
|
||||
func init() {
|
||||
defaultConfigPathDir = "/etc/netbird/"
|
||||
defaultLogFileDir = "/var/log/netbird/"
|
||||
@@ -180,12 +170,6 @@ func init() {
|
||||
logCmd.AddCommand(logLevelCmd)
|
||||
debugCmd.AddCommand(forCmd)
|
||||
debugCmd.AddCommand(persistenceCmd)
|
||||
debugCmd.AddCommand(debugConfigCmd)
|
||||
|
||||
// kubernetes commands
|
||||
rootCmd.AddCommand(kubernetesCmd)
|
||||
kubernetesCmd.AddCommand(kubernetesListCmd)
|
||||
kubernetesCmd.AddCommand(kubernetesWriteKubeconfigCmd)
|
||||
|
||||
// profile commands
|
||||
profileCmd.AddCommand(profileListCmd)
|
||||
|
||||
@@ -57,6 +57,9 @@ func init() {
|
||||
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||
reconfigureCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)
|
||||
|
||||
serviceCmd.PersistentFlags().StringVar(&socketOwner, "socket-owner", "", "user to own the daemon control socket; restricts it to that user plus the netbird group (0660). If unset, the first client to connect claims ownership (trust-on-first-use)")
|
||||
serviceCmd.PersistentFlags().BoolVar(&strictSocketDisabled, "disable-strict-socket", false, "leave the daemon control socket world-writable (0666) instead of restricting it; set via the (root-only) service command")
|
||||
|
||||
rootCmd.AddCommand(serviceCmd)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,10 +4,15 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
@@ -16,6 +21,7 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/shell"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -54,10 +60,36 @@ func (p *program) Start(svc service.Service) error {
|
||||
go func() {
|
||||
defer listen.Close()
|
||||
|
||||
srvListener := listen
|
||||
if split[0] == "unix" {
|
||||
if err := os.Chmod(split[1], 0666); err != nil {
|
||||
log.Errorf("failed setting daemon permissions: %v", split[1])
|
||||
return
|
||||
owner := effectiveSocketOwner()
|
||||
switch {
|
||||
case strictSocketDisabled:
|
||||
// Opt-out (root-only, via service.json): leave it world-writable.
|
||||
if err := os.Chmod(split[1], 0666); err != nil {
|
||||
log.Errorf("failed setting daemon permissions: %v", split[1])
|
||||
return
|
||||
}
|
||||
case owner != "":
|
||||
// Seeded owner (flag, MDM, or persisted TOFU result): restrict
|
||||
// before serving so there is no open window.
|
||||
uid, err := lookupUser(owner)
|
||||
if err != nil {
|
||||
log.Errorf("lookup socket owner %q: %v", owner, err)
|
||||
return
|
||||
}
|
||||
if err := restrictSocket(split[1], uid); err != nil {
|
||||
log.Errorf("restrict socket to %q: %v", owner, err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
// Trust-on-first-use: open the socket now; tofuListener locks it
|
||||
// to the first caller's uid on the first connection.
|
||||
if err := os.Chmod(split[1], 0666); err != nil {
|
||||
log.Errorf("failed setting daemon permissions: %v", split[1])
|
||||
return
|
||||
}
|
||||
srvListener = &tofuListener{Listener: listen, path: split[1], owner: -1}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,13 +104,180 @@ func (p *program) Start(svc service.Service) error {
|
||||
p.serverInstanceMu.Unlock()
|
||||
|
||||
log.Printf("started daemon server: %v", split[1])
|
||||
if err := p.serv.Serve(listen); err != nil {
|
||||
if err := p.serv.Serve(srvListener); err != nil {
|
||||
log.Errorf("failed to serve daemon requests: %v", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func lookupUser(username string) (int, error) {
|
||||
u, err := shell.LookupWithGetent(username)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("lookup user %s: %w", username, err)
|
||||
}
|
||||
uid, err := strconv.Atoi(u.Uid)
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("parse uid %s: %w", u.Uid, err)
|
||||
}
|
||||
return uid, nil
|
||||
}
|
||||
|
||||
// addGroup creates a system group if it doesn't already exist and returns the gid.
|
||||
// Must run as root.
|
||||
func addGroup(name string) (int, error) {
|
||||
group, err := shell.LookupGroupWithGetent(name)
|
||||
if err == nil {
|
||||
gid, err := strconv.ParseInt(group.Gid, 10, 64)
|
||||
return int(gid), err
|
||||
}
|
||||
|
||||
// looup failed, create the group
|
||||
groupadd, err := exec.LookPath("groupadd")
|
||||
if err != nil {
|
||||
// Fallback for Alpine/BusyBox systems.
|
||||
if groupadd, err = exec.LookPath("addgroup"); err != nil {
|
||||
return -1, errors.New("neither groupadd nor addgroup found")
|
||||
}
|
||||
}
|
||||
|
||||
// Use --system for a service/daemon group (no login, low GID).
|
||||
out, err := exec.Command(groupadd, "--system", name).CombinedOutput()
|
||||
if err != nil {
|
||||
return -1, fmt.Errorf("create group %q: %w: %s", name, err, out)
|
||||
}
|
||||
if group, err := shell.LookupWithGetent(name); err == nil {
|
||||
gid, err := strconv.ParseInt(group.Gid, 10, 64)
|
||||
return int(gid), err
|
||||
}
|
||||
return -1, fmt.Errorf("lookup group %q: %w", name, err)
|
||||
}
|
||||
|
||||
// restrictSocket locks the unix socket down to the owner uid plus the netbird
|
||||
// group (0660). If the group cannot be created or applied, it fails closed to
|
||||
// owner-only 0600 — it never leaves the socket world-writable.
|
||||
func restrictSocket(path string, uid int) error {
|
||||
// TODO: introduce flag to configure this (LDAP/AD usecase)
|
||||
gid, err := addGroup("netbird")
|
||||
if err != nil {
|
||||
log.Errorf("create netbird group, failing closed to owner-only 0600: %v", err)
|
||||
return chownChmod(path, uid, -1, 0600)
|
||||
}
|
||||
if err := chownChmod(path, uid, gid, 0660); err != nil {
|
||||
log.Errorf("apply netbird group to socket, failing closed to owner-only 0600: %v", err)
|
||||
return chownChmod(path, uid, -1, 0600)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// chownChmod sets ownership and mode on the socket. A gid of -1 leaves the
|
||||
// group unchanged.
|
||||
func chownChmod(path string, uid, gid int, mode os.FileMode) error {
|
||||
if err := os.Chown(path, uid, gid); err != nil {
|
||||
return fmt.Errorf("chown socket %s: %w", path, err)
|
||||
}
|
||||
if err := os.Chmod(path, mode); err != nil {
|
||||
return fmt.Errorf("chmod socket %s: %w", path, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// tofuListener implements trust-on-first-use for the daemon control socket.
|
||||
// The socket starts world-writable; the first caller's uid (read via the
|
||||
// platform peer-credential mechanism) becomes the owner. On that first
|
||||
// connection the socket is restricted (see restrictSocket) and the owner is
|
||||
// persisted so the open window never reopens on later starts. Connections that
|
||||
// raced in during the open window and are neither the owner nor root are
|
||||
// dropped. Changing the socket mode does not disturb the already-open
|
||||
// connection, so the first caller's request is served normally.
|
||||
type tofuListener struct {
|
||||
net.Listener
|
||||
path string
|
||||
mu sync.Mutex
|
||||
owner int // -1 until claimed
|
||||
}
|
||||
|
||||
func (l *tofuListener) Accept() (net.Conn, error) {
|
||||
for {
|
||||
c, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uid, err := peerUID(c)
|
||||
if err != nil {
|
||||
log.Errorf("read peer credentials, dropping connection: %v", err)
|
||||
_ = c.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
if l.owner == -1 {
|
||||
if err := restrictSocket(l.path, uid); err != nil {
|
||||
l.mu.Unlock()
|
||||
_ = c.Close()
|
||||
// Refuse to serve on a socket we could not lock down.
|
||||
return nil, fmt.Errorf("restrict socket on first connection: %w", err)
|
||||
}
|
||||
l.owner = uid
|
||||
persistSocketOwner(uid)
|
||||
log.Infof("control socket restricted to first caller (uid %d)", uid)
|
||||
l.mu.Unlock()
|
||||
return c, nil
|
||||
}
|
||||
owner := l.owner
|
||||
l.mu.Unlock()
|
||||
|
||||
// New connects are already gated by the 0660 perms set above; this only
|
||||
// drops anything that slipped in during the brief open window.
|
||||
if uid != owner && uid != 0 {
|
||||
log.Warnf("dropping non-owner connection (uid %d) during socket bootstrap", uid)
|
||||
_ = c.Close()
|
||||
continue
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
|
||||
// effectiveSocketOwner returns the configured socket owner: the --socket-owner
|
||||
// flag when set, otherwise the owner persisted by a previous TOFU migration.
|
||||
func effectiveSocketOwner() string {
|
||||
if socketOwner != "" {
|
||||
return socketOwner
|
||||
}
|
||||
params, err := loadServiceParams()
|
||||
if err != nil {
|
||||
log.Errorf("load service params for socket owner: %v", err)
|
||||
return ""
|
||||
}
|
||||
if params != nil {
|
||||
return params.SocketOwner
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// persistSocketOwner records the TOFU-selected owner (by username) so the next
|
||||
// daemon start restricts the socket immediately, with no open window.
|
||||
func persistSocketOwner(uid int) {
|
||||
u, err := user.LookupId(strconv.Itoa(uid))
|
||||
if err != nil {
|
||||
log.Errorf("resolve uid %d to username for persistence: %v", uid, err)
|
||||
return
|
||||
}
|
||||
params, err := loadServiceParams()
|
||||
if err != nil {
|
||||
log.Errorf("load service params to persist socket owner: %v", err)
|
||||
return
|
||||
}
|
||||
if params == nil {
|
||||
params = currentServiceParams()
|
||||
}
|
||||
params.SocketOwner = u.Username
|
||||
if err := saveServiceParams(params); err != nil {
|
||||
log.Errorf("persist socket owner: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *program) Stop(srv service.Service) error {
|
||||
p.serverInstanceMu.Lock()
|
||||
if p.serverInstance != nil {
|
||||
|
||||
@@ -67,6 +67,14 @@ func buildServiceArguments() []string {
|
||||
args = append(args, "--disable-networks")
|
||||
}
|
||||
|
||||
if socketOwner != "" {
|
||||
args = append(args, "--socket-owner", socketOwner)
|
||||
}
|
||||
|
||||
if strictSocketDisabled {
|
||||
args = append(args, "--disable-strict-socket")
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
@@ -127,6 +135,8 @@ var installCmd = &cobra.Command{
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("SUDO_UID: %s\n", os.Getenv("SUDO_UID"))
|
||||
|
||||
if err := loadAndApplyServiceParams(cmd); err != nil {
|
||||
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
|
||||
}
|
||||
|
||||
@@ -30,6 +30,8 @@ type serviceParams struct {
|
||||
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
||||
EnableCapture bool `json:"enable_capture,omitempty"`
|
||||
DisableNetworks bool `json:"disable_networks,omitempty"`
|
||||
SocketOwner string `json:"socket_owner,omitempty"`
|
||||
DisableStrictSocket bool `json:"disable_strict_socket,omitempty"`
|
||||
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
||||
}
|
||||
|
||||
@@ -82,6 +84,8 @@ func currentServiceParams() *serviceParams {
|
||||
DisableUpdateSettings: updateSettingsDisabled,
|
||||
EnableCapture: captureEnabled,
|
||||
DisableNetworks: networksDisabled,
|
||||
SocketOwner: socketOwner,
|
||||
DisableStrictSocket: strictSocketDisabled,
|
||||
}
|
||||
|
||||
if len(serviceEnvVars) > 0 {
|
||||
@@ -154,6 +158,14 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
|
||||
networksDisabled = params.DisableNetworks
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("socket-owner") {
|
||||
socketOwner = params.SocketOwner
|
||||
}
|
||||
|
||||
if !serviceCmd.PersistentFlags().Changed("disable-strict-socket") {
|
||||
strictSocketDisabled = params.DisableStrictSocket
|
||||
}
|
||||
|
||||
applyServiceEnvParams(cmd, params)
|
||||
}
|
||||
|
||||
|
||||
@@ -279,10 +279,6 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
// Cancel the client context before stopping: Engine.Start blocks on the
|
||||
// signal stream while holding the engine mutex and only unblocks on
|
||||
// cancellation. Stopping first would deadlock on that mutex.
|
||||
cancel()
|
||||
if stopErr := client.Stop(); stopErr != nil {
|
||||
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||
}
|
||||
@@ -446,8 +442,8 @@ func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession,
|
||||
|
||||
// IdentityForIP looks up a remote peer by its tunnel IP using the
|
||||
// embedded client's status recorder. Returns the peer's WireGuard public
|
||||
// key and FQDN. ok=false means the IP doesn't belong to an active peer
|
||||
// — offline roster peers are treated as unknown, same as foreign IPs.
|
||||
// key and FQDN. ok=false means the IP isn't in this client's peer
|
||||
// roster — callers should treat that as "unknown peer".
|
||||
func (c *Client) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
|
||||
if !ip.IsValid() || c.recorder == nil {
|
||||
return "", "", false
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
mgmt "github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const testSetupKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
// TestClientStartTimeoutRollback reproduces a deadlock between Engine.Start and
|
||||
// Engine.Stop. The signal endpoint accepts gRPC connections but never serves the
|
||||
// SignalExchange service, so Engine.Start parks in WaitStreamConnected while
|
||||
// holding the engine mutex. When the Start context expires, the rollback path
|
||||
// calls ConnectClient.Stop, which must not block forever acquiring that mutex.
|
||||
func TestClientStartTimeoutRollback(t *testing.T) {
|
||||
signalAddr := startBlackholeSignal(t)
|
||||
mgmAddr := startManagement(t, signalAddr)
|
||||
|
||||
wgPort := 0
|
||||
client, err := New(Options{
|
||||
DeviceName: "embed-rollback-test",
|
||||
SetupKey: testSetupKey,
|
||||
ManagementURL: "http://" + mgmAddr,
|
||||
WireguardPort: &wgPort,
|
||||
})
|
||||
require.NoError(t, err, "embed client creation must succeed")
|
||||
|
||||
startCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
startErr := make(chan error, 1)
|
||||
go func() {
|
||||
startErr <- client.Start(startCtx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-startErr:
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
case <-time.After(60 * time.Second):
|
||||
t.Fatal("client.Start did not return after its context expired: Engine.Stop deadlocked against Engine.Start waiting for the signal stream")
|
||||
}
|
||||
}
|
||||
|
||||
// startBlackholeSignal starts a gRPC server without the SignalExchange service
|
||||
// registered. Connections succeed, but the signal stream can never be
|
||||
// established, which keeps Engine.Start parked in WaitStreamConnected.
|
||||
func startBlackholeSignal(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
s := grpc.NewServer()
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
t.Cleanup(s.Stop)
|
||||
|
||||
return lis.Addr().String()
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, signalAddr string) string {
|
||||
t.Helper()
|
||||
|
||||
cfg := &config.Config{
|
||||
Stuns: []*config.Host{},
|
||||
TURNConfig: &config.TURNConfig{},
|
||||
Relay: &config.Relay{
|
||||
Addresses: []string{"127.0.0.1:1234"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "222222222222222222",
|
||||
},
|
||||
Signal: &config.Host{
|
||||
Proto: "http",
|
||||
URI: signalAddr,
|
||||
},
|
||||
Datadir: t.TempDir(),
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
s := grpc.NewServer()
|
||||
|
||||
testStore, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", cfg.Datadir)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
permissionsManager := permissions.NewManager(testStore)
|
||||
peersManager := peers.NewManager(testStore, permissionsManager)
|
||||
jobManager := job.NewJobManager(nil, testStore, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
iv, err := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
require.NoError(t, err)
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.EXPECT().
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
settingsMockManager.EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := mgmt.NewAccountRequestBuffer(context.Background(), testStore)
|
||||
networkMapController := controller.NewController(context.Background(), testStore, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(testStore, peersManager), cfg)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), cfg, testStore, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
require.NoError(t, err)
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, cfg.TURNConfig, cfg.Relay, settingsMockManager, groupsManager)
|
||||
require.NoError(t, err)
|
||||
|
||||
mgmtServer, err := nbgrpc.NewServer(cfg, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
require.NoError(t, err)
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
t.Cleanup(s.Stop)
|
||||
|
||||
return lis.Addr().String()
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package iptables
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net"
|
||||
"slices"
|
||||
|
||||
@@ -422,17 +421,12 @@ func (m *aclManager) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
// Clone the maps so the persisted state holds a private snapshot. The
|
||||
// live maps keep being mutated by subsequent rule operations while the
|
||||
// state manager marshals the state from its periodic-save goroutine.
|
||||
// Sharing them by reference races the two and aborts the process with a
|
||||
// concurrent map iteration and write.
|
||||
if m.v6 {
|
||||
currentState.ACLEntries6 = maps.Clone(m.entries)
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore.clone()
|
||||
currentState.ACLEntries6 = m.entries
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore
|
||||
} else {
|
||||
currentState.ACLEntries = maps.Clone(m.entries)
|
||||
currentState.ACLIPsetStore = m.ipsetStore.clone()
|
||||
currentState.ACLEntries = m.entries
|
||||
currentState.ACLIPsetStore = m.ipsetStore
|
||||
}
|
||||
|
||||
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||
|
||||
@@ -4,7 +4,6 @@ package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -750,17 +749,11 @@ func (r *router) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
// Clone the rule map so the persisted state holds a private snapshot. The
|
||||
// live map keeps being mutated by subsequent rule operations while the
|
||||
// state manager marshals the state from its periodic-save goroutine.
|
||||
// Sharing it by reference races the two and aborts the process with a
|
||||
// concurrent map iteration and write. The ipset counter guards itself
|
||||
// during marshaling, so it can be shared directly.
|
||||
if r.v6 {
|
||||
currentState.RouteRules6 = maps.Clone(r.rules)
|
||||
currentState.RouteRules6 = r.rules
|
||||
currentState.RouteIPsetCounter6 = r.ipsetCounter
|
||||
} else {
|
||||
currentState.RouteRules = maps.Clone(r.rules)
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"maps"
|
||||
)
|
||||
import "encoding/json"
|
||||
|
||||
type ipList struct {
|
||||
ips map[string]struct{}
|
||||
@@ -22,14 +19,6 @@ func (s *ipList) addIP(ip string) {
|
||||
s.ips[ip] = struct{}{}
|
||||
}
|
||||
|
||||
// clone returns a deep copy of the ipList with its own ips map.
|
||||
func (s *ipList) clone() *ipList {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return &ipList{ips: maps.Clone(s.ips)}
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (s *ipList) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
@@ -66,19 +55,6 @@ func newIpsetStore() *ipsetStore {
|
||||
}
|
||||
}
|
||||
|
||||
// clone returns a deep copy of the ipsetStore with its own ipsets map and
|
||||
// independent ipList entries.
|
||||
func (s *ipsetStore) clone() *ipsetStore {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := &ipsetStore{ipsets: make(map[string]*ipList, len(s.ipsets))}
|
||||
for name, list := range s.ipsets {
|
||||
cloned.ipsets[name] = list.clone()
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
|
||||
@@ -516,14 +516,6 @@ func (g *BundleGenerator) addConfig() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Surface the set of MDM-enforced keys so a support engineer reading
|
||||
// the bundle can tell which field values are user-set vs MDM-overridden.
|
||||
// Same semantics as the mDMManagedFields list returned by the
|
||||
// GetConfig RPC consumed by `netbird debug config`.
|
||||
if managed := g.internalConfig.Policy().ManagedKeys(); len(managed) > 0 {
|
||||
configContent.WriteString(fmt.Sprintf("MDMManagedFields: %v\n", managed))
|
||||
}
|
||||
|
||||
configReader := strings.NewReader(configContent.String())
|
||||
if err := g.addFileToZip(configReader, "config.txt"); err != nil {
|
||||
return fmt.Errorf("add config file to zip: %w", err)
|
||||
|
||||
@@ -843,7 +843,6 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
||||
"PreSharedKey": "sensitive: WireGuard pre-shared key",
|
||||
"SSHKey": "sensitive: SSH private key",
|
||||
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
|
||||
"policy": "non-config: in-memory MDM policy snapshot, surfaced via Config.Policy() / GetConfigResponse.MDMManagedFields",
|
||||
}
|
||||
|
||||
mURL, _ := url.Parse("https://api.example.com:443")
|
||||
|
||||
@@ -482,7 +482,7 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16,
|
||||
// completely when every proxy peer is offline (the upstream may still
|
||||
// be reachable some other way, or the peerstore may be stale).
|
||||
func (d *Resolver) filterDisconnectedPeerAnswers(logger *log.Entry, question dns.Question, records []dns.RR) []dns.RR {
|
||||
if len(records) < 2 {
|
||||
if len(records) == 0 {
|
||||
return records
|
||||
}
|
||||
d.mu.RLock()
|
||||
|
||||
@@ -2738,17 +2738,6 @@ func TestLocalResolver_FilterDisconnectedPeerAnswers(t *testing.T) {
|
||||
connByIP: nil,
|
||||
wantInOrder: []string{"100.64.0.10", "100.64.0.11"},
|
||||
},
|
||||
{
|
||||
// A single answer is never filtered: dropping it would only
|
||||
// trigger the empty-answer escape hatch, so the fast path
|
||||
// returns it untouched.
|
||||
name: "single disconnected answer passes through",
|
||||
records: []nbdns.SimpleRecord{disconnectedRec},
|
||||
connByIP: map[string]ipState{
|
||||
"100.64.0.11": {known: true, connected: false},
|
||||
},
|
||||
wantInOrder: []string{"100.64.0.11"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
|
||||
@@ -14,10 +14,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// errNoSuitableAddress mirrors the unexported error string the net package
|
||||
// uses when a resolved host has no addresses of the requested family.
|
||||
const errNoSuitableAddress = "no suitable address found"
|
||||
|
||||
// GenerateRequestID creates a random 8-character hex string for request tracing.
|
||||
func GenerateRequestID() string {
|
||||
bytes := make([]byte, 4)
|
||||
@@ -130,14 +126,6 @@ func LookupIP(ctx context.Context, r resolver, network, host string, qtype uint1
|
||||
}
|
||||
|
||||
func getRcodeForError(ctx context.Context, r resolver, host string, qtype uint16, err error) int {
|
||||
// The net package returns this AddrError when the host resolves but has
|
||||
// no addresses of the requested family. The domain exists, so answer
|
||||
// NODATA instead of SERVFAIL.
|
||||
var addrErr *net.AddrError
|
||||
if errors.As(err, &addrErr) && addrErr.Err == errNoSuitableAddress {
|
||||
return dns.RcodeSuccess
|
||||
}
|
||||
|
||||
var dnsErr *net.DNSError
|
||||
if !errors.As(err, &dnsErr) {
|
||||
return dns.RcodeServerFailure
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
package resutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockResolver struct {
|
||||
// results maps network ("ip4"/"ip6") to the lookup outcome.
|
||||
results map[string]mockLookup
|
||||
}
|
||||
|
||||
type mockLookup struct {
|
||||
ips []netip.Addr
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockResolver) LookupNetIP(_ context.Context, network, _ string) ([]netip.Addr, error) {
|
||||
res, ok := m.results[network]
|
||||
if !ok {
|
||||
return nil, errors.New("unexpected network: " + network)
|
||||
}
|
||||
return res.ips, res.err
|
||||
}
|
||||
|
||||
func TestLookupIP_Success(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip4": {ips: []netip.Addr{netip.MustParseAddr("::ffff:192.0.2.1")}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip4", "example.com.", dns.TypeA)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, result.Rcode, "successful lookup should return NOERROR")
|
||||
require.Len(t, result.IPs, 1, "should return the resolved address")
|
||||
assert.Equal(t, netip.MustParseAddr("192.0.2.1"), result.IPs[0], "v4-mapped address should be unmapped")
|
||||
}
|
||||
|
||||
func TestLookupIP_NoSuitableAddress(t *testing.T) {
|
||||
// The net package returns this AddrError when the host resolves but has
|
||||
// no addresses of the requested family (e.g. AAAA query for a v4-only
|
||||
// hosts file entry). The domain exists, so this is NODATA, not SERVFAIL.
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip6": {err: &net.AddrError{Err: "no suitable address found", Addr: "example.com."}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip6", "example.com.", dns.TypeAAAA)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, result.Rcode, "no suitable address should map to NODATA")
|
||||
assert.Empty(t, result.IPs, "NODATA response should carry no addresses")
|
||||
}
|
||||
|
||||
// TestErrNoSuitableAddressMatchesNetPackage pins our copy of the error string
|
||||
// to what the net package actually emits. A literal IP of the wrong family
|
||||
// takes the same filterAddrList path as a resolved hostname, without network
|
||||
// access.
|
||||
func TestErrNoSuitableAddressMatchesNetPackage(t *testing.T) {
|
||||
_, err := (&net.Resolver{}).LookupNetIP(context.Background(), "ip6", "192.0.2.1")
|
||||
require.Error(t, err)
|
||||
|
||||
var addrErr *net.AddrError
|
||||
require.ErrorAs(t, err, &addrErr, "wrong-family lookup should return AddrError")
|
||||
assert.Equal(t, errNoSuitableAddress, addrErr.Err, "net package error string should match our constant")
|
||||
}
|
||||
|
||||
func TestLookupIP_OtherAddrError(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip4": {err: &net.AddrError{Err: "some other address problem", Addr: "example.com."}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip4", "example.com.", dns.TypeA)
|
||||
|
||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "unrecognized AddrError should map to SERVFAIL")
|
||||
}
|
||||
|
||||
func TestLookupIP_NotFoundNXDomain(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip4": {err: &net.DNSError{Err: "no such host", Name: "example.com.", IsNotFound: true}},
|
||||
"ip6": {err: &net.DNSError{Err: "no such host", Name: "example.com.", IsNotFound: true}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip4", "example.com.", dns.TypeA)
|
||||
|
||||
assert.Equal(t, dns.RcodeNameError, result.Rcode, "not found for both families should map to NXDOMAIN")
|
||||
}
|
||||
|
||||
func TestLookupIP_NotFoundNoData(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip6": {err: &net.DNSError{Err: "no such host", Name: "example.com.", IsNotFound: true}},
|
||||
"ip4": {ips: []netip.Addr{netip.MustParseAddr("192.0.2.1")}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip6", "example.com.", dns.TypeAAAA)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, result.Rcode, "not found with the other family present should map to NODATA")
|
||||
}
|
||||
|
||||
func TestLookupIP_GenericError(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip4": {err: errors.New("connection refused")},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip4", "example.com.", dns.TypeA)
|
||||
|
||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "generic error should map to SERVFAIL")
|
||||
}
|
||||
|
||||
func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
|
||||
r := &mockResolver{results: map[string]mockLookup{
|
||||
"ip4": {err: &net.DNSError{Err: "server misbehaving", Name: "example.com.", IsTemporary: true}},
|
||||
}}
|
||||
|
||||
result := LookupIP(context.Background(), r, "ip4", "example.com.", dns.TypeA)
|
||||
|
||||
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
|
||||
}
|
||||
@@ -777,24 +777,13 @@ func (s *DefaultServer) applyHostConfig() {
|
||||
// context is released rather than leaked until GC.
|
||||
func (s *DefaultServer) registerFallback() {
|
||||
originalNameservers := s.hostManager.getOriginalNameservers()
|
||||
|
||||
serverIP := s.service.RuntimeIP()
|
||||
var servers []netip.AddrPort
|
||||
for _, ns := range originalNameservers {
|
||||
if ns == serverIP {
|
||||
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, serverIP)
|
||||
continue
|
||||
}
|
||||
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
|
||||
}
|
||||
|
||||
if len(servers) == 0 {
|
||||
if len(originalNameservers) == 0 {
|
||||
log.Debugf("no fallback upstreams to register; clearing PriorityFallback handler")
|
||||
s.clearFallback()
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("registering original nameservers %v as upstream handlers with priority %d", servers, PriorityFallback)
|
||||
log.Infof("registering original nameservers %v as upstream handlers with priority %d", originalNameservers, PriorityFallback)
|
||||
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
@@ -808,6 +797,11 @@ func (s *DefaultServer) registerFallback() {
|
||||
return
|
||||
}
|
||||
handler.selectedRoutes = s.selectedRoutes
|
||||
|
||||
var servers []netip.AddrPort
|
||||
for _, ns := range originalNameservers {
|
||||
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
|
||||
}
|
||||
handler.addRace(servers)
|
||||
|
||||
prev := s.fallbackHandler
|
||||
|
||||
@@ -880,25 +880,62 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
||||
}
|
||||
|
||||
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
|
||||
return err
|
||||
}
|
||||
if update.GetNetbirdConfig() != nil {
|
||||
wCfg := update.GetNetbirdConfig()
|
||||
err := e.updateTURNs(wCfg.GetTurns())
|
||||
if err != nil {
|
||||
return fmt.Errorf("update TURNs: %w", err)
|
||||
}
|
||||
|
||||
// Posture checks are bound to the network map presence:
|
||||
// NetworkMap != nil, checks present -> apply the received checks
|
||||
// NetworkMap != nil, checks nil -> posture checks were removed, clear them
|
||||
// NetworkMap == nil -> config-only update (e.g. relay token rotation),
|
||||
// leave the previously applied checks untouched
|
||||
nm := update.GetNetworkMap()
|
||||
if nm == nil {
|
||||
return nil
|
||||
err = e.updateSTUNs(wCfg.GetStuns())
|
||||
if err != nil {
|
||||
return fmt.Errorf("update STUNs: %w", err)
|
||||
}
|
||||
|
||||
var stunTurn []*stun.URI
|
||||
stunTurn = append(stunTurn, e.STUNs...)
|
||||
stunTurn = append(stunTurn, e.TURNs...)
|
||||
e.stunTurn.Store(stunTurn)
|
||||
|
||||
err = e.handleRelayUpdate(wCfg.GetRelay())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = e.handleFlowUpdate(wCfg.GetFlow())
|
||||
if err != nil {
|
||||
return fmt.Errorf("handle the flow configuration: %w", err)
|
||||
}
|
||||
|
||||
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
|
||||
log.Warnf("Failed to update DNS server config: %v", err)
|
||||
}
|
||||
|
||||
// todo update signal
|
||||
}
|
||||
|
||||
if err := e.updateChecksIfNew(update.Checks); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.persistSyncResponse(update)
|
||||
nm := update.GetNetworkMap()
|
||||
if nm == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
|
||||
// A non-nil syncStore is what marks persistence as enabled. Hold the lock for
|
||||
// the whole Set so the store cannot be cleared (disabled / engine close)
|
||||
// mid-call and have this write resurrect a file that was just removed.
|
||||
e.syncRespMux.RLock()
|
||||
if e.syncStore != nil {
|
||||
if err := e.syncStore.Set(update); err != nil {
|
||||
log.Errorf("failed to persist sync response: %v", err)
|
||||
} else {
|
||||
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
|
||||
}
|
||||
}
|
||||
e.syncRespMux.RUnlock()
|
||||
|
||||
// only apply new changes and ignore old ones
|
||||
if err := e.updateNetworkMap(nm); err != nil {
|
||||
@@ -910,64 +947,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateNetbirdConfig applies the management-provided NetBird configuration:
|
||||
// STUN/TURN and relay servers, flow logging and DNS settings. A nil config is a no-op,
|
||||
// which is the case for sync updates carrying only a network map.
|
||||
func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error {
|
||||
if wCfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.updateTURNs(wCfg.GetTurns()); err != nil {
|
||||
return fmt.Errorf("update TURNs: %w", err)
|
||||
}
|
||||
|
||||
if err := e.updateSTUNs(wCfg.GetStuns()); err != nil {
|
||||
return fmt.Errorf("update STUNs: %w", err)
|
||||
}
|
||||
|
||||
var stunTurn []*stun.URI
|
||||
stunTurn = append(stunTurn, e.STUNs...)
|
||||
stunTurn = append(stunTurn, e.TURNs...)
|
||||
e.stunTurn.Store(stunTurn)
|
||||
|
||||
if err := e.handleRelayUpdate(wCfg.GetRelay()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := e.handleFlowUpdate(wCfg.GetFlow()); err != nil {
|
||||
return fmt.Errorf("handle the flow configuration: %w", err)
|
||||
}
|
||||
|
||||
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
|
||||
log.Warnf("Failed to update DNS server config: %v", err)
|
||||
}
|
||||
|
||||
// todo update signal
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// persistSyncResponse stores the full sync response so it can be restored on the next
|
||||
// startup. Persistence is enabled only when syncStore is set. The dedicated syncRespMux
|
||||
// (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled /
|
||||
// engine close) mid-call and have this write resurrect a file that was just removed.
|
||||
func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) {
|
||||
e.syncRespMux.RLock()
|
||||
defer e.syncRespMux.RUnlock()
|
||||
|
||||
if e.syncStore == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := e.syncStore.Set(update); err != nil {
|
||||
log.Errorf("failed to persist sync response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("sync response persisted with serial %d", update.GetNetworkMap().GetSerial())
|
||||
}
|
||||
|
||||
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
||||
if update != nil {
|
||||
// when we receive token we expect valid address list too
|
||||
|
||||
@@ -26,6 +26,7 @@ type connStatusInputs struct {
|
||||
iceInProgress bool // a negotiation is currently in flight
|
||||
}
|
||||
|
||||
|
||||
// ConnStatus describe the status of a peer's connection
|
||||
type ConnStatus int32
|
||||
|
||||
|
||||
@@ -193,7 +193,6 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
|
||||
type Status struct {
|
||||
mux sync.RWMutex
|
||||
peers map[string]State
|
||||
ipToKey map[string]string
|
||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||
signalState bool
|
||||
signalError error
|
||||
@@ -232,7 +231,6 @@ type Status struct {
|
||||
func NewRecorder(mgmAddress string) *Status {
|
||||
return &Status{
|
||||
peers: make(map[string]State),
|
||||
ipToKey: make(map[string]string),
|
||||
changeNotify: make(map[string]map[string]*StatusChangeSubscription),
|
||||
eventStreams: make(map[string]chan *proto.SystemEvent),
|
||||
eventQueue: NewEventQueue(eventQueueSize),
|
||||
@@ -284,12 +282,6 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string)
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
d.peerListChangedForNotification = true
|
||||
if ipv6 != "" {
|
||||
d.ipToKey[ipv6] = peerPubKey
|
||||
}
|
||||
if ip != "" {
|
||||
d.ipToKey[ip] = peerPubKey
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -319,22 +311,28 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
|
||||
|
||||
// PeerStateByIP returns the full peer State for the given tunnel IP.
|
||||
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
|
||||
// address so dual-stack peers are reachable on either family. Only
|
||||
// active peers are matched; peers moved into the offline slice by
|
||||
// ReplaceOfflinePeers are intentionally treated as unknown.
|
||||
// address so dual-stack peers are reachable on either family. Searches
|
||||
// both d.peers and d.offlinePeers — peers that have been moved into
|
||||
// the offline slice by ReplaceOfflinePeers are still part of the
|
||||
// account's roster and callers (DNS filter, embed.Client.IdentityForIP)
|
||||
// need to recognise them rather than treating them as unknown. Returns
|
||||
// the zero State and false when no peer matches or the input is empty.
|
||||
func (d *Status) PeerStateByIP(ip string) (State, bool) {
|
||||
if ip == "" {
|
||||
return State{}, false
|
||||
}
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
key, ok := d.ipToKey[ip]
|
||||
if !ok {
|
||||
return State{}, false
|
||||
|
||||
for _, state := range d.peers {
|
||||
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
|
||||
return state, true
|
||||
}
|
||||
}
|
||||
state, ok := d.peers[key]
|
||||
if ok {
|
||||
return state, true
|
||||
for _, state := range d.offlinePeers {
|
||||
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
|
||||
return state, true
|
||||
}
|
||||
}
|
||||
return State{}, false
|
||||
}
|
||||
@@ -344,18 +342,12 @@ func (d *Status) RemovePeer(peerPubKey string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
p, ok := d.peers[peerPubKey]
|
||||
_, ok := d.peers[peerPubKey]
|
||||
if !ok {
|
||||
return errors.New("no peer with to remove")
|
||||
}
|
||||
|
||||
delete(d.peers, peerPubKey)
|
||||
if mappedKey, exists := d.ipToKey[p.IP]; exists && mappedKey == peerPubKey {
|
||||
delete(d.ipToKey, p.IP)
|
||||
}
|
||||
if mappedKey, exists := d.ipToKey[p.IPv6]; exists && mappedKey == peerPubKey {
|
||||
delete(d.ipToKey, p.IPv6)
|
||||
}
|
||||
d.peerListChangedForNotification = true
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -90,11 +90,12 @@ func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
|
||||
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
|
||||
}
|
||||
|
||||
// TestStatus_PeerStateByIP_IgnoresOfflinePeers documents that peers
|
||||
// moved into the offline slice via ReplaceOfflinePeers are intentionally
|
||||
// not resolvable by IP: only active peers can carry traffic, so callers
|
||||
// (DNS filter, embed.Client.IdentityForIP) treat them as unknown.
|
||||
func TestStatus_PeerStateByIP_IgnoresOfflinePeers(t *testing.T) {
|
||||
// TestStatus_PeerStateByIP_MatchesOfflinePeers covers peers that have
|
||||
// been moved into the offline slice via ReplaceOfflinePeers. Callers
|
||||
// (DNS filter, embed.Client.IdentityForIP) need to treat them as known
|
||||
// rather than unknown — otherwise authentication / DNS filtering treats
|
||||
// known-but-offline peers as foreign IPs.
|
||||
func TestStatus_PeerStateByIP_MatchesOfflinePeers(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
req := require.New(t)
|
||||
|
||||
@@ -102,31 +103,13 @@ func TestStatus_PeerStateByIP_IgnoresOfflinePeers(t *testing.T) {
|
||||
{PubKey: "pk-offline", FQDN: "offline.netbird", IP: "100.64.0.20", IPv6: "fd00::20"},
|
||||
})
|
||||
|
||||
_, ok := status.PeerStateByIP("100.64.0.20")
|
||||
req.False(ok, "offline peer must not resolve by IPv4 tunnel address")
|
||||
state, ok := status.PeerStateByIP("100.64.0.20")
|
||||
req.True(ok, "offline peer must resolve by IPv4 tunnel address")
|
||||
req.Equal("pk-offline", state.PubKey, "matching state must carry the offline peer's pub key")
|
||||
|
||||
_, ok = status.PeerStateByIP("fd00::20")
|
||||
req.False(ok, "offline peer must not resolve by IPv6 tunnel address")
|
||||
}
|
||||
|
||||
// TestStatus_PeerStateByIP_RemovedPeer verifies RemovePeer drops the
|
||||
// IP index entries for both address families.
|
||||
func TestStatus_PeerStateByIP_RemovedPeer(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
req := require.New(t)
|
||||
|
||||
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", "fd00::1"))
|
||||
|
||||
_, ok := status.PeerStateByIP("100.64.0.10")
|
||||
req.True(ok, "active peer must resolve before removal")
|
||||
|
||||
req.NoError(status.RemovePeer("pk-1"))
|
||||
|
||||
_, ok = status.PeerStateByIP("100.64.0.10")
|
||||
req.False(ok, "removed peer must not resolve by IPv4 tunnel address")
|
||||
|
||||
_, ok = status.PeerStateByIP("fd00::1")
|
||||
req.False(ok, "removed peer must not resolve by IPv6 tunnel address")
|
||||
state, ok = status.PeerStateByIP("fd00::20")
|
||||
req.True(ok, "offline peer must resolve by IPv6 tunnel address")
|
||||
req.Equal("pk-offline", state.PubKey, "IPv6 match must carry the offline peer's pub key")
|
||||
}
|
||||
|
||||
func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -58,10 +57,6 @@ var DefaultInterfaceBlacklist = []string{
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
|
||||
}
|
||||
|
||||
// loadMDMPolicy is the package-level indirection used by apply() to read the
|
||||
// active MDM policy. Tests override this to inject a fake policy.
|
||||
var loadMDMPolicy = mdm.LoadPolicy
|
||||
|
||||
// ConfigInput carries configuration changes to the client
|
||||
type ConfigInput struct {
|
||||
ManagementURL string
|
||||
@@ -179,23 +174,6 @@ type Config struct {
|
||||
LazyConnectionEnabled bool
|
||||
|
||||
MTU uint16
|
||||
|
||||
// policy is the MDM policy that produced the currently-set values for
|
||||
// any MDM-enforced fields. Set by applyMDMPolicy at the tail of apply()
|
||||
// and reset on every apply() invocation. Never persisted to disk.
|
||||
// Callers query enforcement state via Policy() and the mdm.Policy API
|
||||
// (HasKey, ManagedKeys, IsEmpty).
|
||||
policy *mdm.Policy `json:"-"`
|
||||
}
|
||||
|
||||
// Policy returns the MDM policy applied to this Config. Returns a non-nil
|
||||
// empty Policy when MDM enforcement is inactive; callers can always invoke
|
||||
// HasKey / ManagedKeys / IsEmpty without a nil check.
|
||||
func (config *Config) Policy() *mdm.Policy {
|
||||
if config == nil || config.policy == nil {
|
||||
return mdm.NewPolicy(nil)
|
||||
}
|
||||
return config.policy
|
||||
}
|
||||
|
||||
var ConfigDirOverride string
|
||||
@@ -634,93 +612,10 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
// MDM is the last override layer: any key present in the policy
|
||||
// supersedes defaults, on-disk config, env vars and CLI input.
|
||||
config.applyMDMPolicy(loadMDMPolicy())
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// applyMDMPolicy overlays MDM-supplied values on top of the resolved Config.
|
||||
// The provided Policy is also stored on the Config so callers can later query
|
||||
// which fields are enforced. Invalid values (e.g. malformed URLs) are logged
|
||||
// and skipped to avoid bricking the client; the field keeps its previous
|
||||
// resolved value but is still marked as managed (Policy.HasKey returns true
|
||||
// for the key, so per-field rejection of user writes still applies).
|
||||
func (config *Config) applyMDMPolicy(policy *mdm.Policy) {
|
||||
config.policy = policy
|
||||
if policy.IsEmpty() {
|
||||
return
|
||||
}
|
||||
|
||||
// Helper: log the application of a single MDM-managed key. Values for
|
||||
// keys in mdm.SecretKeys are redacted.
|
||||
logApplied := func(key string, displayValue any) {
|
||||
if _, secret := mdm.SecretKeys[key]; secret {
|
||||
log.Infof("MDM override %s = ********** (secret)", key)
|
||||
return
|
||||
}
|
||||
log.Infof("MDM override %s = %v", key, displayValue)
|
||||
}
|
||||
|
||||
if v, ok := policy.GetString(mdm.KeyManagementURL); ok {
|
||||
if u, err := parseURL("Management URL", v); err != nil {
|
||||
log.Warnf("MDM management URL %q invalid: %v; keeping previous value", v, err)
|
||||
} else {
|
||||
config.ManagementURL = u
|
||||
logApplied(mdm.KeyManagementURL, u.String())
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := policy.GetString(mdm.KeyPreSharedKey); ok {
|
||||
// Defensive: refuse the redaction mask in case it round-tripped
|
||||
// through a manifest by mistake.
|
||||
if !isPreSharedKeyHidden(&v) {
|
||||
config.PreSharedKey = v
|
||||
logApplied(mdm.KeyPreSharedKey, "")
|
||||
}
|
||||
}
|
||||
|
||||
// applyBool collapses the per-key "read + set + log" boilerplate
|
||||
// for every plain bool MDM key into a single helper. Keeps the
|
||||
// outer function's cognitive complexity below SonarCube's
|
||||
// threshold; functional behaviour is identical to the inlined
|
||||
// branches it replaces.
|
||||
applyBool := func(key string, setter func(bool)) {
|
||||
v, ok := policy.GetBool(key)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
setter(v)
|
||||
logApplied(key, v)
|
||||
}
|
||||
|
||||
applyBool(mdm.KeyAllowServerSSH, func(v bool) { bv := v; config.ServerSSHAllowed = &bv })
|
||||
applyBool(mdm.KeyDisableClientRoutes, func(v bool) { config.DisableClientRoutes = v })
|
||||
applyBool(mdm.KeyDisableServerRoutes, func(v bool) { config.DisableServerRoutes = v })
|
||||
applyBool(mdm.KeyBlockInbound, func(v bool) { config.BlockInbound = v })
|
||||
applyBool(mdm.KeyDisableAutoConnect, func(v bool) { config.DisableAutoConnect = v })
|
||||
applyBool(mdm.KeyRosenpassEnabled, func(v bool) { config.RosenpassEnabled = v })
|
||||
applyBool(mdm.KeyRosenpassPermissive, func(v bool) { config.RosenpassPermissive = v })
|
||||
|
||||
if v, ok := policy.GetInt(mdm.KeyWireguardPort); ok {
|
||||
// REG_DWORD is 32-bit; UDP port range is 1-65535. Clamp at the
|
||||
// upper bound and reject obviously-invalid values to avoid the
|
||||
// engine binding to an unusable port if the admin pushes garbage.
|
||||
if v >= 1 && v <= 65535 {
|
||||
config.WgPort = int(v)
|
||||
logApplied(mdm.KeyWireguardPort, v)
|
||||
} else {
|
||||
log.Warnf("MDM wireguard port %d out of range [1,65535]; keeping previous value", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseURL parses and validates the URL for the named service. The URL
|
||||
// must use the http or https scheme; if no port is present, ":443" is
|
||||
// appended for https or ":80" for http. The serviceName parameter is
|
||||
// used to contextualise error messages. On success returns the parsed
|
||||
// *url.URL; on failure returns a non-nil error.
|
||||
// parseURL parses and validates a service URL
|
||||
func parseURL(serviceName, serviceURL string) (*url.URL, error) {
|
||||
parsedMgmtURL, err := url.ParseRequestURI(serviceURL)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
package profilemanager
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
)
|
||||
|
||||
// withMDMPolicy temporarily overrides the package-level loadMDMPolicy hook so
|
||||
// apply() observes the supplied Policy. The original loader is restored at
|
||||
// test cleanup.
|
||||
func withMDMPolicy(t *testing.T, policy *mdm.Policy) {
|
||||
t.Helper()
|
||||
prev := loadMDMPolicy
|
||||
loadMDMPolicy = func() *mdm.Policy { return policy }
|
||||
t.Cleanup(func() { loadMDMPolicy = prev })
|
||||
}
|
||||
|
||||
func TestApply_MDMEmpty_NoEnforcement(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(nil))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.True(t, cfg.Policy().IsEmpty(), "no MDM source ⇒ empty Policy")
|
||||
assert.False(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
assert.Empty(t, cfg.Policy().ManagedKeys())
|
||||
|
||||
// Default management URL still resolves.
|
||||
assert.Equal(t, DefaultManagementURL, cfg.ManagementURL.String())
|
||||
}
|
||||
|
||||
func TestApply_MDMOnly_OverridesDefaults(t *testing.T) {
|
||||
const mdmURL = "https://corp.mdm.example.com:443"
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: mdmURL,
|
||||
mdm.KeyDisableClientRoutes: true,
|
||||
mdm.KeyBlockInbound: true,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.Equal(t, mdmURL, cfg.ManagementURL.String())
|
||||
assert.True(t, cfg.DisableClientRoutes)
|
||||
assert.True(t, cfg.BlockInbound)
|
||||
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyDisableClientRoutes))
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyBlockInbound))
|
||||
assert.False(t, cfg.Policy().HasKey(mdm.KeyAllowServerSSH))
|
||||
}
|
||||
|
||||
func TestApply_MDMBeatsCLIInput(t *testing.T) {
|
||||
const mdmURL = "https://mdm.example.com:443"
|
||||
const cliURL = "https://cli.example.com:443"
|
||||
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: mdmURL,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
ManagementURL: cliURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// MDM wins over CLI-supplied management URL.
|
||||
assert.Equal(t, mdmURL, cfg.ManagementURL.String())
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
}
|
||||
|
||||
func TestApply_MDMInvalidURL_KeepsPreviousValue(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "not-a-url",
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// Invalid MDM URL is logged and skipped: default URL stays in place
|
||||
// to keep the client functional.
|
||||
assert.Equal(t, DefaultManagementURL, cfg.ManagementURL.String())
|
||||
|
||||
// But the key is still considered MDM-managed (admin intent is to
|
||||
// enforce, daemon rejects user writes to this field — phase-1 scaffolding
|
||||
// reflects this by keeping Policy.HasKey true even on parse failure).
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
}
|
||||
|
||||
func TestApply_MDMBoolKeysOverrideOnDiskValue(t *testing.T) {
|
||||
tmp := filepath.Join(t.TempDir(), "config.json")
|
||||
|
||||
// Seed without MDM.
|
||||
withMDMPolicy(t, mdm.NewPolicy(nil))
|
||||
_, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: tmp,
|
||||
DisableClientRoutes: boolPtr(false),
|
||||
RosenpassEnabled: boolPtr(false),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now enable MDM enforcement for these keys.
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyDisableClientRoutes: true,
|
||||
mdm.KeyRosenpassEnabled: true,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{ConfigPath: tmp})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.True(t, cfg.DisableClientRoutes, "MDM override should flip on-disk false to true")
|
||||
assert.True(t, cfg.RosenpassEnabled)
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyDisableClientRoutes))
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyRosenpassEnabled))
|
||||
}
|
||||
|
||||
func TestApply_MDMPreSharedKeyRedactionSentinelRejected(t *testing.T) {
|
||||
const maskSentinel = "**********"
|
||||
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyPreSharedKey: maskSentinel,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// Mask sentinel must not be persisted as the actual PSK.
|
||||
assert.NotEqual(t, maskSentinel, cfg.PreSharedKey)
|
||||
// Key still marked managed so user writes are still rejected.
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyPreSharedKey))
|
||||
}
|
||||
|
||||
func boolPtr(b bool) *bool { return &b }
|
||||
@@ -22,14 +22,14 @@ type removePeerCall struct {
|
||||
}
|
||||
|
||||
type mockServer struct {
|
||||
mu sync.Mutex
|
||||
addCalls []addPeerCall
|
||||
removed []removePeerCall
|
||||
nextID rp.PeerID
|
||||
addErr error
|
||||
removeErr error
|
||||
closed bool
|
||||
ran bool
|
||||
mu sync.Mutex
|
||||
addCalls []addPeerCall
|
||||
removed []removePeerCall
|
||||
nextID rp.PeerID
|
||||
addErr error
|
||||
removeErr error
|
||||
closed bool
|
||||
ran bool
|
||||
}
|
||||
|
||||
func (m *mockServer) AddPeer(cfg rp.PeerConfig) (rp.PeerID, error) {
|
||||
@@ -51,7 +51,7 @@ func (m *mockServer) RemovePeer(id rp.PeerID) error {
|
||||
return m.removeErr
|
||||
}
|
||||
|
||||
func (m *mockServer) Run() error { m.ran = true; return nil }
|
||||
func (m *mockServer) Run() error { m.ran = true; return nil }
|
||||
func (m *mockServer) Close() error { m.closed = true; return nil }
|
||||
|
||||
type setPSKCall struct {
|
||||
|
||||
@@ -41,3 +41,4 @@ func TestDeterministicSeedKey_TooShortKey_ReturnsError(t *testing.T) {
|
||||
_, err = DeterministicSeedKey(long, short)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -700,13 +700,6 @@ func resolveURLsToIPs(urls []string) []net.IP {
|
||||
|
||||
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
|
||||
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
|
||||
// An explicit user "deselect all" must not be overridden by management auto-apply.
|
||||
// Auto-applying an exit node here would call SelectRoutes, which clears the
|
||||
// deselect-all flag and re-enables every route the user turned off.
|
||||
if m.routeSelector.IsDeselectAll() {
|
||||
return
|
||||
}
|
||||
|
||||
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
|
||||
if len(exitNodeInfo.allIDs) == 0 {
|
||||
return
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func exitNodeRoutes(netID route.NetID, skipAutoApply bool) route.HAMap {
|
||||
haID := route.HAUniqueID(string(netID) + "|0.0.0.0/0")
|
||||
return route.HAMap{
|
||||
haID: []*route.Route{
|
||||
{
|
||||
ID: "r-" + route.ID(netID),
|
||||
NetID: netID,
|
||||
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Enabled: true,
|
||||
SkipAutoApply: skipAutoApply,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement(t *testing.T) {
|
||||
t.Run("management auto-apply selects exit node without user selection", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
routes := exitNodeRoutes("exit1", false)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsSelected("exit1"), "auto-apply exit node should be selected")
|
||||
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "selected exit node should pass the filter")
|
||||
})
|
||||
|
||||
t.Run("management SkipAutoApply leaves exit node deselected", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
routes := exitNodeRoutes("exit1", true)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.False(t, m.routeSelector.IsSelected("exit1"), "SkipAutoApply exit node should not be selected")
|
||||
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "deselected exit node should be filtered out")
|
||||
})
|
||||
|
||||
t.Run("user selection is not overridden by management", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
require.NoError(t, m.routeSelector.SelectRoutes([]route.NetID{"exit1"}, true, []route.NetID{"exit1"}))
|
||||
routes := exitNodeRoutes("exit1", true)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsSelected("exit1"), "explicit user selection must survive a management sync that wants to skip auto-apply")
|
||||
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "user-selected exit node should pass the filter")
|
||||
})
|
||||
|
||||
t.Run("deselect-all is preserved across a management sync", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
m.routeSelector.DeselectAllRoutes()
|
||||
routes := exitNodeRoutes("exit1", false)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsDeselectAll(), "an explicit deselect-all must not be cleared by management auto-apply")
|
||||
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "no routes should be selected while deselect-all is set")
|
||||
})
|
||||
}
|
||||
@@ -116,14 +116,6 @@ func (rs *RouteSelector) DeselectAllRoutes() {
|
||||
clear(rs.selectedRoutes)
|
||||
}
|
||||
|
||||
// IsDeselectAll reports whether the user has explicitly deselected all routes.
|
||||
func (rs *RouteSelector) IsDeselectAll() bool {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
return rs.deselectAll
|
||||
}
|
||||
|
||||
// IsSelected checks if a specific route is selected.
|
||||
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
rs.mu.RLock()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build cgo && !osusergo && !windows
|
||||
|
||||
package server
|
||||
package shell
|
||||
|
||||
import "os/user"
|
||||
|
||||
@@ -8,17 +8,22 @@ import "os/user"
|
||||
// When CGO is enabled, os/user uses libc (getpwnam_r) which goes through
|
||||
// the NSS stack natively. If it fails, the user truly doesn't exist and
|
||||
// getent would also fail.
|
||||
func lookupWithGetent(username string) (*user.User, error) {
|
||||
func LookupWithGetent(username string) (*user.User, error) {
|
||||
return user.Lookup(username)
|
||||
}
|
||||
|
||||
// currentUserWithGetent with CGO delegates directly to os/user.Current.
|
||||
func currentUserWithGetent() (*user.User, error) {
|
||||
func CurrentUserWithGetent() (*user.User, error) {
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
// LookupGroupWithGetent returns the resolved group from either a gid or groupname
|
||||
func LookupGroupWithGetent(name string) (*user.Group, error) {
|
||||
return user.LookupGroup(name)
|
||||
}
|
||||
|
||||
// groupIdsWithFallback with CGO delegates directly to user.GroupIds.
|
||||
// libc's getgrouplist handles NSS groups natively.
|
||||
func groupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
func GroupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
return u.GroupIds()
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build (!cgo || osusergo) && !windows
|
||||
|
||||
package server
|
||||
package shell
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
// lookupWithGetent looks up a user by name, falling back to getent if os/user fails.
|
||||
// Without CGO, os/user only reads /etc/passwd and misses NSS-provided users.
|
||||
// getent goes through the host's NSS stack.
|
||||
func lookupWithGetent(username string) (*user.User, error) {
|
||||
func LookupWithGetent(username string) (*user.User, error) {
|
||||
u, err := user.Lookup(username)
|
||||
if err == nil {
|
||||
return u, nil
|
||||
@@ -22,7 +22,7 @@ func lookupWithGetent(username string) (*user.User, error) {
|
||||
stdErr := err
|
||||
log.Debugf("os/user.Lookup(%q) failed, trying getent: %v", username, err)
|
||||
|
||||
u, _, getentErr := runGetent(username)
|
||||
u, _, getentErr := runGetentPasswd(username)
|
||||
if getentErr != nil {
|
||||
log.Debugf("getent fallback for %q also failed: %v", username, getentErr)
|
||||
return nil, stdErr
|
||||
@@ -31,8 +31,25 @@ func lookupWithGetent(username string) (*user.User, error) {
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// LookupGroupWithGetent returns the resolved group from either a gid or groupname
|
||||
func LookupGroupWithGetent(name string) (*user.Group, error) {
|
||||
g, err := user.LookupGroup(name)
|
||||
if err == nil {
|
||||
return g, nil
|
||||
}
|
||||
|
||||
stdErr := err
|
||||
log.Debugf("os/user.LookupGroup(%q) failed, trying getent: %v", name, err)
|
||||
g, getentErr := runGetentGroup(name)
|
||||
if getentErr != nil {
|
||||
log.Debugf("getent fallback for %q also failed: %v", name, getentErr)
|
||||
return nil, stdErr
|
||||
}
|
||||
return g, nil
|
||||
}
|
||||
|
||||
// currentUserWithGetent gets the current user, falling back to getent if os/user fails.
|
||||
func currentUserWithGetent() (*user.User, error) {
|
||||
func CurrentUserWithGetent() (*user.User, error) {
|
||||
u, err := user.Current()
|
||||
if err == nil {
|
||||
return u, nil
|
||||
@@ -42,7 +59,7 @@ func currentUserWithGetent() (*user.User, error) {
|
||||
uid := strconv.Itoa(os.Getuid())
|
||||
log.Debugf("os/user.Current() failed, trying getent with UID %s: %v", uid, err)
|
||||
|
||||
u, _, getentErr := runGetent(uid)
|
||||
u, _, getentErr := runGetentPasswd(uid)
|
||||
if getentErr != nil {
|
||||
return nil, stdErr
|
||||
}
|
||||
@@ -57,7 +74,7 @@ func currentUserWithGetent() (*user.User, error) {
|
||||
// only reads /etc/group and silently returns incomplete results for NSS users
|
||||
// (no error, just missing groups). The id command goes through NSS and returns
|
||||
// the full set.
|
||||
func groupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
func GroupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
ids, err := runIdGroups(u.Username)
|
||||
if err == nil {
|
||||
return ids, nil
|
||||
@@ -1,4 +1,4 @@
|
||||
package server
|
||||
package shell
|
||||
|
||||
import (
|
||||
"os/user"
|
||||
@@ -15,7 +15,7 @@ func TestLookupWithGetent_CurrentUser(t *testing.T) {
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
u, err := lookupWithGetent(current.Username)
|
||||
u, err := LookupWithGetent(current.Username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, current.Username, u.Username)
|
||||
assert.Equal(t, current.Uid, u.Uid)
|
||||
@@ -23,7 +23,7 @@ func TestLookupWithGetent_CurrentUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestLookupWithGetent_NonexistentUser(t *testing.T) {
|
||||
_, err := lookupWithGetent("nonexistent_user_xyzzy_12345")
|
||||
_, err := LookupWithGetent("nonexistent_user_xyzzy_12345")
|
||||
require.Error(t, err, "should fail for nonexistent user")
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ func TestCurrentUserWithGetent(t *testing.T) {
|
||||
stdUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
u, err := currentUserWithGetent()
|
||||
u, err := CurrentUserWithGetent()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, stdUser.Uid, u.Uid)
|
||||
assert.Equal(t, stdUser.Username, u.Username)
|
||||
@@ -41,7 +41,7 @@ func TestGroupIdsWithFallback_CurrentUser(t *testing.T) {
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, err := groupIdsWithFallback(current)
|
||||
groups, err := GroupIdsWithFallback(current)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, groups, "current user should have at least one group")
|
||||
|
||||
@@ -56,7 +56,7 @@ func TestGroupIdsWithFallback_CurrentUser(t *testing.T) {
|
||||
func TestGetShellFromGetent_CurrentUser(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
// Windows stub always returns empty, which is correct
|
||||
shell := getShellFromGetent("1000")
|
||||
shell := GetShellFromGetent("1000")
|
||||
assert.Empty(t, shell, "Windows stub should return empty")
|
||||
return
|
||||
}
|
||||
@@ -65,7 +65,7 @@ func TestGetShellFromGetent_CurrentUser(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// getent may not be available on all systems (e.g., macOS without Homebrew getent)
|
||||
shell := getShellFromGetent(current.Uid)
|
||||
shell := GetShellFromGetent(current.Uid)
|
||||
if shell == "" {
|
||||
t.Log("getShellFromGetent returned empty, getent may not be available")
|
||||
return
|
||||
@@ -78,7 +78,7 @@ func TestLookupWithGetent_RootUser(t *testing.T) {
|
||||
t.Skip("no root user on Windows")
|
||||
}
|
||||
|
||||
u, err := lookupWithGetent("root")
|
||||
u, err := LookupWithGetent("root")
|
||||
if err != nil {
|
||||
t.Skip("root user not available on this system")
|
||||
}
|
||||
@@ -91,20 +91,20 @@ func TestLookupWithGetent_RootUser(t *testing.T) {
|
||||
// consistent and correct results when composed together.
|
||||
func TestIntegration_FullLookupChain(t *testing.T) {
|
||||
// Step 1: currentUserWithGetent must resolve the running user.
|
||||
current, err := currentUserWithGetent()
|
||||
current, err := CurrentUserWithGetent()
|
||||
require.NoError(t, err, "currentUserWithGetent must resolve the running user")
|
||||
require.NotEmpty(t, current.Uid)
|
||||
require.NotEmpty(t, current.Username)
|
||||
|
||||
// Step 2: lookupWithGetent by the same username must return matching identity.
|
||||
byName, err := lookupWithGetent(current.Username)
|
||||
byName, err := LookupWithGetent(current.Username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, current.Uid, byName.Uid, "lookup by name should return same UID")
|
||||
assert.Equal(t, current.Gid, byName.Gid, "lookup by name should return same GID")
|
||||
assert.Equal(t, current.HomeDir, byName.HomeDir, "lookup by name should return same home")
|
||||
|
||||
// Step 3: groupIdsWithFallback must return at least the primary GID.
|
||||
groups, err := groupIdsWithFallback(current)
|
||||
groups, err := GroupIdsWithFallback(current)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, groups, "user must have at least one group")
|
||||
|
||||
@@ -123,7 +123,7 @@ func TestIntegration_FullLookupChain(t *testing.T) {
|
||||
// Step 4: getShellFromGetent should either return a valid shell path or empty
|
||||
// (empty is OK when getent is not available, e.g. macOS without Homebrew getent).
|
||||
if runtime.GOOS != "windows" {
|
||||
shell := getShellFromGetent(current.Uid)
|
||||
shell := GetShellFromGetent(current.Uid)
|
||||
if shell != "" {
|
||||
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
|
||||
}
|
||||
@@ -138,10 +138,10 @@ func TestIntegration_LookupAndGroupsConsistency(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate the SSH server flow: lookup user, then get their groups.
|
||||
resolved, err := lookupWithGetent(current.Username)
|
||||
resolved, err := LookupWithGetent(current.Username)
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, err := groupIdsWithFallback(resolved)
|
||||
groups, err := GroupIdsWithFallback(resolved)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, groups, "resolved user must have groups")
|
||||
|
||||
@@ -154,19 +154,3 @@ func TestIntegration_LookupAndGroupsConsistency(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_ShellLookupChain tests the full shell resolution chain
|
||||
// (getShellFromPasswd -> getShellFromGetent -> $SHELL -> default) on Unix.
|
||||
func TestIntegration_ShellLookupChain(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Unix shell lookup not applicable on Windows")
|
||||
}
|
||||
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// getUserShell is the top-level function used by the SSH server.
|
||||
shell := getUserShell(current.Uid)
|
||||
require.NotEmpty(t, shell, "getUserShell must always return a shell")
|
||||
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
package shell
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -14,19 +14,25 @@ import (
|
||||
|
||||
const getentTimeout = 5 * time.Second
|
||||
|
||||
// getShellFromGetent gets a user's login shell via getent by UID.
|
||||
// GetShellFromGetent gets a user's login shell via getent by UID.
|
||||
// This is needed even with CGO because getShellFromPasswd reads /etc/passwd
|
||||
// directly and won't find NSS-provided users there.
|
||||
func getShellFromGetent(userID string) string {
|
||||
_, shell, err := runGetent(userID)
|
||||
func GetShellFromGetent(userID string) string {
|
||||
_, shell, err := runGetentPasswd(userID)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return shell
|
||||
}
|
||||
|
||||
// runGetent executes `getent passwd <query>` and returns the user and login shell.
|
||||
func runGetent(query string) (*user.User, string, error) {
|
||||
// GetUserFromGetent returns the resolved group from either a uid or username
|
||||
func GetUserFromGetent(user string) (*user.User, error) {
|
||||
u, _, err := runGetentPasswd(user)
|
||||
return u, err
|
||||
}
|
||||
|
||||
// runGetentPasswd executes `getent passwd <query>` and returns the user and login shell.
|
||||
func runGetentPasswd(query string) (*user.User, string, error) {
|
||||
if !validateGetentInput(query) {
|
||||
return nil, "", fmt.Errorf("invalid getent input: %q", query)
|
||||
}
|
||||
@@ -42,6 +48,23 @@ func runGetent(query string) (*user.User, string, error) {
|
||||
return parseGetentPasswd(string(out))
|
||||
}
|
||||
|
||||
// runGetentGroup executes `getent group <query>` and returns the group
|
||||
func runGetentGroup(query string) (*user.Group, error) {
|
||||
if !validateGetentInput(query) {
|
||||
return nil, fmt.Errorf("invalid getent input: %q", query)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), getentTimeout)
|
||||
defer cancel()
|
||||
|
||||
out, err := exec.CommandContext(ctx, "getent", "group", query).Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getent passwd%s: %w", query, err)
|
||||
}
|
||||
|
||||
return parseGetentGroup(string(out))
|
||||
}
|
||||
|
||||
// parseGetentPasswd parses getent passwd output: "name:x:uid:gid:gecos:home:shell"
|
||||
func parseGetentPasswd(output string) (*user.User, string, error) {
|
||||
fields := strings.SplitN(strings.TrimSpace(output), ":", 8)
|
||||
@@ -67,6 +90,20 @@ func parseGetentPasswd(output string) (*user.User, string, error) {
|
||||
}, shell, nil
|
||||
}
|
||||
|
||||
// parseGetentGroup parses getent group output: "group:x:gid:user"
|
||||
func parseGetentGroup(output string) (*user.Group, error) {
|
||||
fields := strings.SplitN(strings.TrimSpace(output), ":", 8)
|
||||
if len(fields) < 4 {
|
||||
return nil, fmt.Errorf("unexpected getent output (need 4+ fields): %q", output)
|
||||
}
|
||||
|
||||
if fields[0] == "" || fields[2] == "" {
|
||||
return nil, fmt.Errorf("missing required fields in getent output: %q", output)
|
||||
}
|
||||
|
||||
return &user.Group{Gid: fields[2], Name: fields[0]}, nil
|
||||
}
|
||||
|
||||
// validateGetentInput checks that the input is safe to pass to getent or id.
|
||||
// Allows POSIX usernames, numeric UIDs, and common NSS extensions
|
||||
// (@ for Kerberos, $ for Samba, + for NIS compat).
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
package shell
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
@@ -195,7 +195,7 @@ func TestRunGetent_RootUser(t *testing.T) {
|
||||
t.Skip("getent not available on this system")
|
||||
}
|
||||
|
||||
u, shell, err := runGetent("root")
|
||||
u, shell, err := runGetentPasswd("root")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "root", u.Username)
|
||||
assert.Equal(t, "0", u.Uid)
|
||||
@@ -208,7 +208,7 @@ func TestRunGetent_ByUID(t *testing.T) {
|
||||
t.Skip("getent not available on this system")
|
||||
}
|
||||
|
||||
u, _, err := runGetent("0")
|
||||
u, _, err := runGetentPasswd("0")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "root", u.Username)
|
||||
assert.Equal(t, "0", u.Uid)
|
||||
@@ -219,15 +219,15 @@ func TestRunGetent_NonexistentUser(t *testing.T) {
|
||||
t.Skip("getent not available on this system")
|
||||
}
|
||||
|
||||
_, _, err := runGetent("nonexistent_user_xyzzy_12345")
|
||||
_, _, err := runGetentPasswd("nonexistent_user_xyzzy_12345")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRunGetent_InvalidInput(t *testing.T) {
|
||||
_, _, err := runGetent("")
|
||||
_, _, err := runGetentPasswd("")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, _, err = runGetent("user\x00name")
|
||||
_, _, err = runGetentPasswd("user\x00name")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -236,7 +236,7 @@ func TestRunGetent_NotAvailable(t *testing.T) {
|
||||
t.Skip("getent is available, can't test missing case")
|
||||
}
|
||||
|
||||
_, _, err := runGetent("root")
|
||||
_, _, err := runGetentPasswd("root")
|
||||
assert.Error(t, err, "should fail when getent is not installed")
|
||||
}
|
||||
|
||||
@@ -283,7 +283,7 @@ func TestGetentResultsMatchStdlib(t *testing.T) {
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
getentUser, _, err := runGetent(current.Username)
|
||||
getentUser, _, err := runGetentPasswd(current.Username)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, current.Username, getentUser.Username, "username should match")
|
||||
@@ -300,7 +300,7 @@ func TestGetentResultsMatchStdlib_ByUID(t *testing.T) {
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
getentUser, _, err := runGetent(current.Uid)
|
||||
getentUser, _, err := runGetentPasswd(current.Uid)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, current.Username, getentUser.Username, "username should match when looked up by UID")
|
||||
@@ -356,7 +356,7 @@ func TestGetShellFromPasswd_CurrentUser(t *testing.T) {
|
||||
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
|
||||
|
||||
if _, err := exec.LookPath("getent"); err == nil {
|
||||
_, getentShell, getentErr := runGetent(current.Uid)
|
||||
_, getentShell, getentErr := runGetentPasswd(current.Uid)
|
||||
if getentErr == nil && getentShell != "" {
|
||||
assert.Equal(t, getentShell, shell, "shell from /etc/passwd should match getent")
|
||||
}
|
||||
@@ -400,7 +400,7 @@ func TestGetShellFromPasswd_MatchesGetentForKnownUsers(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
|
||||
_, getentShell, err := runGetent(uid)
|
||||
_, getentShell, err := runGetentPasswd(uid)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@@ -1,26 +1,26 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
package shell
|
||||
|
||||
import "os/user"
|
||||
|
||||
// lookupWithGetent on Windows just delegates to os/user.Lookup.
|
||||
// Windows does not use NSS/getent; its user lookup works without CGO.
|
||||
func lookupWithGetent(username string) (*user.User, error) {
|
||||
func LookupWithGetent(username string) (*user.User, error) {
|
||||
return user.Lookup(username)
|
||||
}
|
||||
|
||||
// currentUserWithGetent on Windows just delegates to os/user.Current.
|
||||
func currentUserWithGetent() (*user.User, error) {
|
||||
func CurrentUserWithGetent() (*user.User, error) {
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
// getShellFromGetent is a no-op on Windows; shell resolution uses PowerShell detection.
|
||||
func getShellFromGetent(_ string) string {
|
||||
func GetShellFromGetent(_ string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// groupIdsWithFallback on Windows just delegates to u.GroupIds().
|
||||
func groupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
func GroupIdsWithFallback(u *user.User) ([]string, error) {
|
||||
return u.GroupIds()
|
||||
}
|
||||
@@ -1,17 +1,14 @@
|
||||
package server
|
||||
package shell
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -24,7 +21,7 @@ const (
|
||||
|
||||
// getUserShell returns the appropriate shell for the given user ID
|
||||
// Handles all platform-specific logic and fallbacks consistently
|
||||
func getUserShell(userID string) string {
|
||||
func GetUserShell(userID string) string {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return getWindowsUserShell()
|
||||
@@ -56,7 +53,7 @@ func getUnixUserShell(userID string) string {
|
||||
return shell
|
||||
}
|
||||
|
||||
if shell := getShellFromGetent(userID); shell != "" {
|
||||
if shell := GetShellFromGetent(userID); shell != "" {
|
||||
return shell
|
||||
}
|
||||
|
||||
@@ -101,8 +98,8 @@ func getShellFromPasswd(userID string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// prepareUserEnv prepares environment variables for user execution
|
||||
func prepareUserEnv(user *user.User, shell string) []string {
|
||||
// PrepareUserEnv prepares environment variables for user execution
|
||||
func PrepareUserEnv(user *user.User, shell string) []string {
|
||||
pathValue := "/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games"
|
||||
if runtime.GOOS == "windows" {
|
||||
pathValue = `C:\Windows\System32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0`
|
||||
@@ -119,7 +116,7 @@ func prepareUserEnv(user *user.User, shell string) []string {
|
||||
|
||||
// acceptEnv checks if environment variable from SSH client should be accepted
|
||||
// This is a whitelist of variables that SSH clients can send to the server
|
||||
func acceptEnv(envVar string) bool {
|
||||
func AcceptEnv(envVar string) bool {
|
||||
varName := envVar
|
||||
if idx := strings.Index(envVar, "="); idx != -1 {
|
||||
varName = envVar[:idx]
|
||||
@@ -156,29 +153,3 @@ func acceptEnv(envVar string) bool {
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// prepareSSHEnv prepares SSH protocol-specific environment variables
|
||||
// These variables provide information about the SSH connection itself
|
||||
func prepareSSHEnv(session ssh.Session) []string {
|
||||
remoteAddr := session.RemoteAddr()
|
||||
localAddr := session.LocalAddr()
|
||||
|
||||
remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
|
||||
if err != nil {
|
||||
remoteHost = remoteAddr.String()
|
||||
remotePort = "0"
|
||||
}
|
||||
|
||||
localHost, localPort, err := net.SplitHostPort(localAddr.String())
|
||||
if err != nil {
|
||||
localHost = localAddr.String()
|
||||
localPort = strconv.Itoa(InternalSSHPort)
|
||||
}
|
||||
|
||||
return []string{
|
||||
// SSH_CLIENT format: "client_ip client_port server_port"
|
||||
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
|
||||
// SSH_CONNECTION format: "client_ip client_port server_ip server_port"
|
||||
fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
|
||||
}
|
||||
}
|
||||
26
client/internal/shell/shell_test.go
Normal file
26
client/internal/shell/shell_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package shell
|
||||
|
||||
import (
|
||||
"os/user"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestIntegration_ShellLookupChain tests the full shell resolution chain
|
||||
// (getShellFromPasswd -> getShellFromGetent -> $SHELL -> default) on Unix.
|
||||
func TestIntegration_ShellLookupChain(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Unix shell lookup not applicable on Windows")
|
||||
}
|
||||
|
||||
current, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
// getUserShell is the top-level function used by the SSH server.
|
||||
shell := GetUserShell(current.Uid)
|
||||
require.NotEmpty(t, shell, "getUserShell must always return a shell")
|
||||
assert.True(t, shell[0] == '/', "shell should be an absolute path, got %q", shell)
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package mdm
|
||||
|
||||
import "strings"
|
||||
|
||||
// allKeys is the set of recognised MDM keys. Unknown keys in a managed
|
||||
// configuration are ignored but logged. Lives in this build-tagged file
|
||||
// (windows || darwin) because only desktop loaders need the
|
||||
// canonicalisation table that consumes it; including it unconditionally
|
||||
// would trigger the `unused` golangci-lint check on platforms that
|
||||
// don't import canonical_loaders.go.
|
||||
var allKeys = []string{
|
||||
KeyManagementURL,
|
||||
KeyDisableUpdateSettings,
|
||||
KeyDisableProfiles,
|
||||
KeyDisableNetworks,
|
||||
KeyDisableClientRoutes,
|
||||
KeyDisableServerRoutes,
|
||||
KeyBlockInbound,
|
||||
KeyDisableMetricsCollection,
|
||||
KeyAllowServerSSH,
|
||||
KeyDisableAutoConnect,
|
||||
KeyPreSharedKey,
|
||||
KeyRosenpassEnabled,
|
||||
KeyRosenpassPermissive,
|
||||
KeyWireguardPort,
|
||||
KeySplitTunnelMode,
|
||||
KeySplitTunnelApps,
|
||||
}
|
||||
|
||||
// canonicalKey maps the lowercase form of a managed-config value name to
|
||||
// its canonical mdm.Key* form. Admins commonly write PascalCase value
|
||||
// names in ADMX / Group Policy ("ManagementURL"); the iOS/AppConfig and
|
||||
// macOS plist conventions are camelCase ("managementURL"); both must
|
||||
// resolve to the same Policy lookup.
|
||||
//
|
||||
// Lives in a desktop-loader-only file (build tag `windows || darwin`)
|
||||
// because no other build path consumes it. Linux / FreeBSD / mobile
|
||||
// builds don't ship a platform loader that reads arbitrary-case key
|
||||
// names, so they don't need the canonicalisation table — and including
|
||||
// the var unconditionally would trigger the `unused` golangci-lint
|
||||
// check on those platforms.
|
||||
var canonicalKey = func() map[string]string {
|
||||
m := make(map[string]string, len(allKeys))
|
||||
for _, k := range allKeys {
|
||||
m[strings.ToLower(k)] = k
|
||||
}
|
||||
return m
|
||||
}()
|
||||
@@ -1,247 +0,0 @@
|
||||
// Package mdm reads MDM-managed configuration from platform-native sources
|
||||
// (plist on macOS, registry on Windows, UserDefaults on iOS,
|
||||
// RestrictionsManager on Android). The returned Policy is consumed by
|
||||
// profilemanager.Config.apply() as the highest-priority override layer.
|
||||
//
|
||||
// An empty Policy (no source present, or source present with zero keys)
|
||||
// means no MDM enforcement is active and the client behaves as if the
|
||||
// feature did not exist.
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Well-known policy keys. Names mirror the corresponding ConfigInput Go field
|
||||
// names (lowerCamelCase) so the daemon can map a Policy key directly to a
|
||||
// configuration field.
|
||||
const (
|
||||
KeyManagementURL = "managementURL"
|
||||
KeyDisableUpdateSettings = "disableUpdateSettings"
|
||||
KeyDisableProfiles = "disableProfiles"
|
||||
KeyDisableNetworks = "disableNetworks"
|
||||
KeyDisableClientRoutes = "disableClientRoutes"
|
||||
KeyDisableServerRoutes = "disableServerRoutes"
|
||||
KeyBlockInbound = "blockInbound"
|
||||
KeyDisableMetricsCollection = "disableMetricsCollection"
|
||||
KeyAllowServerSSH = "allowServerSSH"
|
||||
KeyDisableAutoConnect = "disableAutoConnect"
|
||||
KeyPreSharedKey = "preSharedKey"
|
||||
KeyRosenpassEnabled = "rosenpassEnabled"
|
||||
KeyRosenpassPermissive = "rosenpassPermissive"
|
||||
KeyWireguardPort = "wireguardPort"
|
||||
|
||||
// Split tunnel is modeled as a single conceptual policy with two
|
||||
// registry/plist values. KeySplitTunnelMode is the discriminator
|
||||
// ("allow" or "disallow"); KeySplitTunnelApps is a comma-separated
|
||||
// list of package names. The values are mutually exclusive by
|
||||
// construction — only one mode can be set at a time.
|
||||
KeySplitTunnelMode = "splitTunnelMode"
|
||||
KeySplitTunnelApps = "splitTunnelApps"
|
||||
)
|
||||
|
||||
// Split-tunnel mode literals (KeySplitTunnelMode values).
|
||||
const (
|
||||
SplitTunnelModeAllow = "allow"
|
||||
SplitTunnelModeDisallow = "disallow"
|
||||
)
|
||||
|
||||
// SecretKeys lists keys whose values must be redacted in logs.
|
||||
var SecretKeys = map[string]struct{}{
|
||||
KeyPreSharedKey: {},
|
||||
}
|
||||
|
||||
// boolStringLiterals enumerates the textual boolean encodings the
|
||||
// platform loaders may produce (Windows REG_SZ "true", iOS / Android
|
||||
// managed-config booleans-as-strings, etc.). Lookup keeps GetBool flat
|
||||
// (no nested switch on the string case).
|
||||
var boolStringLiterals = map[string]bool{
|
||||
"true": true,
|
||||
"1": true,
|
||||
"yes": true,
|
||||
"false": false,
|
||||
"0": false,
|
||||
"no": false,
|
||||
}
|
||||
|
||||
|
||||
// Policy holds MDM-managed settings read from the platform source. A nil or
|
||||
// empty Policy means no enforcement is active.
|
||||
type Policy struct {
|
||||
values map[string]any
|
||||
}
|
||||
|
||||
// NewPolicy constructs a Policy from a key→value map. Pass nil or an
|
||||
// empty map to construct an empty (no-enforcement) Policy. The returned
|
||||
// *Policy is always non-nil.
|
||||
func NewPolicy(values map[string]any) *Policy {
|
||||
if values == nil {
|
||||
values = map[string]any{}
|
||||
}
|
||||
return &Policy{values: values}
|
||||
}
|
||||
|
||||
// LoadPolicy reads the platform-native MDM configuration. Returns an
|
||||
// empty (but non-nil) Policy when no source is present, the source is
|
||||
// empty, or the platform is unsupported.
|
||||
//
|
||||
// Diagnostic logging differentiates the three states:
|
||||
// - source absent / unsupported platform: trace log only
|
||||
// - source present, zero keys: info "MDM enrolled (no managed keys)"
|
||||
// - source present, N keys: info "MDM enrolled with N managed keys: [...]"
|
||||
func LoadPolicy() *Policy {
|
||||
values, err := loadPlatformPolicy()
|
||||
if err != nil {
|
||||
log.Tracef("MDM policy load: %v", err)
|
||||
return &Policy{values: map[string]any{}}
|
||||
}
|
||||
if values == nil {
|
||||
return &Policy{values: map[string]any{}}
|
||||
}
|
||||
if len(values) == 0 {
|
||||
log.Info("MDM enrolled (no managed keys)")
|
||||
} else {
|
||||
log.Infof("MDM enrolled with %d managed key(s): %v", len(values), sortedKeys(values))
|
||||
}
|
||||
return &Policy{values: values}
|
||||
}
|
||||
|
||||
// IsEmpty reports whether the Policy has no managed keys.
|
||||
func (p *Policy) IsEmpty() bool {
|
||||
return p == nil || len(p.values) == 0
|
||||
}
|
||||
|
||||
// HasKey reports whether the given key is MDM-managed.
|
||||
func (p *Policy) HasKey(key string) bool {
|
||||
if p == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := p.values[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ManagedKeys returns the sorted list of managed key names. Returns an empty
|
||||
// slice (not nil) on an empty Policy.
|
||||
func (p *Policy) ManagedKeys() []string {
|
||||
if p == nil {
|
||||
return []string{}
|
||||
}
|
||||
return sortedKeys(p.values)
|
||||
}
|
||||
|
||||
// GetString returns the managed value for key coerced to string, and whether
|
||||
// the key was set. A non-string value returns ("", false).
|
||||
func (p *Policy) GetString(key string) (string, bool) {
|
||||
if p == nil {
|
||||
return "", false
|
||||
}
|
||||
v, ok := p.values[key]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok || s == "" {
|
||||
return "", false
|
||||
}
|
||||
return s, true
|
||||
}
|
||||
|
||||
// GetBool returns the managed value for key coerced to bool, and whether the
|
||||
// key was set. Accepts native bool and string literals "true"/"false"/"1"/"0".
|
||||
func (p *Policy) GetBool(key string) (bool, bool) {
|
||||
if p == nil {
|
||||
return false, false
|
||||
}
|
||||
v, ok := p.values[key]
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case bool:
|
||||
return t, true
|
||||
case string:
|
||||
b, known := boolStringLiterals[t]
|
||||
return b, known
|
||||
case int:
|
||||
return t != 0, true
|
||||
case int64:
|
||||
return t != 0, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
// GetInt returns the managed value for key as int64, and whether the key
|
||||
// was set. Accepts native int / int64 (as produced by the Windows registry
|
||||
// loader for REG_DWORD/REG_QWORD) and numeric strings (decimal).
|
||||
func (p *Policy) GetInt(key string) (int64, bool) {
|
||||
if p == nil {
|
||||
return 0, false
|
||||
}
|
||||
v, ok := p.values[key]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case int64:
|
||||
return t, true
|
||||
case int:
|
||||
return int64(t), true
|
||||
case int32:
|
||||
return int64(t), true
|
||||
case uint64:
|
||||
return int64(t), true
|
||||
case float64:
|
||||
return int64(t), true
|
||||
case string:
|
||||
if n, err := strconv.ParseInt(t, 10, 64); err == nil {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetStringSlice returns the managed value for key as []string, and whether
|
||||
// the key was set. Accepts []string, []any (of strings), and a single string
|
||||
// (treated as a one-element list).
|
||||
func (p *Policy) GetStringSlice(key string) ([]string, bool) {
|
||||
if p == nil {
|
||||
return nil, false
|
||||
}
|
||||
v, ok := p.values[key]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case []string:
|
||||
return append([]string(nil), t...), true
|
||||
case []any:
|
||||
out := make([]string, 0, len(t))
|
||||
for _, item := range t {
|
||||
s, ok := item.(string)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
out = append(out, s)
|
||||
}
|
||||
return out, true
|
||||
case string:
|
||||
return []string{t}, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// sortedKeys returns the keys of m as a deterministic, lexicographically
|
||||
// sorted slice. Used internally by Policy.ManagedKeys and LoadPolicy's
|
||||
// diagnostic log line so callers see a stable key order across runs
|
||||
// regardless of Go's randomised map iteration.
|
||||
func sortedKeys(m map[string]any) []string {
|
||||
out := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
out = append(out, k)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"howett.net/plist"
|
||||
)
|
||||
|
||||
// policyPlistPath is the well-known location where macOS writes the
|
||||
// device-level mandatory MDM payload for NetBird. The path is fixed by
|
||||
// Apple convention: when an MDM provider (Jamf / Kandji / Mosyle /
|
||||
// Intune for Mac / Workspace ONE) pushes a Configuration Profile that
|
||||
// contains a com.apple.ManagedClient.preferences payload targeting the
|
||||
// bundle id io.netbird.client, the OS materializes the payload here.
|
||||
//
|
||||
// Read-only — only the OS (root) is supposed to write this file. The
|
||||
// loader sanity-checks the file mode and refuses to honour a world-
|
||||
// writable plist, as a defense against tampered installs.
|
||||
const policyPlistPath = "/Library/Managed Preferences/io.netbird.client.plist"
|
||||
|
||||
// loadPlatformPolicy reads the MDM-managed configuration from the macOS
|
||||
// managed-preferences plist at policyPlistPath. Returns:
|
||||
// - (nil, nil) when the plist is absent (device not MDM-enrolled for
|
||||
// NetBird, or admin has not yet pushed a payload)
|
||||
// - (map, nil) with N entries when N managed values are present
|
||||
// (N may be 0 — empty plist still signals enrollment to the caller)
|
||||
// - (nil, err) on permission / parse / safety errors (including
|
||||
// refusal to read a world-writable plist)
|
||||
//
|
||||
// Top-level plist keys are canonicalised case-insensitively to the
|
||||
// package's internal mdm.Key* names; unknown keys are logged and
|
||||
// skipped so a stray entry in the payload does not block startup.
|
||||
// Native plist value types map naturally onto the Policy accessor
|
||||
// expectations (GetString / GetBool / GetInt / GetStringSlice).
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
f, err := os.Open(policyPlistPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
// Not enrolled for NetBird. Caller treats nil as
|
||||
// "no MDM source present".
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("open %s: %w", policyPlistPath, err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := f.Close(); closeErr != nil {
|
||||
log.Warnf("MDM close plist %s: %v", policyPlistPath, closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
info, err := f.Stat()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stat %s: %w", policyPlistPath, err)
|
||||
}
|
||||
// World-writable plist => tampered install. Refuse rather than
|
||||
// honour potentially attacker-controlled policy values.
|
||||
if info.Mode().Perm()&0o002 != 0 {
|
||||
return nil, fmt.Errorf("refusing to read world-writable MDM source %s (mode %o)",
|
||||
policyPlistPath, info.Mode().Perm())
|
||||
}
|
||||
|
||||
raw := make(map[string]any)
|
||||
if err := plist.NewDecoder(f).Decode(&raw); err != nil {
|
||||
return nil, fmt.Errorf("decode plist %s: %w", policyPlistPath, err)
|
||||
}
|
||||
|
||||
out := make(map[string]any, len(raw))
|
||||
for name, val := range raw {
|
||||
// macOS / AppConfig conventions both use camelCase for managed
|
||||
// preferences keys; canonicalize to the mdm.Key* form so a key
|
||||
// written as "ManagementURL" (PascalCase, rare on macOS but
|
||||
// possible if the admin reused an ADMX-style name) still
|
||||
// resolves.
|
||||
canonical, known := canonicalKey[strings.ToLower(name)]
|
||||
if !known {
|
||||
log.Warnf("MDM ignoring unknown plist key %s: %s", policyPlistPath, name)
|
||||
continue
|
||||
}
|
||||
out[canonical] = val
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
//go:build ios || android
|
||||
|
||||
package mdm
|
||||
|
||||
// loadPlatformPolicy is unused on mobile: the native layer (Swift on iOS,
|
||||
// Kotlin/Java on Android) reads the OS managed-config store and pushes the
|
||||
// resulting dictionary in-process via a gomobile entry point that lands in
|
||||
// Phase 5 / Phase 6. The stub keeps the package compilable for mobile
|
||||
// builds and returns (nil, nil) — the platform-absent sentinel that
|
||||
// LoadPolicy in policy.go treats as "no MDM source present".
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
//go:build !windows && !darwin && !ios && !android
|
||||
|
||||
package mdm
|
||||
|
||||
// loadPlatformPolicy returns no policy on platforms without an MDM channel
|
||||
// (Linux, FreeBSD). MDM enforcement is off and the client behaves as if
|
||||
// the feature did not exist. Returns (nil, nil) — the platform-absent
|
||||
// sentinel the caller (LoadPolicy in policy.go) treats as "no MDM
|
||||
// source present"; an error here would just translate to the same
|
||||
// outcome with an extra log line.
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
@@ -1,160 +0,0 @@
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPolicy_NilSafe(t *testing.T) {
|
||||
var p *Policy
|
||||
assert.True(t, p.IsEmpty())
|
||||
assert.False(t, p.HasKey(KeyManagementURL))
|
||||
assert.Empty(t, p.ManagedKeys())
|
||||
|
||||
_, ok := p.GetString(KeyManagementURL)
|
||||
assert.False(t, ok)
|
||||
_, ok = p.GetBool(KeyDisableProfiles)
|
||||
assert.False(t, ok)
|
||||
_, ok = p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPolicy_Empty(t *testing.T) {
|
||||
p := NewPolicy(nil)
|
||||
require.NotNil(t, p)
|
||||
assert.True(t, p.IsEmpty())
|
||||
assert.False(t, p.HasKey(KeyManagementURL))
|
||||
assert.Empty(t, p.ManagedKeys())
|
||||
}
|
||||
|
||||
func TestPolicy_HasKey(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true,
|
||||
})
|
||||
assert.False(t, p.IsEmpty())
|
||||
assert.True(t, p.HasKey(KeyManagementURL))
|
||||
assert.True(t, p.HasKey(KeyDisableProfiles))
|
||||
assert.False(t, p.HasKey(KeyPreSharedKey))
|
||||
}
|
||||
|
||||
func TestPolicy_ManagedKeysSorted(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyDisableProfiles: true,
|
||||
KeyManagementURL: "https://x",
|
||||
KeyAllowServerSSH: false,
|
||||
})
|
||||
got := p.ManagedKeys()
|
||||
assert.Equal(t, []string{KeyAllowServerSSH, KeyDisableProfiles, KeyManagementURL}, got)
|
||||
}
|
||||
|
||||
func TestPolicy_GetString(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true, // wrong type for GetString
|
||||
KeyPreSharedKey: "", // empty rejected
|
||||
})
|
||||
v, ok := p.GetString(KeyManagementURL)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "https://corp.example.com", v)
|
||||
|
||||
_, ok = p.GetString(KeyDisableProfiles)
|
||||
assert.False(t, ok, "non-string value must not be reported as string")
|
||||
|
||||
_, ok = p.GetString(KeyPreSharedKey)
|
||||
assert.False(t, ok, "empty string treated as unset")
|
||||
|
||||
_, ok = p.GetString("nonexistent")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPolicy_GetBool(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
raw any
|
||||
want bool
|
||||
ok bool
|
||||
}{
|
||||
{"native true", true, true, true},
|
||||
{"native false", false, false, true},
|
||||
{"string true", "true", true, true},
|
||||
{"string false", "false", false, true},
|
||||
{"string 1", "1", true, true},
|
||||
{"string 0", "0", false, true},
|
||||
{"string yes", "yes", true, true},
|
||||
{"string no", "no", false, true},
|
||||
{"int nonzero", 1, true, true},
|
||||
{"int zero", 0, false, true},
|
||||
{"int64 nonzero", int64(2), true, true},
|
||||
{"int64 zero", int64(0), false, true},
|
||||
{"string garbage", "maybe", false, false},
|
||||
{"float unsupported", 1.0, false, false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{KeyDisableProfiles: c.raw})
|
||||
got, ok := p.GetBool(KeyDisableProfiles)
|
||||
assert.Equal(t, c.ok, ok)
|
||||
if c.ok {
|
||||
assert.Equal(t, c.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
_, ok := NewPolicy(nil).GetBool(KeyDisableProfiles)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPolicy_GetStringSlice(t *testing.T) {
|
||||
t.Run("native string slice", func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeySplitTunnelApps: []string{"com.a", "com.b"},
|
||||
})
|
||||
got, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"com.a", "com.b"}, got)
|
||||
})
|
||||
|
||||
t.Run("any slice of strings", func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeySplitTunnelApps: []any{"com.a", "com.b"},
|
||||
})
|
||||
got, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"com.a", "com.b"}, got)
|
||||
})
|
||||
|
||||
t.Run("single string lifts to one-element slice", func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeySplitTunnelApps: "com.a",
|
||||
})
|
||||
got, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"com.a"}, got)
|
||||
})
|
||||
|
||||
t.Run("mixed any slice rejected", func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeySplitTunnelApps: []any{"com.a", 1},
|
||||
})
|
||||
_, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("missing key", func(t *testing.T) {
|
||||
p := NewPolicy(nil)
|
||||
_, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadPolicy_PlatformStubReturnsEmpty(t *testing.T) {
|
||||
// loadPlatformPolicy is a stub on every OS for Phase 1. LoadPolicy must
|
||||
// degrade gracefully and never return nil.
|
||||
p := LoadPolicy()
|
||||
require.NotNil(t, p)
|
||||
assert.True(t, p.IsEmpty())
|
||||
assert.Empty(t, p.ManagedKeys())
|
||||
}
|
||||
@@ -1,108 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
// policyRegistryPath is the well-known MDM policy registry key for NetBird.
|
||||
// Admins push values here through Group Policy, Intune ADMX ingestion, an
|
||||
// Intune custom Registry CSP profile, or `reg add` during MSI deployment.
|
||||
// Listed in the project's docs/mdm/netbird.admx schema.
|
||||
const policyRegistryPath = `Software\Policies\NetBird`
|
||||
|
||||
// readRegistryValue reads a single value under policyRegistryPath and,
|
||||
// on success, stores the type-coerced result in out[canonical]. Type
|
||||
// coercion mirrors loadPlatformPolicy's documented mapping:
|
||||
// - REG_SZ / REG_EXPAND_SZ -> string (REG_EXPAND_SZ is expanded by the API)
|
||||
// - REG_DWORD / REG_QWORD -> int64
|
||||
// - REG_MULTI_SZ -> []string
|
||||
//
|
||||
// Unsupported value types and per-value read failures are logged at
|
||||
// warn level and skipped — one malformed value must not block the
|
||||
// surrounding loop. Extracted from loadPlatformPolicy to keep that
|
||||
// function's cognitive complexity in check.
|
||||
func readRegistryValue(k registry.Key, name, canonical string, out map[string]any) {
|
||||
_, valType, err := k.GetValue(name, nil)
|
||||
if err != nil {
|
||||
log.Warnf("MDM stat %s\\%s: %v", policyRegistryPath, name, err)
|
||||
return
|
||||
}
|
||||
switch valType {
|
||||
case registry.SZ, registry.EXPAND_SZ:
|
||||
if v, _, err := k.GetStringValue(name); err == nil {
|
||||
out[canonical] = v
|
||||
} else {
|
||||
log.Warnf("MDM read string %s\\%s: %v", policyRegistryPath, name, err)
|
||||
}
|
||||
case registry.DWORD, registry.QWORD:
|
||||
if v, _, err := k.GetIntegerValue(name); err == nil {
|
||||
// uint64 from the registry API; Policy.GetBool / GetInt
|
||||
// helpers consume int64, so narrow safely.
|
||||
out[canonical] = int64(v)
|
||||
} else {
|
||||
log.Warnf("MDM read int %s\\%s: %v", policyRegistryPath, name, err)
|
||||
}
|
||||
case registry.MULTI_SZ:
|
||||
if v, _, err := k.GetStringsValue(name); err == nil {
|
||||
out[canonical] = v
|
||||
} else {
|
||||
log.Warnf("MDM read multi-string %s\\%s: %v", policyRegistryPath, name, err)
|
||||
}
|
||||
default:
|
||||
log.Warnf("MDM ignoring unsupported registry value type %d at %s\\%s",
|
||||
valType, policyRegistryPath, name)
|
||||
}
|
||||
}
|
||||
|
||||
// loadPlatformPolicy reads the MDM-managed configuration from the
|
||||
// Windows registry under HKLM\Software\Policies\NetBird. Returns:
|
||||
// - (nil, nil) when the key is absent (device not MDM-enrolled for NetBird)
|
||||
// - (map, nil) with N entries when N managed values are set (N may be 0)
|
||||
// - (nil, err) on open / enumerate registry errors
|
||||
//
|
||||
// Per-value type coercion + skip-on-error is delegated to
|
||||
// readRegistryValue. Unknown value names are logged and skipped so a
|
||||
// malformed deployment does not block startup.
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, policyRegistryPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
if errors.Is(err, registry.ErrNotExist) {
|
||||
// Not enrolled. Caller treats nil as "no MDM source present".
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("open %s: %w", policyRegistryPath, err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := k.Close(); closeErr != nil {
|
||||
log.Warnf("MDM close registry key %s: %v", policyRegistryPath, closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
names, err := k.ReadValueNames(-1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("enumerate values of %s: %w", policyRegistryPath, err)
|
||||
}
|
||||
|
||||
out := make(map[string]any, len(names))
|
||||
for _, name := range names {
|
||||
// Canonicalize the registry value name against the known MDM key
|
||||
// set so Policy.HasKey lookups (which use the canonical names)
|
||||
// succeed regardless of the casing used by the admin's ADMX or
|
||||
// `reg add` command.
|
||||
canonical, known := canonicalKey[strings.ToLower(name)]
|
||||
if !known {
|
||||
log.Warnf("MDM ignoring unknown registry value %s\\%s", policyRegistryPath, name)
|
||||
continue
|
||||
}
|
||||
readRegistryValue(k, name, canonical, out)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -1,129 +0,0 @@
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DefaultReloadInterval is the production cadence at which the desktop daemon
|
||||
// re-reads the OS-native MDM policy. Picked to balance responsiveness against
|
||||
// registry/plist I/O overhead. Mobile builds use OS-side notifications
|
||||
// instead, hence anticipating the ticker mechanism entirely.
|
||||
const DefaultReloadInterval = 1 * time.Minute
|
||||
|
||||
// policyLoader is the indirection through which the ticker reads the
|
||||
// OS-native policy, both for the initial observation and on every tick.
|
||||
// Production points it at LoadPolicy; tests in this package override it to
|
||||
// feed a scripted sequence of policies without touching the real OS store.
|
||||
var policyLoader = LoadPolicy
|
||||
|
||||
// Ticker periodically re-reads the OS-native MDM policy via LoadPolicy and
|
||||
// invokes the onChange callback (supplied to Run) whenever the observed
|
||||
// Policy diverges from the last observation (added / removed / changed
|
||||
// keys). Launch with Run from a goroutine; cancel the supplied context
|
||||
// to stop.
|
||||
type Ticker struct {
|
||||
interval time.Duration
|
||||
prev *Policy
|
||||
}
|
||||
|
||||
// NewTicker constructs a Ticker that will re-read the OS-native policy
|
||||
// every reloadInterval once Run is called.
|
||||
// The initial snapshot is populated by calling policyLoader at
|
||||
// construction time so the first tick only fires
|
||||
// onChange when the policy actually changed since boot — without
|
||||
// this baseline the first tick would report every currently-managed
|
||||
// key as "added" and trigger a spurious engine restart.
|
||||
func NewTicker(reloadInterval time.Duration) *Ticker {
|
||||
return &Ticker{
|
||||
interval: reloadInterval,
|
||||
prev: policyLoader(),
|
||||
}
|
||||
}
|
||||
|
||||
// Run blocks until ctx is cancelled, polling the OS-native policy store at
|
||||
// the configured cadence and emitting log lines + onChange callback on
|
||||
// every observed diff. onChange must be non-nil.
|
||||
func (t *Ticker) Run(ctx context.Context, onChange func(prev, curr *Policy) error) {
|
||||
tk := time.NewTicker(t.interval)
|
||||
defer tk.Stop()
|
||||
log.Infof("MDM policy reload ticker started (interval=%s)", t.interval)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("MDM policy reload ticker stopped")
|
||||
return
|
||||
case <-tk.C:
|
||||
curr := policyLoader()
|
||||
if policiesEqual(t.prev, curr) {
|
||||
continue
|
||||
}
|
||||
added, removed, changed := diffPolicies(t.prev, curr)
|
||||
log.Infof("MDM policy changed: added=%v removed=%v changed=%v",
|
||||
added, removed, changed)
|
||||
prev := t.prev
|
||||
if err := onChange(prev, curr); err != nil {
|
||||
log.Errorf("MDM policy change handler failed (retrying in 1 minute): %v", err)
|
||||
continue
|
||||
}
|
||||
t.prev = curr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// policiesEqual reports whether two Policy instances carry the same
|
||||
// managed key set with identical values. Nil and empty policies
|
||||
// compare equal; one-nil/one-non-empty compare not equal; otherwise
|
||||
// the underlying values maps are compared with reflect.DeepEqual.
|
||||
func policiesEqual(a, b *Policy) bool {
|
||||
if a.IsEmpty() && b.IsEmpty() {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return reflect.DeepEqual(a.values, b.values)
|
||||
}
|
||||
|
||||
// diffPolicies returns the keys added in curr, removed from prev, and
|
||||
// whose values changed between prev and curr. Each slice is sorted
|
||||
// lexicographically for stable log output; value differences are
|
||||
// determined with reflect.DeepEqual.
|
||||
func diffPolicies(prev, curr *Policy) (added, removed, changed []string) {
|
||||
prevKVs := mapOf(prev)
|
||||
currKVs := mapOf(curr)
|
||||
for k := range currKVs {
|
||||
if _, ok := prevKVs[k]; !ok {
|
||||
added = append(added, k)
|
||||
} else if !reflect.DeepEqual(prevKVs[k], currKVs[k]) {
|
||||
changed = append(changed, k)
|
||||
}
|
||||
}
|
||||
for k := range prevKVs {
|
||||
if _, ok := currKVs[k]; !ok {
|
||||
removed = append(removed, k)
|
||||
}
|
||||
}
|
||||
sort.Strings(added)
|
||||
sort.Strings(removed)
|
||||
sort.Strings(changed)
|
||||
return added, removed, changed
|
||||
}
|
||||
|
||||
// mapOf returns a (possibly empty, never nil) copy of the underlying
|
||||
// values map of a Policy so callers outside this package can compare
|
||||
// keys/values across the type boundary. Returns an empty map on nil p.
|
||||
func mapOf(p *Policy) map[string]any {
|
||||
if p == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
out := make(map[string]any, len(p.values))
|
||||
for k, v := range p.values {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -1,100 +0,0 @@
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testReloadInterval for speeding up the ticker cadence under `go test`
|
||||
const testReloadInterval = 1 * time.Second
|
||||
|
||||
// withPolicyLoader overrides the package-level policyLoader for the duration
|
||||
// of the test so the ticker observes a scripted policy instead of the real
|
||||
// OS-native store. The original loader is restored on cleanup.
|
||||
func withPolicyLoader(t *testing.T, fn func() *Policy) {
|
||||
t.Helper()
|
||||
prev := policyLoader
|
||||
policyLoader = fn
|
||||
t.Cleanup(func() { policyLoader = prev })
|
||||
}
|
||||
|
||||
func TestTicker_FiresOnChangeWithDelta(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
current := NewPolicy(nil) // initial observation: empty (no enforcement)
|
||||
withPolicyLoader(t, func() *Policy {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return current
|
||||
})
|
||||
|
||||
type change struct{ prev, curr *Policy }
|
||||
changes := make(chan change, 1)
|
||||
tk := NewTicker(testReloadInterval)
|
||||
require.Equal(t, testReloadInterval, tk.interval)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
tk.Run(ctx, func(prev, curr *Policy) error {
|
||||
select {
|
||||
case changes <- change{prev, curr}:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
// Stop Run and wait for it to exit before returning, so the policyLoader
|
||||
// restore in t.Cleanup can't race the ticker goroutine still reading it.
|
||||
defer func() { cancel(); <-done }()
|
||||
|
||||
// Flip the OS-observed policy from empty to one managed key. The next
|
||||
// tick must detect the diff and invoke onChange.
|
||||
mu.Lock()
|
||||
current = NewPolicy(map[string]any{KeyManagementURL: "https://mdm.example.com:443"})
|
||||
mu.Unlock()
|
||||
|
||||
select {
|
||||
case c := <-changes:
|
||||
assert.True(t, c.prev.IsEmpty(), "prev should be the initial empty policy")
|
||||
assert.True(t, c.curr.HasKey(KeyManagementURL), "curr should carry the newly-pushed managed key")
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("onChange not invoked within 5s; ticker should fire every 1s under test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTicker_NoCallbackWhenPolicyUnchanged(t *testing.T) {
|
||||
withPolicyLoader(t, func() *Policy {
|
||||
return NewPolicy(map[string]any{KeyBlockInbound: true})
|
||||
})
|
||||
|
||||
fired := make(chan struct{}, 1)
|
||||
tk := NewTicker(testReloadInterval)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
tk.Run(ctx, func(_, _ *Policy) error {
|
||||
select {
|
||||
case fired <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
defer func() { cancel(); <-done }()
|
||||
|
||||
// Over ~2 ticks at the 1s test cadence the policy never changes, so the
|
||||
// diff guard must suppress the callback entirely.
|
||||
select {
|
||||
case <-fired:
|
||||
t.Fatal("onChange fired despite an unchanged policy")
|
||||
case <-time.After(2500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
@@ -1191,14 +1191,8 @@ type GetConfigResponse struct {
|
||||
DisableSSHAuth bool `protobuf:"varint,25,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"`
|
||||
SshJWTCacheTTL int32 `protobuf:"varint,26,opt,name=sshJWTCacheTTL,proto3" json:"sshJWTCacheTTL,omitempty"`
|
||||
DisableIpv6 bool `protobuf:"varint,27,opt,name=disable_ipv6,json=disableIpv6,proto3" json:"disable_ipv6,omitempty"`
|
||||
// mDMManagedFields lists the names of configuration keys whose value is
|
||||
// currently enforced by an MDM policy. Names match mdm.Key* constants
|
||||
// (e.g. "managementURL", "disableClientRoutes"). UI/CLI clients should
|
||||
// render the corresponding inputs as read-only and display a "managed
|
||||
// by MDM" indicator.
|
||||
MDMManagedFields []string `protobuf:"bytes,28,rep,name=mDMManagedFields,proto3" json:"mDMManagedFields,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *GetConfigResponse) Reset() {
|
||||
@@ -1420,13 +1414,6 @@ func (x *GetConfigResponse) GetDisableIpv6() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *GetConfigResponse) GetMDMManagedFields() []string {
|
||||
if x != nil {
|
||||
return x.MDMManagedFields
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
type PeerState struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
@@ -4974,55 +4961,6 @@ func (x *GetFeaturesResponse) GetDisableNetworks() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// MDMManagedFieldsViolation is attached as a gRPC error detail on a
|
||||
// FailedPrecondition status returned from SetConfig (and similar mutating
|
||||
// RPCs) when the caller tries to modify one or more MDM-enforced fields.
|
||||
// The fields list contains the offending key names; the entire request is
|
||||
// rejected (no partial apply).
|
||||
type MDMManagedFieldsViolation struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Fields []string `protobuf:"bytes,1,rep,name=fields,proto3" json:"fields,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *MDMManagedFieldsViolation) Reset() {
|
||||
*x = MDMManagedFieldsViolation{}
|
||||
mi := &file_daemon_proto_msgTypes[71]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *MDMManagedFieldsViolation) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*MDMManagedFieldsViolation) ProtoMessage() {}
|
||||
|
||||
func (x *MDMManagedFieldsViolation) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[71]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use MDMManagedFieldsViolation.ProtoReflect.Descriptor instead.
|
||||
func (*MDMManagedFieldsViolation) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{71}
|
||||
}
|
||||
|
||||
func (x *MDMManagedFieldsViolation) GetFields() []string {
|
||||
if x != nil {
|
||||
return x.Fields
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type TriggerUpdateRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
@@ -5031,7 +4969,7 @@ type TriggerUpdateRequest struct {
|
||||
|
||||
func (x *TriggerUpdateRequest) Reset() {
|
||||
*x = TriggerUpdateRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[72]
|
||||
mi := &file_daemon_proto_msgTypes[71]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5043,7 +4981,7 @@ func (x *TriggerUpdateRequest) String() string {
|
||||
func (*TriggerUpdateRequest) ProtoMessage() {}
|
||||
|
||||
func (x *TriggerUpdateRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[72]
|
||||
mi := &file_daemon_proto_msgTypes[71]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5056,7 +4994,7 @@ func (x *TriggerUpdateRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use TriggerUpdateRequest.ProtoReflect.Descriptor instead.
|
||||
func (*TriggerUpdateRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{72}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{71}
|
||||
}
|
||||
|
||||
type TriggerUpdateResponse struct {
|
||||
@@ -5069,7 +5007,7 @@ type TriggerUpdateResponse struct {
|
||||
|
||||
func (x *TriggerUpdateResponse) Reset() {
|
||||
*x = TriggerUpdateResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[73]
|
||||
mi := &file_daemon_proto_msgTypes[72]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5081,7 +5019,7 @@ func (x *TriggerUpdateResponse) String() string {
|
||||
func (*TriggerUpdateResponse) ProtoMessage() {}
|
||||
|
||||
func (x *TriggerUpdateResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[73]
|
||||
mi := &file_daemon_proto_msgTypes[72]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5094,7 +5032,7 @@ func (x *TriggerUpdateResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use TriggerUpdateResponse.ProtoReflect.Descriptor instead.
|
||||
func (*TriggerUpdateResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{73}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{72}
|
||||
}
|
||||
|
||||
func (x *TriggerUpdateResponse) GetSuccess() bool {
|
||||
@@ -5122,7 +5060,7 @@ type GetPeerSSHHostKeyRequest struct {
|
||||
|
||||
func (x *GetPeerSSHHostKeyRequest) Reset() {
|
||||
*x = GetPeerSSHHostKeyRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[74]
|
||||
mi := &file_daemon_proto_msgTypes[73]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5134,7 +5072,7 @@ func (x *GetPeerSSHHostKeyRequest) String() string {
|
||||
func (*GetPeerSSHHostKeyRequest) ProtoMessage() {}
|
||||
|
||||
func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[74]
|
||||
mi := &file_daemon_proto_msgTypes[73]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5147,7 +5085,7 @@ func (x *GetPeerSSHHostKeyRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use GetPeerSSHHostKeyRequest.ProtoReflect.Descriptor instead.
|
||||
func (*GetPeerSSHHostKeyRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{74}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{73}
|
||||
}
|
||||
|
||||
func (x *GetPeerSSHHostKeyRequest) GetPeerAddress() string {
|
||||
@@ -5174,7 +5112,7 @@ type GetPeerSSHHostKeyResponse struct {
|
||||
|
||||
func (x *GetPeerSSHHostKeyResponse) Reset() {
|
||||
*x = GetPeerSSHHostKeyResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[75]
|
||||
mi := &file_daemon_proto_msgTypes[74]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5186,7 +5124,7 @@ func (x *GetPeerSSHHostKeyResponse) String() string {
|
||||
func (*GetPeerSSHHostKeyResponse) ProtoMessage() {}
|
||||
|
||||
func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[75]
|
||||
mi := &file_daemon_proto_msgTypes[74]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5199,7 +5137,7 @@ func (x *GetPeerSSHHostKeyResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use GetPeerSSHHostKeyResponse.ProtoReflect.Descriptor instead.
|
||||
func (*GetPeerSSHHostKeyResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{75}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{74}
|
||||
}
|
||||
|
||||
func (x *GetPeerSSHHostKeyResponse) GetSshHostKey() []byte {
|
||||
@@ -5241,7 +5179,7 @@ type RequestJWTAuthRequest struct {
|
||||
|
||||
func (x *RequestJWTAuthRequest) Reset() {
|
||||
*x = RequestJWTAuthRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[76]
|
||||
mi := &file_daemon_proto_msgTypes[75]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5253,7 +5191,7 @@ func (x *RequestJWTAuthRequest) String() string {
|
||||
func (*RequestJWTAuthRequest) ProtoMessage() {}
|
||||
|
||||
func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[76]
|
||||
mi := &file_daemon_proto_msgTypes[75]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5266,7 +5204,7 @@ func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use RequestJWTAuthRequest.ProtoReflect.Descriptor instead.
|
||||
func (*RequestJWTAuthRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{76}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{75}
|
||||
}
|
||||
|
||||
func (x *RequestJWTAuthRequest) GetHint() string {
|
||||
@@ -5299,7 +5237,7 @@ type RequestJWTAuthResponse struct {
|
||||
|
||||
func (x *RequestJWTAuthResponse) Reset() {
|
||||
*x = RequestJWTAuthResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[77]
|
||||
mi := &file_daemon_proto_msgTypes[76]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5311,7 +5249,7 @@ func (x *RequestJWTAuthResponse) String() string {
|
||||
func (*RequestJWTAuthResponse) ProtoMessage() {}
|
||||
|
||||
func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[77]
|
||||
mi := &file_daemon_proto_msgTypes[76]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5324,7 +5262,7 @@ func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use RequestJWTAuthResponse.ProtoReflect.Descriptor instead.
|
||||
func (*RequestJWTAuthResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{77}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{76}
|
||||
}
|
||||
|
||||
func (x *RequestJWTAuthResponse) GetVerificationURI() string {
|
||||
@@ -5389,7 +5327,7 @@ type WaitJWTTokenRequest struct {
|
||||
|
||||
func (x *WaitJWTTokenRequest) Reset() {
|
||||
*x = WaitJWTTokenRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[78]
|
||||
mi := &file_daemon_proto_msgTypes[77]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5401,7 +5339,7 @@ func (x *WaitJWTTokenRequest) String() string {
|
||||
func (*WaitJWTTokenRequest) ProtoMessage() {}
|
||||
|
||||
func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[78]
|
||||
mi := &file_daemon_proto_msgTypes[77]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5414,7 +5352,7 @@ func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use WaitJWTTokenRequest.ProtoReflect.Descriptor instead.
|
||||
func (*WaitJWTTokenRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{78}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{77}
|
||||
}
|
||||
|
||||
func (x *WaitJWTTokenRequest) GetDeviceCode() string {
|
||||
@@ -5446,7 +5384,7 @@ type WaitJWTTokenResponse struct {
|
||||
|
||||
func (x *WaitJWTTokenResponse) Reset() {
|
||||
*x = WaitJWTTokenResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[79]
|
||||
mi := &file_daemon_proto_msgTypes[78]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5458,7 +5396,7 @@ func (x *WaitJWTTokenResponse) String() string {
|
||||
func (*WaitJWTTokenResponse) ProtoMessage() {}
|
||||
|
||||
func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[79]
|
||||
mi := &file_daemon_proto_msgTypes[78]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5471,7 +5409,7 @@ func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use WaitJWTTokenResponse.ProtoReflect.Descriptor instead.
|
||||
func (*WaitJWTTokenResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{79}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{78}
|
||||
}
|
||||
|
||||
func (x *WaitJWTTokenResponse) GetToken() string {
|
||||
@@ -5504,7 +5442,7 @@ type StartCPUProfileRequest struct {
|
||||
|
||||
func (x *StartCPUProfileRequest) Reset() {
|
||||
*x = StartCPUProfileRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[80]
|
||||
mi := &file_daemon_proto_msgTypes[79]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5516,7 +5454,7 @@ func (x *StartCPUProfileRequest) String() string {
|
||||
func (*StartCPUProfileRequest) ProtoMessage() {}
|
||||
|
||||
func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[80]
|
||||
mi := &file_daemon_proto_msgTypes[79]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5529,7 +5467,7 @@ func (x *StartCPUProfileRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use StartCPUProfileRequest.ProtoReflect.Descriptor instead.
|
||||
func (*StartCPUProfileRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{80}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{79}
|
||||
}
|
||||
|
||||
// StartCPUProfileResponse confirms CPU profiling has started
|
||||
@@ -5541,7 +5479,7 @@ type StartCPUProfileResponse struct {
|
||||
|
||||
func (x *StartCPUProfileResponse) Reset() {
|
||||
*x = StartCPUProfileResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[81]
|
||||
mi := &file_daemon_proto_msgTypes[80]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5553,7 +5491,7 @@ func (x *StartCPUProfileResponse) String() string {
|
||||
func (*StartCPUProfileResponse) ProtoMessage() {}
|
||||
|
||||
func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[81]
|
||||
mi := &file_daemon_proto_msgTypes[80]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5566,7 +5504,7 @@ func (x *StartCPUProfileResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use StartCPUProfileResponse.ProtoReflect.Descriptor instead.
|
||||
func (*StartCPUProfileResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{81}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{80}
|
||||
}
|
||||
|
||||
// StopCPUProfileRequest for stopping CPU profiling
|
||||
@@ -5578,7 +5516,7 @@ type StopCPUProfileRequest struct {
|
||||
|
||||
func (x *StopCPUProfileRequest) Reset() {
|
||||
*x = StopCPUProfileRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[82]
|
||||
mi := &file_daemon_proto_msgTypes[81]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5590,7 +5528,7 @@ func (x *StopCPUProfileRequest) String() string {
|
||||
func (*StopCPUProfileRequest) ProtoMessage() {}
|
||||
|
||||
func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[82]
|
||||
mi := &file_daemon_proto_msgTypes[81]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5603,7 +5541,7 @@ func (x *StopCPUProfileRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use StopCPUProfileRequest.ProtoReflect.Descriptor instead.
|
||||
func (*StopCPUProfileRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{82}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{81}
|
||||
}
|
||||
|
||||
// StopCPUProfileResponse confirms CPU profiling has stopped
|
||||
@@ -5615,7 +5553,7 @@ type StopCPUProfileResponse struct {
|
||||
|
||||
func (x *StopCPUProfileResponse) Reset() {
|
||||
*x = StopCPUProfileResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[83]
|
||||
mi := &file_daemon_proto_msgTypes[82]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5627,7 +5565,7 @@ func (x *StopCPUProfileResponse) String() string {
|
||||
func (*StopCPUProfileResponse) ProtoMessage() {}
|
||||
|
||||
func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[83]
|
||||
mi := &file_daemon_proto_msgTypes[82]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5640,7 +5578,7 @@ func (x *StopCPUProfileResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use StopCPUProfileResponse.ProtoReflect.Descriptor instead.
|
||||
func (*StopCPUProfileResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{83}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{82}
|
||||
}
|
||||
|
||||
type InstallerResultRequest struct {
|
||||
@@ -5651,7 +5589,7 @@ type InstallerResultRequest struct {
|
||||
|
||||
func (x *InstallerResultRequest) Reset() {
|
||||
*x = InstallerResultRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[84]
|
||||
mi := &file_daemon_proto_msgTypes[83]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5663,7 +5601,7 @@ func (x *InstallerResultRequest) String() string {
|
||||
func (*InstallerResultRequest) ProtoMessage() {}
|
||||
|
||||
func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[84]
|
||||
mi := &file_daemon_proto_msgTypes[83]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5676,7 +5614,7 @@ func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead.
|
||||
func (*InstallerResultRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{84}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{83}
|
||||
}
|
||||
|
||||
type InstallerResultResponse struct {
|
||||
@@ -5689,7 +5627,7 @@ type InstallerResultResponse struct {
|
||||
|
||||
func (x *InstallerResultResponse) Reset() {
|
||||
*x = InstallerResultResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[85]
|
||||
mi := &file_daemon_proto_msgTypes[84]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5701,7 +5639,7 @@ func (x *InstallerResultResponse) String() string {
|
||||
func (*InstallerResultResponse) ProtoMessage() {}
|
||||
|
||||
func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[85]
|
||||
mi := &file_daemon_proto_msgTypes[84]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5714,7 +5652,7 @@ func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead.
|
||||
func (*InstallerResultResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{85}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{84}
|
||||
}
|
||||
|
||||
func (x *InstallerResultResponse) GetSuccess() bool {
|
||||
@@ -5747,7 +5685,7 @@ type ExposeServiceRequest struct {
|
||||
|
||||
func (x *ExposeServiceRequest) Reset() {
|
||||
*x = ExposeServiceRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[86]
|
||||
mi := &file_daemon_proto_msgTypes[85]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5759,7 +5697,7 @@ func (x *ExposeServiceRequest) String() string {
|
||||
func (*ExposeServiceRequest) ProtoMessage() {}
|
||||
|
||||
func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[86]
|
||||
mi := &file_daemon_proto_msgTypes[85]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5772,7 +5710,7 @@ func (x *ExposeServiceRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use ExposeServiceRequest.ProtoReflect.Descriptor instead.
|
||||
func (*ExposeServiceRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{86}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{85}
|
||||
}
|
||||
|
||||
func (x *ExposeServiceRequest) GetPort() uint32 {
|
||||
@@ -5843,7 +5781,7 @@ type ExposeServiceEvent struct {
|
||||
|
||||
func (x *ExposeServiceEvent) Reset() {
|
||||
*x = ExposeServiceEvent{}
|
||||
mi := &file_daemon_proto_msgTypes[87]
|
||||
mi := &file_daemon_proto_msgTypes[86]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5855,7 +5793,7 @@ func (x *ExposeServiceEvent) String() string {
|
||||
func (*ExposeServiceEvent) ProtoMessage() {}
|
||||
|
||||
func (x *ExposeServiceEvent) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[87]
|
||||
mi := &file_daemon_proto_msgTypes[86]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5868,7 +5806,7 @@ func (x *ExposeServiceEvent) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use ExposeServiceEvent.ProtoReflect.Descriptor instead.
|
||||
func (*ExposeServiceEvent) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{87}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{86}
|
||||
}
|
||||
|
||||
func (x *ExposeServiceEvent) GetEvent() isExposeServiceEvent_Event {
|
||||
@@ -5909,7 +5847,7 @@ type ExposeServiceReady struct {
|
||||
|
||||
func (x *ExposeServiceReady) Reset() {
|
||||
*x = ExposeServiceReady{}
|
||||
mi := &file_daemon_proto_msgTypes[88]
|
||||
mi := &file_daemon_proto_msgTypes[87]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5921,7 +5859,7 @@ func (x *ExposeServiceReady) String() string {
|
||||
func (*ExposeServiceReady) ProtoMessage() {}
|
||||
|
||||
func (x *ExposeServiceReady) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[88]
|
||||
mi := &file_daemon_proto_msgTypes[87]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -5934,7 +5872,7 @@ func (x *ExposeServiceReady) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use ExposeServiceReady.ProtoReflect.Descriptor instead.
|
||||
func (*ExposeServiceReady) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{88}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{87}
|
||||
}
|
||||
|
||||
func (x *ExposeServiceReady) GetServiceName() string {
|
||||
@@ -5979,7 +5917,7 @@ type StartCaptureRequest struct {
|
||||
|
||||
func (x *StartCaptureRequest) Reset() {
|
||||
*x = StartCaptureRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[89]
|
||||
mi := &file_daemon_proto_msgTypes[88]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -5991,7 +5929,7 @@ func (x *StartCaptureRequest) String() string {
|
||||
func (*StartCaptureRequest) ProtoMessage() {}
|
||||
|
||||
func (x *StartCaptureRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[89]
|
||||
mi := &file_daemon_proto_msgTypes[88]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -6004,7 +5942,7 @@ func (x *StartCaptureRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use StartCaptureRequest.ProtoReflect.Descriptor instead.
|
||||
func (*StartCaptureRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{89}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{88}
|
||||
}
|
||||
|
||||
func (x *StartCaptureRequest) GetTextOutput() bool {
|
||||
@@ -6058,7 +5996,7 @@ type CapturePacket struct {
|
||||
|
||||
func (x *CapturePacket) Reset() {
|
||||
*x = CapturePacket{}
|
||||
mi := &file_daemon_proto_msgTypes[90]
|
||||
mi := &file_daemon_proto_msgTypes[89]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -6070,7 +6008,7 @@ func (x *CapturePacket) String() string {
|
||||
func (*CapturePacket) ProtoMessage() {}
|
||||
|
||||
func (x *CapturePacket) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[90]
|
||||
mi := &file_daemon_proto_msgTypes[89]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -6083,7 +6021,7 @@ func (x *CapturePacket) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use CapturePacket.ProtoReflect.Descriptor instead.
|
||||
func (*CapturePacket) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{90}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{89}
|
||||
}
|
||||
|
||||
func (x *CapturePacket) GetData() []byte {
|
||||
@@ -6104,7 +6042,7 @@ type StartBundleCaptureRequest struct {
|
||||
|
||||
func (x *StartBundleCaptureRequest) Reset() {
|
||||
*x = StartBundleCaptureRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[91]
|
||||
mi := &file_daemon_proto_msgTypes[90]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -6116,7 +6054,7 @@ func (x *StartBundleCaptureRequest) String() string {
|
||||
func (*StartBundleCaptureRequest) ProtoMessage() {}
|
||||
|
||||
func (x *StartBundleCaptureRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[91]
|
||||
mi := &file_daemon_proto_msgTypes[90]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -6129,7 +6067,7 @@ func (x *StartBundleCaptureRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use StartBundleCaptureRequest.ProtoReflect.Descriptor instead.
|
||||
func (*StartBundleCaptureRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{91}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{90}
|
||||
}
|
||||
|
||||
func (x *StartBundleCaptureRequest) GetTimeout() *durationpb.Duration {
|
||||
@@ -6147,7 +6085,7 @@ type StartBundleCaptureResponse struct {
|
||||
|
||||
func (x *StartBundleCaptureResponse) Reset() {
|
||||
*x = StartBundleCaptureResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[92]
|
||||
mi := &file_daemon_proto_msgTypes[91]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -6159,7 +6097,7 @@ func (x *StartBundleCaptureResponse) String() string {
|
||||
func (*StartBundleCaptureResponse) ProtoMessage() {}
|
||||
|
||||
func (x *StartBundleCaptureResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[92]
|
||||
mi := &file_daemon_proto_msgTypes[91]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -6172,7 +6110,7 @@ func (x *StartBundleCaptureResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use StartBundleCaptureResponse.ProtoReflect.Descriptor instead.
|
||||
func (*StartBundleCaptureResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{92}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{91}
|
||||
}
|
||||
|
||||
type StopBundleCaptureRequest struct {
|
||||
@@ -6183,7 +6121,7 @@ type StopBundleCaptureRequest struct {
|
||||
|
||||
func (x *StopBundleCaptureRequest) Reset() {
|
||||
*x = StopBundleCaptureRequest{}
|
||||
mi := &file_daemon_proto_msgTypes[93]
|
||||
mi := &file_daemon_proto_msgTypes[92]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -6195,7 +6133,7 @@ func (x *StopBundleCaptureRequest) String() string {
|
||||
func (*StopBundleCaptureRequest) ProtoMessage() {}
|
||||
|
||||
func (x *StopBundleCaptureRequest) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[93]
|
||||
mi := &file_daemon_proto_msgTypes[92]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -6208,7 +6146,7 @@ func (x *StopBundleCaptureRequest) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use StopBundleCaptureRequest.ProtoReflect.Descriptor instead.
|
||||
func (*StopBundleCaptureRequest) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{93}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{92}
|
||||
}
|
||||
|
||||
type StopBundleCaptureResponse struct {
|
||||
@@ -6219,7 +6157,7 @@ type StopBundleCaptureResponse struct {
|
||||
|
||||
func (x *StopBundleCaptureResponse) Reset() {
|
||||
*x = StopBundleCaptureResponse{}
|
||||
mi := &file_daemon_proto_msgTypes[94]
|
||||
mi := &file_daemon_proto_msgTypes[93]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -6231,7 +6169,7 @@ func (x *StopBundleCaptureResponse) String() string {
|
||||
func (*StopBundleCaptureResponse) ProtoMessage() {}
|
||||
|
||||
func (x *StopBundleCaptureResponse) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[94]
|
||||
mi := &file_daemon_proto_msgTypes[93]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -6244,7 +6182,7 @@ func (x *StopBundleCaptureResponse) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use StopBundleCaptureResponse.ProtoReflect.Descriptor instead.
|
||||
func (*StopBundleCaptureResponse) Descriptor() ([]byte, []int) {
|
||||
return file_daemon_proto_rawDescGZIP(), []int{94}
|
||||
return file_daemon_proto_rawDescGZIP(), []int{93}
|
||||
}
|
||||
|
||||
type PortInfo_Range struct {
|
||||
@@ -6257,7 +6195,7 @@ type PortInfo_Range struct {
|
||||
|
||||
func (x *PortInfo_Range) Reset() {
|
||||
*x = PortInfo_Range{}
|
||||
mi := &file_daemon_proto_msgTypes[96]
|
||||
mi := &file_daemon_proto_msgTypes[95]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -6269,7 +6207,7 @@ func (x *PortInfo_Range) String() string {
|
||||
func (*PortInfo_Range) ProtoMessage() {}
|
||||
|
||||
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_daemon_proto_msgTypes[96]
|
||||
mi := &file_daemon_proto_msgTypes[95]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -6410,7 +6348,7 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\fDownResponse\"P\n" +
|
||||
"\x10GetConfigRequest\x12 \n" +
|
||||
"\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" +
|
||||
"\busername\x18\x02 \x01(\tR\busername\"\xaa\t\n" +
|
||||
"\busername\x18\x02 \x01(\tR\busername\"\xfe\b\n" +
|
||||
"\x11GetConfigResponse\x12$\n" +
|
||||
"\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" +
|
||||
"\n" +
|
||||
@@ -6442,8 +6380,7 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x1denableSSHRemotePortForwarding\x18\x17 \x01(\bR\x1denableSSHRemotePortForwarding\x12&\n" +
|
||||
"\x0edisableSSHAuth\x18\x19 \x01(\bR\x0edisableSSHAuth\x12&\n" +
|
||||
"\x0esshJWTCacheTTL\x18\x1a \x01(\x05R\x0esshJWTCacheTTL\x12!\n" +
|
||||
"\fdisable_ipv6\x18\x1b \x01(\bR\vdisableIpv6\x12*\n" +
|
||||
"\x10mDMManagedFields\x18\x1c \x03(\tR\x10mDMManagedFields\"\x92\x06\n" +
|
||||
"\fdisable_ipv6\x18\x1b \x01(\bR\vdisableIpv6\"\x92\x06\n" +
|
||||
"\tPeerState\x12\x0e\n" +
|
||||
"\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" +
|
||||
"\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12\x1e\n" +
|
||||
@@ -6758,9 +6695,7 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x13GetFeaturesResponse\x12)\n" +
|
||||
"\x10disable_profiles\x18\x01 \x01(\bR\x0fdisableProfiles\x126\n" +
|
||||
"\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\x12)\n" +
|
||||
"\x10disable_networks\x18\x03 \x01(\bR\x0fdisableNetworks\"3\n" +
|
||||
"\x19MDMManagedFieldsViolation\x12\x16\n" +
|
||||
"\x06fields\x18\x01 \x03(\tR\x06fields\"\x16\n" +
|
||||
"\x10disable_networks\x18\x03 \x01(\bR\x0fdisableNetworks\"\x16\n" +
|
||||
"\x14TriggerUpdateRequest\"M\n" +
|
||||
"\x15TriggerUpdateResponse\x12\x18\n" +
|
||||
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
|
||||
@@ -6916,7 +6851,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
|
||||
}
|
||||
|
||||
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4)
|
||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 98)
|
||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 97)
|
||||
var file_daemon_proto_goTypes = []any{
|
||||
(LogLevel)(0), // 0: daemon.LogLevel
|
||||
(ExposeProtocol)(0), // 1: daemon.ExposeProtocol
|
||||
@@ -6993,42 +6928,41 @@ var file_daemon_proto_goTypes = []any{
|
||||
(*LogoutResponse)(nil), // 72: daemon.LogoutResponse
|
||||
(*GetFeaturesRequest)(nil), // 73: daemon.GetFeaturesRequest
|
||||
(*GetFeaturesResponse)(nil), // 74: daemon.GetFeaturesResponse
|
||||
(*MDMManagedFieldsViolation)(nil), // 75: daemon.MDMManagedFieldsViolation
|
||||
(*TriggerUpdateRequest)(nil), // 76: daemon.TriggerUpdateRequest
|
||||
(*TriggerUpdateResponse)(nil), // 77: daemon.TriggerUpdateResponse
|
||||
(*GetPeerSSHHostKeyRequest)(nil), // 78: daemon.GetPeerSSHHostKeyRequest
|
||||
(*GetPeerSSHHostKeyResponse)(nil), // 79: daemon.GetPeerSSHHostKeyResponse
|
||||
(*RequestJWTAuthRequest)(nil), // 80: daemon.RequestJWTAuthRequest
|
||||
(*RequestJWTAuthResponse)(nil), // 81: daemon.RequestJWTAuthResponse
|
||||
(*WaitJWTTokenRequest)(nil), // 82: daemon.WaitJWTTokenRequest
|
||||
(*WaitJWTTokenResponse)(nil), // 83: daemon.WaitJWTTokenResponse
|
||||
(*StartCPUProfileRequest)(nil), // 84: daemon.StartCPUProfileRequest
|
||||
(*StartCPUProfileResponse)(nil), // 85: daemon.StartCPUProfileResponse
|
||||
(*StopCPUProfileRequest)(nil), // 86: daemon.StopCPUProfileRequest
|
||||
(*StopCPUProfileResponse)(nil), // 87: daemon.StopCPUProfileResponse
|
||||
(*InstallerResultRequest)(nil), // 88: daemon.InstallerResultRequest
|
||||
(*InstallerResultResponse)(nil), // 89: daemon.InstallerResultResponse
|
||||
(*ExposeServiceRequest)(nil), // 90: daemon.ExposeServiceRequest
|
||||
(*ExposeServiceEvent)(nil), // 91: daemon.ExposeServiceEvent
|
||||
(*ExposeServiceReady)(nil), // 92: daemon.ExposeServiceReady
|
||||
(*StartCaptureRequest)(nil), // 93: daemon.StartCaptureRequest
|
||||
(*CapturePacket)(nil), // 94: daemon.CapturePacket
|
||||
(*StartBundleCaptureRequest)(nil), // 95: daemon.StartBundleCaptureRequest
|
||||
(*StartBundleCaptureResponse)(nil), // 96: daemon.StartBundleCaptureResponse
|
||||
(*StopBundleCaptureRequest)(nil), // 97: daemon.StopBundleCaptureRequest
|
||||
(*StopBundleCaptureResponse)(nil), // 98: daemon.StopBundleCaptureResponse
|
||||
nil, // 99: daemon.Network.ResolvedIPsEntry
|
||||
(*PortInfo_Range)(nil), // 100: daemon.PortInfo.Range
|
||||
nil, // 101: daemon.SystemEvent.MetadataEntry
|
||||
(*durationpb.Duration)(nil), // 102: google.protobuf.Duration
|
||||
(*timestamppb.Timestamp)(nil), // 103: google.protobuf.Timestamp
|
||||
(*TriggerUpdateRequest)(nil), // 75: daemon.TriggerUpdateRequest
|
||||
(*TriggerUpdateResponse)(nil), // 76: daemon.TriggerUpdateResponse
|
||||
(*GetPeerSSHHostKeyRequest)(nil), // 77: daemon.GetPeerSSHHostKeyRequest
|
||||
(*GetPeerSSHHostKeyResponse)(nil), // 78: daemon.GetPeerSSHHostKeyResponse
|
||||
(*RequestJWTAuthRequest)(nil), // 79: daemon.RequestJWTAuthRequest
|
||||
(*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse
|
||||
(*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest
|
||||
(*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse
|
||||
(*StartCPUProfileRequest)(nil), // 83: daemon.StartCPUProfileRequest
|
||||
(*StartCPUProfileResponse)(nil), // 84: daemon.StartCPUProfileResponse
|
||||
(*StopCPUProfileRequest)(nil), // 85: daemon.StopCPUProfileRequest
|
||||
(*StopCPUProfileResponse)(nil), // 86: daemon.StopCPUProfileResponse
|
||||
(*InstallerResultRequest)(nil), // 87: daemon.InstallerResultRequest
|
||||
(*InstallerResultResponse)(nil), // 88: daemon.InstallerResultResponse
|
||||
(*ExposeServiceRequest)(nil), // 89: daemon.ExposeServiceRequest
|
||||
(*ExposeServiceEvent)(nil), // 90: daemon.ExposeServiceEvent
|
||||
(*ExposeServiceReady)(nil), // 91: daemon.ExposeServiceReady
|
||||
(*StartCaptureRequest)(nil), // 92: daemon.StartCaptureRequest
|
||||
(*CapturePacket)(nil), // 93: daemon.CapturePacket
|
||||
(*StartBundleCaptureRequest)(nil), // 94: daemon.StartBundleCaptureRequest
|
||||
(*StartBundleCaptureResponse)(nil), // 95: daemon.StartBundleCaptureResponse
|
||||
(*StopBundleCaptureRequest)(nil), // 96: daemon.StopBundleCaptureRequest
|
||||
(*StopBundleCaptureResponse)(nil), // 97: daemon.StopBundleCaptureResponse
|
||||
nil, // 98: daemon.Network.ResolvedIPsEntry
|
||||
(*PortInfo_Range)(nil), // 99: daemon.PortInfo.Range
|
||||
nil, // 100: daemon.SystemEvent.MetadataEntry
|
||||
(*durationpb.Duration)(nil), // 101: google.protobuf.Duration
|
||||
(*timestamppb.Timestamp)(nil), // 102: google.protobuf.Timestamp
|
||||
}
|
||||
var file_daemon_proto_depIdxs = []int32{
|
||||
102, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
101, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
25, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
||||
103, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||
103, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||
102, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||
102, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||
102, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||
101, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||
23, // 5: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
|
||||
20, // 6: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
||||
19, // 7: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
||||
@@ -7039,8 +6973,8 @@ var file_daemon_proto_depIdxs = []int32{
|
||||
55, // 12: daemon.FullStatus.events:type_name -> daemon.SystemEvent
|
||||
24, // 13: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
|
||||
31, // 14: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
|
||||
99, // 15: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
||||
100, // 16: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
||||
98, // 15: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
||||
99, // 16: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
||||
32, // 17: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
|
||||
32, // 18: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
|
||||
33, // 19: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
|
||||
@@ -7051,15 +6985,15 @@ var file_daemon_proto_depIdxs = []int32{
|
||||
52, // 24: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
|
||||
2, // 25: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
|
||||
3, // 26: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
|
||||
103, // 27: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||
101, // 28: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
||||
102, // 27: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||
100, // 28: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
||||
55, // 29: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
|
||||
102, // 30: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
101, // 30: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||
68, // 31: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
|
||||
1, // 32: daemon.ExposeServiceRequest.protocol:type_name -> daemon.ExposeProtocol
|
||||
92, // 33: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady
|
||||
102, // 34: daemon.StartCaptureRequest.duration:type_name -> google.protobuf.Duration
|
||||
102, // 35: daemon.StartBundleCaptureRequest.timeout:type_name -> google.protobuf.Duration
|
||||
91, // 33: daemon.ExposeServiceEvent.ready:type_name -> daemon.ExposeServiceReady
|
||||
101, // 34: daemon.StartCaptureRequest.duration:type_name -> google.protobuf.Duration
|
||||
101, // 35: daemon.StartBundleCaptureRequest.timeout:type_name -> google.protobuf.Duration
|
||||
30, // 36: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
||||
5, // 37: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||
7, // 38: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest
|
||||
@@ -7079,9 +7013,9 @@ var file_daemon_proto_depIdxs = []int32{
|
||||
46, // 52: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest
|
||||
48, // 53: daemon.DaemonService.SetSyncResponsePersistence:input_type -> daemon.SetSyncResponsePersistenceRequest
|
||||
51, // 54: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest
|
||||
93, // 55: daemon.DaemonService.StartCapture:input_type -> daemon.StartCaptureRequest
|
||||
95, // 56: daemon.DaemonService.StartBundleCapture:input_type -> daemon.StartBundleCaptureRequest
|
||||
97, // 57: daemon.DaemonService.StopBundleCapture:input_type -> daemon.StopBundleCaptureRequest
|
||||
92, // 55: daemon.DaemonService.StartCapture:input_type -> daemon.StartCaptureRequest
|
||||
94, // 56: daemon.DaemonService.StartBundleCapture:input_type -> daemon.StartBundleCaptureRequest
|
||||
96, // 57: daemon.DaemonService.StopBundleCapture:input_type -> daemon.StopBundleCaptureRequest
|
||||
54, // 58: daemon.DaemonService.SubscribeEvents:input_type -> daemon.SubscribeRequest
|
||||
56, // 59: daemon.DaemonService.GetEvents:input_type -> daemon.GetEventsRequest
|
||||
58, // 60: daemon.DaemonService.SwitchProfile:input_type -> daemon.SwitchProfileRequest
|
||||
@@ -7092,14 +7026,14 @@ var file_daemon_proto_depIdxs = []int32{
|
||||
69, // 65: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
|
||||
71, // 66: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
|
||||
73, // 67: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
|
||||
76, // 68: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest
|
||||
78, // 69: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
|
||||
80, // 70: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
|
||||
82, // 71: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
|
||||
84, // 72: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
|
||||
86, // 73: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
|
||||
88, // 74: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
|
||||
90, // 75: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest
|
||||
75, // 68: daemon.DaemonService.TriggerUpdate:input_type -> daemon.TriggerUpdateRequest
|
||||
77, // 69: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
|
||||
79, // 70: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
|
||||
81, // 71: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
|
||||
83, // 72: daemon.DaemonService.StartCPUProfile:input_type -> daemon.StartCPUProfileRequest
|
||||
85, // 73: daemon.DaemonService.StopCPUProfile:input_type -> daemon.StopCPUProfileRequest
|
||||
87, // 74: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
|
||||
89, // 75: daemon.DaemonService.ExposeService:input_type -> daemon.ExposeServiceRequest
|
||||
6, // 76: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||
8, // 77: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||
10, // 78: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||
@@ -7118,9 +7052,9 @@ var file_daemon_proto_depIdxs = []int32{
|
||||
47, // 91: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
||||
49, // 92: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
|
||||
53, // 93: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
||||
94, // 94: daemon.DaemonService.StartCapture:output_type -> daemon.CapturePacket
|
||||
96, // 95: daemon.DaemonService.StartBundleCapture:output_type -> daemon.StartBundleCaptureResponse
|
||||
98, // 96: daemon.DaemonService.StopBundleCapture:output_type -> daemon.StopBundleCaptureResponse
|
||||
93, // 94: daemon.DaemonService.StartCapture:output_type -> daemon.CapturePacket
|
||||
95, // 95: daemon.DaemonService.StartBundleCapture:output_type -> daemon.StartBundleCaptureResponse
|
||||
97, // 96: daemon.DaemonService.StopBundleCapture:output_type -> daemon.StopBundleCaptureResponse
|
||||
55, // 97: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
|
||||
57, // 98: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
|
||||
59, // 99: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
|
||||
@@ -7131,14 +7065,14 @@ var file_daemon_proto_depIdxs = []int32{
|
||||
70, // 104: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
|
||||
72, // 105: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
|
||||
74, // 106: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
|
||||
77, // 107: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse
|
||||
79, // 108: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
||||
81, // 109: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
|
||||
83, // 110: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
|
||||
85, // 111: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
|
||||
87, // 112: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
|
||||
89, // 113: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
|
||||
91, // 114: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent
|
||||
76, // 107: daemon.DaemonService.TriggerUpdate:output_type -> daemon.TriggerUpdateResponse
|
||||
78, // 108: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
||||
80, // 109: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
|
||||
82, // 110: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
|
||||
84, // 111: daemon.DaemonService.StartCPUProfile:output_type -> daemon.StartCPUProfileResponse
|
||||
86, // 112: daemon.DaemonService.StopCPUProfile:output_type -> daemon.StopCPUProfileResponse
|
||||
88, // 113: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
|
||||
90, // 114: daemon.DaemonService.ExposeService:output_type -> daemon.ExposeServiceEvent
|
||||
76, // [76:115] is the sub-list for method output_type
|
||||
37, // [37:76] is the sub-list for method input_type
|
||||
37, // [37:37] is the sub-list for extension type_name
|
||||
@@ -7163,8 +7097,8 @@ func file_daemon_proto_init() {
|
||||
file_daemon_proto_msgTypes[54].OneofWrappers = []any{}
|
||||
file_daemon_proto_msgTypes[56].OneofWrappers = []any{}
|
||||
file_daemon_proto_msgTypes[67].OneofWrappers = []any{}
|
||||
file_daemon_proto_msgTypes[76].OneofWrappers = []any{}
|
||||
file_daemon_proto_msgTypes[87].OneofWrappers = []any{
|
||||
file_daemon_proto_msgTypes[75].OneofWrappers = []any{}
|
||||
file_daemon_proto_msgTypes[86].OneofWrappers = []any{
|
||||
(*ExposeServiceEvent_Ready)(nil),
|
||||
}
|
||||
type x struct{}
|
||||
@@ -7173,7 +7107,7 @@ func file_daemon_proto_init() {
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
|
||||
NumEnums: 4,
|
||||
NumMessages: 98,
|
||||
NumMessages: 97,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
||||
@@ -314,13 +314,6 @@ message GetConfigResponse {
|
||||
int32 sshJWTCacheTTL = 26;
|
||||
|
||||
bool disable_ipv6 = 27;
|
||||
|
||||
// mDMManagedFields lists the names of configuration keys whose value is
|
||||
// currently enforced by an MDM policy. Names match mdm.Key* constants
|
||||
// (e.g. "managementURL", "disableClientRoutes"). UI/CLI clients should
|
||||
// render the corresponding inputs as read-only and display a "managed
|
||||
// by MDM" indicator.
|
||||
repeated string mDMManagedFields = 28;
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
@@ -740,15 +733,6 @@ message GetFeaturesResponse{
|
||||
bool disable_networks = 3;
|
||||
}
|
||||
|
||||
// MDMManagedFieldsViolation is attached as a gRPC error detail on a
|
||||
// FailedPrecondition status returned from SetConfig (and similar mutating
|
||||
// RPCs) when the caller tries to modify one or more MDM-enforced fields.
|
||||
// The fields list contains the offending key names; the entire request is
|
||||
// rejected (no partial apply).
|
||||
message MDMManagedFieldsViolation {
|
||||
repeated string fields = 1;
|
||||
}
|
||||
|
||||
message TriggerUpdateRequest {}
|
||||
|
||||
message TriggerUpdateResponse {
|
||||
|
||||
@@ -1,419 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// preSharedKeyRedactedSentinel is the value GetConfig returns in place
|
||||
// of an actual PSK, so a UI that round-trips the field back to the
|
||||
// daemon (via SetConfig / Login) can be distinguished from a deliberate
|
||||
// override. Any incoming PSK that equals this sentinel is treated as
|
||||
// a no-op echo, never as a conflict with the policy.
|
||||
const preSharedKeyRedactedSentinel = "**********"
|
||||
|
||||
// loadMDMPolicy is the indirection used by server handlers to read the
|
||||
// active MDM policy. Tests override this to inject a fake policy.
|
||||
var loadMDMPolicy = mdm.LoadPolicy
|
||||
|
||||
// conflictCheck is a value-aware comparison between a single field in
|
||||
// the incoming request and the corresponding MDM-enforced value. It
|
||||
// runs only when the field was actually set in the request (presence
|
||||
// already filtered upstream); ok=true reports the policy value, ok=false
|
||||
// means the policy is silent on the key — both are treated as conflicts
|
||||
// to be safe (an MDM key declared as managed must hold a value).
|
||||
type conflictCheck struct {
|
||||
key string
|
||||
check func(*mdm.Policy) (match bool)
|
||||
}
|
||||
|
||||
// onMDMPolicyChange is invoked by the MDM reload ticker every time the
|
||||
// OS-native managed-config store reports a diff vs the last observation.
|
||||
//
|
||||
// Restart sequence:
|
||||
// 1. Cancel the active engine context (terminates connectWithRetryRuns).
|
||||
// 2. Wait briefly for that goroutine to exit (giveUpChan is closed on exit).
|
||||
// 3. Re-resolve Config from disk + MDM policy (Config.apply re-runs
|
||||
// applyMDMPolicy with the freshly loaded Policy).
|
||||
// 4. Spawn a fresh connectWithRetryRuns with the new context and config.
|
||||
// 5. Broadcast a SystemEvent so any GUI / CLI subscriber (SubscribeEvents
|
||||
// RPC) can refresh its cached config view without polling.
|
||||
//
|
||||
// The callback runs in the ticker's own goroutine. Ticker has already
|
||||
// logged the per-key diff before invoking this hook.
|
||||
func (s *Server) onMDMPolicyChange(_, _ *mdm.Policy) error {
|
||||
log.Warn("MDM policy changed; restarting engine to apply new configuration")
|
||||
|
||||
// Hold s.mutex for the entire restart sequence (cancel + quiescence
|
||||
// wait + re-spawn). Any concurrent Up/Down/Status arriving while
|
||||
// MDM is restarting blocks on the Lock until we are done — they
|
||||
// then observe the post-restart state coherently. This is safe
|
||||
// because the connectWithRetryRuns goroutine no longer acquires
|
||||
// s.mutex in its defer (intent vs. goroutine-alive concerns are
|
||||
// fully separated; see the connectionGoroutineRunning helper).
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if !s.clientRunning {
|
||||
// The client is not running, so there's no engine to restart.
|
||||
return nil
|
||||
}
|
||||
if s.actCancel != nil {
|
||||
s.actCancel()
|
||||
}
|
||||
|
||||
// Wait for previous connectWithRetryRuns to exit so we don't end up
|
||||
// with two goroutines fighting over the same status recorder + engine.
|
||||
// The teardown engages a fan-out of engine goroutines (peer workers,
|
||||
// signal handler, route manager, ...). close(clientGiveUpChan)
|
||||
// happens in the function-scope defer of connectWithRetryRuns, on
|
||||
// every exit path (ctx cancel, backoff exhausted, panic) — see the
|
||||
// defer in server.go.
|
||||
if s.clientGiveUpChan != nil {
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
case <-time.After(10 * time.Second):
|
||||
return fmt.Errorf("failed to restart the engine due to timeout")
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.restartEngineForMDMLocked(); err != nil {
|
||||
log.Errorf("MDM restart failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// publishConfigChangedEvent has already fired inside
|
||||
// restartEngineForMDMLocked with source="mdm". Emit an MDM-specific
|
||||
// user-visible toast so the operator knows their IT policy was
|
||||
// applied (UserMessage != "" triggers the GUI notifier).
|
||||
s.statusRecorder.PublishEvent(
|
||||
proto.SystemEvent_INFO,
|
||||
proto.SystemEvent_SYSTEM,
|
||||
"MDM policy applied",
|
||||
"NetBird configuration was updated by your IT policy.",
|
||||
map[string]string{"source": "mdm", "type": "policy_applied"},
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// publishConfigChangedEvent broadcasts a SystemEvent informing any active
|
||||
// SubscribeEvents subscriber (typically the GUI tray) that the daemon's
|
||||
// effective Config has been replaced and any cached client-side view
|
||||
// should be refreshed. Callers pass a stable `source` label so the GUI
|
||||
// can distinguish a startup spawn from a user-triggered Up or an
|
||||
// MDM-driven restart. Reusing the SYSTEM category keeps the proto enum
|
||||
// stable; metadata.type="config_changed" routes to the GUI's refresh
|
||||
// handler. UserMessage is left empty so the system tray does not toast
|
||||
// for every internal restart; the MDM path emits a separate
|
||||
// "policy_applied" event (with UserMessage) for that purpose.
|
||||
func (s *Server) publishConfigChangedEvent(source string) {
|
||||
if s.statusRecorder == nil {
|
||||
return
|
||||
}
|
||||
s.statusRecorder.PublishEvent(
|
||||
proto.SystemEvent_INFO,
|
||||
proto.SystemEvent_SYSTEM,
|
||||
fmt.Sprintf("daemon config changed (source=%s)", source),
|
||||
"",
|
||||
map[string]string{
|
||||
"source": source,
|
||||
"type": "config_changed",
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// restartEngineForMDMLocked re-resolves the active profile config
|
||||
// (re-running applyMDMPolicy via Config.apply) and re-spawns
|
||||
// connectWithRetryRuns. Mirrors the tail of Server.Start so a runtime
|
||||
// MDM change behaves identically to a fresh boot under the new policy.
|
||||
//
|
||||
// MUST be called with s.mutex held — onMDMPolicyChange holds the lock
|
||||
// for the entire restart sequence (cancel + quiescence wait + re-spawn)
|
||||
// so concurrent Up/Down/Status RPCs observe a coherent post-restart
|
||||
// state.
|
||||
func (s *Server) restartEngineForMDMLocked() error {
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile state: %w", err)
|
||||
}
|
||||
config, _, err := s.getConfig(activeProf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile config: %w", err)
|
||||
}
|
||||
|
||||
s.config = config
|
||||
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
|
||||
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
|
||||
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
s.actCancel = cancel
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
log.Info("MDM restart: spawning connectWithRetryRuns with re-resolved config")
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
s.publishConfigChangedEvent("mdm")
|
||||
return nil
|
||||
}
|
||||
|
||||
// conflictBool builds a conflictCheck for a boolean MDM key. If p is nil
|
||||
// the field is treated as matching (no override requested); otherwise the
|
||||
// check returns true only when the policy contains the key and its
|
||||
// boolean value equals *p.
|
||||
func conflictBool(key string, p *bool) conflictCheck {
|
||||
return conflictCheck{
|
||||
key: key,
|
||||
check: func(pol *mdm.Policy) bool {
|
||||
if p == nil {
|
||||
return true // absent → match by definition
|
||||
}
|
||||
want, ok := pol.GetBool(key)
|
||||
return ok && want == *p
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// conflictString builds a conflictCheck for a string MDM key. An empty
|
||||
// `got` is treated as "field not set" (no override requested); otherwise
|
||||
// the check returns true only when the policy contains the key and its
|
||||
// value equals got.
|
||||
func conflictString(key, got string) conflictCheck {
|
||||
return conflictCheck{
|
||||
key: key,
|
||||
check: func(pol *mdm.Policy) bool {
|
||||
if got == "" {
|
||||
return true
|
||||
}
|
||||
want, ok := pol.GetString(key)
|
||||
return ok && want == got
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// conflictInt64 builds a conflictCheck for an integer MDM key. If p is
|
||||
// nil the field is treated as matching; otherwise the check returns
|
||||
// true only when the policy contains the key and its int value equals *p.
|
||||
func conflictInt64(key string, p *int64) conflictCheck {
|
||||
return conflictCheck{
|
||||
key: key,
|
||||
check: func(pol *mdm.Policy) bool {
|
||||
if p == nil {
|
||||
return true
|
||||
}
|
||||
want, ok := pol.GetInt(key)
|
||||
return ok && want == *p
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// resolveConflicts walks the per-field checks against the active MDM
|
||||
// policy and returns the names of keys whose requested value diverges
|
||||
// from the policy-enforced value. Keys not present in the policy are
|
||||
// skipped silently (the gate fires only for keys the admin has
|
||||
// actually pushed). Returns nil for an empty policy.
|
||||
func resolveConflicts(policy *mdm.Policy, checks []conflictCheck) []string {
|
||||
if policy.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
var conflicts []string
|
||||
for _, c := range checks {
|
||||
if !policy.HasKey(c.key) {
|
||||
continue
|
||||
}
|
||||
if !c.check(policy) {
|
||||
conflicts = append(conflicts, c.key)
|
||||
}
|
||||
}
|
||||
return conflicts
|
||||
}
|
||||
|
||||
// mdmManagedFieldConflicts returns the names of MDM-managed keys whose
|
||||
// requested value in the SetConfigRequest differs from the MDM-enforced
|
||||
// value. A field set to the same value the policy already enforces is
|
||||
// treated as a no-op echo (the GUI tray sends a full Config snapshot on
|
||||
// every toggle, so most fields in a typical request match the policy
|
||||
// exactly and must NOT be flagged as conflicts). The redacted PSK
|
||||
// sentinel ("**********") returned by GetConfig is recognised and
|
||||
// treated as no-op so the UI can safely round-trip it.
|
||||
func mdmManagedFieldConflicts(msg *proto.SetConfigRequest, policy *mdm.Policy) []string {
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// PSK round-trip echo: collapse the sentinel to empty so the
|
||||
// shared check treats it as "field not set".
|
||||
pskGot := ""
|
||||
if msg.OptionalPreSharedKey != nil && *msg.OptionalPreSharedKey != preSharedKeyRedactedSentinel {
|
||||
pskGot = *msg.OptionalPreSharedKey
|
||||
}
|
||||
|
||||
return resolveConflicts(policy, []conflictCheck{
|
||||
conflictString(mdm.KeyManagementURL, msg.ManagementUrl),
|
||||
conflictString(mdm.KeyPreSharedKey, pskGot),
|
||||
conflictBool(mdm.KeyRosenpassEnabled, msg.RosenpassEnabled),
|
||||
conflictBool(mdm.KeyRosenpassPermissive, msg.RosenpassPermissive),
|
||||
conflictBool(mdm.KeyDisableAutoConnect, msg.DisableAutoConnect),
|
||||
conflictBool(mdm.KeyAllowServerSSH, msg.ServerSSHAllowed),
|
||||
conflictBool(mdm.KeyDisableClientRoutes, msg.DisableClientRoutes),
|
||||
conflictBool(mdm.KeyDisableServerRoutes, msg.DisableServerRoutes),
|
||||
conflictBool(mdm.KeyBlockInbound, msg.BlockInbound),
|
||||
conflictInt64(mdm.KeyWireguardPort, msg.WireguardPort),
|
||||
})
|
||||
}
|
||||
|
||||
// setConfigRequestHasConfigOverrides reports whether the SetConfigRequest
|
||||
// carries ANY field that would actually mutate the persisted config.
|
||||
// The CLI builds a SetConfigRequest unconditionally on every
|
||||
// `netbird up` (see setupSetConfigReq in cmd/up.go) — a plain
|
||||
// `netbird up` produces a request with every field at its zero value;
|
||||
// the gate must skip such no-op invocations or it would always fire
|
||||
// even when the user did not pass any --flag. Returns false on a nil
|
||||
// msg; true when any management/admin URL, PSK, DNS/NAT list+clean
|
||||
// flag, interface/port/MTU, or any optional bool/duration field is set.
|
||||
func setConfigRequestHasConfigOverrides(msg *proto.SetConfigRequest) bool {
|
||||
if msg == nil {
|
||||
return false
|
||||
}
|
||||
return msg.ManagementUrl != "" ||
|
||||
msg.AdminURL != "" ||
|
||||
msg.OptionalPreSharedKey != nil ||
|
||||
len(msg.CustomDNSAddress) > 0 ||
|
||||
len(msg.NatExternalIPs) > 0 || msg.CleanNATExternalIPs ||
|
||||
len(msg.ExtraIFaceBlacklist) > 0 ||
|
||||
len(msg.DnsLabels) > 0 || msg.CleanDNSLabels ||
|
||||
msg.DnsRouteInterval != nil ||
|
||||
msg.RosenpassEnabled != nil ||
|
||||
msg.RosenpassPermissive != nil ||
|
||||
msg.InterfaceName != nil ||
|
||||
msg.WireguardPort != nil ||
|
||||
msg.Mtu != nil ||
|
||||
msg.DisableAutoConnect != nil ||
|
||||
msg.ServerSSHAllowed != nil ||
|
||||
msg.NetworkMonitor != nil ||
|
||||
msg.DisableClientRoutes != nil ||
|
||||
msg.DisableServerRoutes != nil ||
|
||||
msg.DisableDns != nil ||
|
||||
msg.DisableFirewall != nil ||
|
||||
msg.BlockLanAccess != nil ||
|
||||
msg.DisableNotifications != nil ||
|
||||
msg.LazyConnectionEnabled != nil ||
|
||||
msg.BlockInbound != nil ||
|
||||
msg.DisableIpv6 != nil ||
|
||||
msg.EnableSSHRoot != nil ||
|
||||
msg.EnableSSHSFTP != nil ||
|
||||
msg.EnableSSHLocalPortForwarding != nil ||
|
||||
msg.EnableSSHRemotePortForwarding != nil ||
|
||||
msg.DisableSSHAuth != nil ||
|
||||
msg.SshJWTCacheTTL != nil
|
||||
}
|
||||
|
||||
// loginRequestHasConfigOverrides reports whether the LoginRequest
|
||||
// carries ANY field that would mutate persisted daemon configuration
|
||||
// (as opposed to pure-auth fields like setupKey, hostname, hint,
|
||||
// profileName, username). Used by the Login handler to decide whether
|
||||
// the `--disable-update-settings` / MDM gates must run: a re-auth that
|
||||
// changes nothing about the configuration is always allowed.
|
||||
func loginRequestHasConfigOverrides(msg *proto.LoginRequest) bool {
|
||||
if msg == nil {
|
||||
return false
|
||||
}
|
||||
return msg.ManagementUrl != "" ||
|
||||
msg.AdminURL != "" ||
|
||||
msg.PreSharedKey != "" || //nolint:staticcheck // SA1019: legacy proto field still accepted by Login
|
||||
msg.OptionalPreSharedKey != nil ||
|
||||
len(msg.CustomDNSAddress) > 0 ||
|
||||
len(msg.NatExternalIPs) > 0 || msg.CleanNATExternalIPs ||
|
||||
msg.RosenpassEnabled != nil ||
|
||||
msg.InterfaceName != nil ||
|
||||
msg.WireguardPort != nil ||
|
||||
msg.DisableAutoConnect != nil ||
|
||||
msg.ServerSSHAllowed != nil ||
|
||||
msg.RosenpassPermissive != nil ||
|
||||
len(msg.ExtraIFaceBlacklist) > 0 ||
|
||||
msg.NetworkMonitor != nil ||
|
||||
msg.DnsRouteInterval != nil ||
|
||||
msg.DisableClientRoutes != nil ||
|
||||
msg.DisableServerRoutes != nil ||
|
||||
msg.DisableDns != nil ||
|
||||
msg.DisableFirewall != nil ||
|
||||
msg.BlockLanAccess != nil ||
|
||||
msg.DisableNotifications != nil ||
|
||||
len(msg.DnsLabels) > 0 || msg.CleanDNSLabels ||
|
||||
msg.LazyConnectionEnabled != nil ||
|
||||
msg.BlockInbound != nil
|
||||
}
|
||||
|
||||
// loginRequestMDMConflicts mirrors mdmManagedFieldConflicts but for the
|
||||
// LoginRequest surface. Same value-aware semantics: a field set to the
|
||||
// MDM-enforced value is a no-op echo, not a conflict; only a divergent
|
||||
// value is flagged. PSK has two proto fields — PreSharedKey (deprecated)
|
||||
// and OptionalPreSharedKey (current); either route trips the gate if it
|
||||
// diverges from the MDM-enforced PSK. OptionalPreSharedKey wins when
|
||||
// both are set; the redaction sentinel ("**********") is accepted as
|
||||
// a no-op echo.
|
||||
func loginRequestMDMConflicts(msg *proto.LoginRequest, policy *mdm.Policy) []string {
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Collapse the two PSK fields + the redaction sentinel down to a
|
||||
// single "got" string the shared check can compare against the
|
||||
// policy: OptionalPreSharedKey wins if set; PreSharedKey (deprecated)
|
||||
// is the fallback; sentinel echo is treated as "field not set".
|
||||
pskGot := ""
|
||||
if msg.OptionalPreSharedKey != nil {
|
||||
pskGot = *msg.OptionalPreSharedKey
|
||||
} else if msg.PreSharedKey != "" { //nolint:staticcheck // SA1019: legacy proto field still accepted by Login
|
||||
pskGot = msg.PreSharedKey //nolint:staticcheck // SA1019
|
||||
}
|
||||
if pskGot == preSharedKeyRedactedSentinel {
|
||||
pskGot = ""
|
||||
}
|
||||
|
||||
return resolveConflicts(policy, []conflictCheck{
|
||||
conflictString(mdm.KeyManagementURL, msg.ManagementUrl),
|
||||
conflictString(mdm.KeyPreSharedKey, pskGot),
|
||||
conflictBool(mdm.KeyRosenpassEnabled, msg.RosenpassEnabled),
|
||||
conflictBool(mdm.KeyRosenpassPermissive, msg.RosenpassPermissive),
|
||||
conflictBool(mdm.KeyDisableAutoConnect, msg.DisableAutoConnect),
|
||||
conflictBool(mdm.KeyAllowServerSSH, msg.ServerSSHAllowed),
|
||||
conflictBool(mdm.KeyDisableClientRoutes, msg.DisableClientRoutes),
|
||||
conflictBool(mdm.KeyDisableServerRoutes, msg.DisableServerRoutes),
|
||||
conflictBool(mdm.KeyBlockInbound, msg.BlockInbound),
|
||||
conflictInt64(mdm.KeyWireguardPort, msg.WireguardPort),
|
||||
})
|
||||
}
|
||||
|
||||
// rejectMDMManagedFieldConflicts returns a FailedPrecondition gRPC error
|
||||
// with an MDMManagedFieldsViolation detail when any of the requested
|
||||
// fields tries to change an MDM-enforced value to something else, and
|
||||
// nil otherwise. The whole request is rejected on any conflict; non-
|
||||
// conflicting fields in the same request are not applied either (no
|
||||
// partial apply).
|
||||
func rejectMDMManagedFieldConflicts(conflicts []string) error {
|
||||
if len(conflicts) == 0 {
|
||||
return nil
|
||||
}
|
||||
log.Warnf("MDM rejected request: tried to modify %d managed key(s): %v",
|
||||
len(conflicts), conflicts)
|
||||
st := gstatus.New(
|
||||
codes.FailedPrecondition,
|
||||
fmt.Sprintf("fields managed by MDM cannot be modified: %v", conflicts),
|
||||
)
|
||||
detailed, err := st.WithDetails(&proto.MDMManagedFieldsViolation{Fields: conflicts})
|
||||
if err != nil {
|
||||
// Detail attachment is best-effort; fall back to the plain status
|
||||
// so the caller still gets a usable FailedPrecondition.
|
||||
return st.Err()
|
||||
}
|
||||
return detailed.Err()
|
||||
}
|
||||
@@ -30,7 +30,7 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.checkNetworksDisabled() {
|
||||
if s.networksDisabled {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
@@ -143,7 +143,7 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.checkNetworksDisabled() {
|
||||
if s.networksDisabled {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
@@ -195,7 +195,7 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.checkNetworksDisabled() {
|
||||
if s.networksDisabled {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sleephandler "github.com/netbirdio/netbird/client/internal/sleep/handler"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -72,13 +71,7 @@ type Server struct {
|
||||
mutex sync.Mutex
|
||||
config *profilemanager.Config
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
// clientRunning tracks "the daemon wants to be connected" — set true by
|
||||
// Start / Up, cleared by Down / Logout. Persists across retry
|
||||
// loops, signal disconnects, and ErrResetConnection cycles. NOT
|
||||
// changed by connectWithRetryRuns goroutine exit — for that
|
||||
// (goroutine-still-alive) check, see connectionGoroutineRunning() which
|
||||
// derives from clientGiveUpChan close state. Protected by s.mutex.
|
||||
clientRunning bool
|
||||
clientRunning bool // protected by mutex
|
||||
clientRunningChan chan struct{}
|
||||
clientGiveUpChan chan struct{} // closed when connectWithRetryRuns goroutine exits
|
||||
|
||||
@@ -105,11 +98,6 @@ type Server struct {
|
||||
|
||||
sleepHandler *sleephandler.SleepHandler
|
||||
|
||||
// mdmTicker periodically re-reads the OS-native MDM policy and triggers
|
||||
// an engine restart when the policy changes. Launched once by Start;
|
||||
// stopped by the rootCtx cancellation.
|
||||
mdmTicker *mdm.Ticker
|
||||
|
||||
updateManager *updater.Manager
|
||||
|
||||
jwtCache *jwtCache
|
||||
@@ -167,17 +155,6 @@ func (s *Server) Start() error {
|
||||
s.updateManager.CheckUpdateSuccess(s.rootCtx)
|
||||
}
|
||||
|
||||
// MDM policy reload ticker: every minute the desktop daemon re-reads
|
||||
// the OS-native managed-config store and, on diff vs the previous
|
||||
// observation, cancels the active engine context so connectWithRetry-
|
||||
// Runs re-resolves Config (re-running profilemanager.Config.apply which
|
||||
// applies the freshly-read MDM policy as the last layer) and brings
|
||||
// the engine back with the new values.
|
||||
if s.mdmTicker == nil {
|
||||
s.mdmTicker = mdm.NewTicker(mdm.DefaultReloadInterval)
|
||||
go s.mdmTicker.Run(s.rootCtx, s.onMDMPolicyChange)
|
||||
}
|
||||
|
||||
// if current state contains any error, return it
|
||||
// in all other cases we can continue execution only if status is idle and up command was
|
||||
// not in the progress or already successfully established connection.
|
||||
@@ -236,27 +213,17 @@ func (s *Server) Start() error {
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
s.publishConfigChangedEvent("startup")
|
||||
return nil
|
||||
}
|
||||
|
||||
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
|
||||
// mechanism to keep the client connected even when the connection is lost.
|
||||
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
|
||||
//
|
||||
// The goroutine's exit is signalled to the daemon via close(giveUpChan)
|
||||
// — placed in the function-scope defer so every return path (panic,
|
||||
// DisableAutoConnect early-exit, backoff exhausted, ctx cancel) closes
|
||||
// it. Callers that need to observe "is the goroutine still alive?" use
|
||||
// Server.connectionGoroutineRunning() which non-blockingly checks the close state
|
||||
// of clientGiveUpChan. The defer does NOT touch s.mutex; the daemon's
|
||||
// "intent" (clientRunning) is maintained by the RPC handlers, not by this
|
||||
// goroutine.
|
||||
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
|
||||
defer func() {
|
||||
if giveUpChan != nil {
|
||||
close(giveUpChan)
|
||||
}
|
||||
s.mutex.Lock()
|
||||
s.clientRunning = false
|
||||
s.mutex.Unlock()
|
||||
}()
|
||||
|
||||
if s.config.DisableAutoConnect {
|
||||
@@ -302,26 +269,9 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
|
||||
if err := backoff.Retry(runOperation, backOff); err != nil {
|
||||
log.Errorf("operation failed: %v", err)
|
||||
}
|
||||
// giveUpChan is closed by the function-scope defer.
|
||||
}
|
||||
|
||||
// connectionGoroutineRunning reports whether the connectWithRetryRuns goroutine is
|
||||
// still running. Returns false when no goroutine has ever been started
|
||||
// AND when the most recent one has already closed clientGiveUpChan on
|
||||
// exit (whether due to ctx cancel, DisableAutoConnect single-shot
|
||||
// completion, or backoff retry exhaustion).
|
||||
//
|
||||
// MUST be called with s.mutex held — accesses s.clientGiveUpChan which
|
||||
// is written by Start/Up under the same lock.
|
||||
func (s *Server) connectionGoroutineRunning() bool {
|
||||
if s.clientGiveUpChan == nil {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
if giveUpChan != nil {
|
||||
close(giveUpChan)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -354,85 +304,54 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
// Skip the update-settings gate when the request carries no actual
|
||||
// overrides: the CLI builds a SetConfigRequest unconditionally on
|
||||
// every `netbird up` (setupSetConfigReq in cmd/up.go), so a plain
|
||||
// `netbird up` would otherwise always trip the gate and surface a
|
||||
// misleading "setConfig method is not available" warning, even when
|
||||
// the user did not pass any config flag.
|
||||
if setConfigRequestHasConfigOverrides(msg) {
|
||||
if s.checkUpdateSettingsDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
|
||||
}
|
||||
if s.checkUpdateSettingsDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
|
||||
}
|
||||
|
||||
// MDM gate: refuse the whole request if any of its fields is enforced
|
||||
// by the active MDM policy. The error carries an MDMManagedFields-
|
||||
// Violation detail listing the offending key names. Non-conflicting
|
||||
// fields in the same request are not applied either.
|
||||
policy := loadMDMPolicy()
|
||||
if err := rejectMDMManagedFieldConflicts(mdmManagedFieldConflicts(msg, policy)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config, err := setConfigInputFromRequest(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := profilemanager.UpdateConfig(config); err != nil {
|
||||
log.Errorf("failed to update profile config: %v", err)
|
||||
return nil, fmt.Errorf("failed to update profile config: %w", err)
|
||||
}
|
||||
|
||||
return &proto.SetConfigResponse{}, nil
|
||||
}
|
||||
|
||||
// setConfigInputFromRequest translates a SetConfigRequest into the
|
||||
// profilemanager.ConfigInput that profilemanager.UpdateConfig consumes.
|
||||
// Pure mapping with no business logic beyond presence-aware copying of
|
||||
// optional fields and the "empty / clean" semantics for the two slice
|
||||
// fields (DNS labels, NAT external IPs). Extracted from SetConfig to
|
||||
// keep the handler's cognitive complexity below the SonarCube
|
||||
// threshold; the body is intentionally linear because each proto
|
||||
// field is its own optional case. Returns the resolved ConfigInput
|
||||
// and a non-nil error only when the active profile file path cannot
|
||||
// be determined.
|
||||
func setConfigInputFromRequest(msg *proto.SetConfigRequest) (profilemanager.ConfigInput, error) {
|
||||
var config profilemanager.ConfigInput
|
||||
|
||||
profState := profilemanager.ActiveProfileState{
|
||||
Name: msg.ProfileName,
|
||||
Username: msg.Username,
|
||||
}
|
||||
|
||||
profPath, err := profState.FilePath()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get active profile file path: %v", err)
|
||||
return config, fmt.Errorf("failed to get active profile file path: %w", err)
|
||||
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
|
||||
}
|
||||
|
||||
var config profilemanager.ConfigInput
|
||||
|
||||
config.ConfigPath = profPath
|
||||
|
||||
if msg.ManagementUrl != "" {
|
||||
config.ManagementURL = msg.ManagementUrl
|
||||
}
|
||||
|
||||
if msg.AdminURL != "" {
|
||||
config.AdminURL = msg.AdminURL
|
||||
}
|
||||
|
||||
if msg.InterfaceName != nil {
|
||||
config.InterfaceName = msg.InterfaceName
|
||||
}
|
||||
|
||||
if msg.WireguardPort != nil {
|
||||
wgPort := int(*msg.WireguardPort)
|
||||
config.WireguardPort = &wgPort
|
||||
}
|
||||
if msg.OptionalPreSharedKey != nil && *msg.OptionalPreSharedKey != "" {
|
||||
config.PreSharedKey = msg.OptionalPreSharedKey
|
||||
|
||||
if msg.OptionalPreSharedKey != nil {
|
||||
if *msg.OptionalPreSharedKey != "" {
|
||||
config.PreSharedKey = msg.OptionalPreSharedKey
|
||||
}
|
||||
}
|
||||
|
||||
if msg.CleanDNSLabels {
|
||||
config.DNSLabels = domain.List{}
|
||||
|
||||
} else if msg.DnsLabels != nil {
|
||||
config.DNSLabels = domain.FromPunycodeList(msg.DnsLabels)
|
||||
dnsLabels := domain.FromPunycodeList(msg.DnsLabels)
|
||||
config.DNSLabels = dnsLabels
|
||||
}
|
||||
|
||||
if msg.CleanNATExternalIPs {
|
||||
@@ -445,6 +364,7 @@ func setConfigInputFromRequest(msg *proto.SetConfigRequest) (profilemanager.Conf
|
||||
if string(msg.CustomDNSAddress) == "empty" {
|
||||
config.CustomDNSAddress = []byte{}
|
||||
}
|
||||
|
||||
config.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
|
||||
|
||||
if msg.DnsRouteInterval != nil {
|
||||
@@ -477,31 +397,22 @@ func setConfigInputFromRequest(msg *proto.SetConfigRequest) (profilemanager.Conf
|
||||
ttl := int(*msg.SshJWTCacheTTL)
|
||||
config.SSHJWTCacheTTL = &ttl
|
||||
}
|
||||
|
||||
if msg.Mtu != nil {
|
||||
mtu := uint16(*msg.Mtu)
|
||||
config.MTU = &mtu
|
||||
}
|
||||
return config, nil
|
||||
|
||||
if _, err := profilemanager.UpdateConfig(config); err != nil {
|
||||
log.Errorf("failed to update profile config: %v", err)
|
||||
return nil, fmt.Errorf("failed to update profile config: %w", err)
|
||||
}
|
||||
|
||||
return &proto.SetConfigResponse{}, nil
|
||||
}
|
||||
|
||||
// Login uses setup key to prepare configuration for the daemon.
|
||||
func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*proto.LoginResponse, error) {
|
||||
// Config-override gates. LoginRequest carries the same surface as
|
||||
// SetConfigRequest (managementUrl, PSK, ssh/rosenpass/port toggles,
|
||||
// ...), so the same protections must apply. Without these the CLI
|
||||
// command `netbird up --management-url=X` (which falls through to
|
||||
// Login when SetConfig is rejected — see cmd/up.go) would silently
|
||||
// bypass `--disable-update-settings` and any MDM policy.
|
||||
if loginRequestHasConfigOverrides(msg) {
|
||||
if s.checkUpdateSettingsDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
|
||||
}
|
||||
policy := loadMDMPolicy()
|
||||
if err := rejectMDMManagedFieldConflicts(loginRequestMDMConflicts(msg, policy)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
if s.actCancel != nil {
|
||||
s.actCancel()
|
||||
@@ -741,13 +652,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
||||
// Up starts engine work in the daemon.
|
||||
func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) {
|
||||
s.mutex.Lock()
|
||||
// clientRunning is the daemon-intent flag (set by previous Up/Start, cleared
|
||||
// by Down). connectionGoroutineRunning() reports whether the previous retry-loop
|
||||
// goroutine is still trying. When intent is up AND goroutine is alive,
|
||||
// the existing engine is on the job — just wait for it. When intent
|
||||
// is up but the goroutine has given up (backoff exhausted) OR when
|
||||
// intent is down, fall through to spawn a fresh retry loop.
|
||||
if s.clientRunning && s.connectionGoroutineRunning() {
|
||||
if s.clientRunning {
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
@@ -838,7 +743,6 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
|
||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
s.publishConfigChangedEvent("up_rpc")
|
||||
|
||||
s.mutex.Unlock()
|
||||
return s.waitForUp(callerCtx)
|
||||
@@ -967,12 +871,6 @@ func (s *Server) cleanupConnection() error {
|
||||
return ErrServiceNotUp
|
||||
}
|
||||
|
||||
// Daemon intent flips to "down" — all callers (Down RPC,
|
||||
// Logout RPC handlers) tear down the connection because the user
|
||||
// explicitly asked for it. MDM restart does NOT go through this
|
||||
// path, so its clientRunning stays true.
|
||||
s.clientRunning = false
|
||||
|
||||
// Capture the engine reference before cancelling the context.
|
||||
// After actCancel(), the connectWithRetryRuns goroutine wakes up
|
||||
// and sets connectClient.engine = nil, causing connectClient.Stop()
|
||||
@@ -1176,14 +1074,10 @@ func (s *Server) Status(
|
||||
msg *proto.StatusRequest,
|
||||
) (*proto.StatusResponse, error) {
|
||||
s.mutex.Lock()
|
||||
// Only wait if the retry-loop goroutine is alive and making
|
||||
// progress. clientRunning=true with connectionGoroutineRunning=false means the
|
||||
// backoff has given up — there is nothing to wait for; let the
|
||||
// caller observe the failed status directly.
|
||||
alive := s.connectionGoroutineRunning()
|
||||
clientRunning := s.clientRunning
|
||||
s.mutex.Unlock()
|
||||
|
||||
if msg.WaitForReady != nil && *msg.WaitForReady && alive {
|
||||
if msg.WaitForReady != nil && *msg.WaitForReady && clientRunning {
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
@@ -1654,7 +1548,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
|
||||
DisableSSHAuth: disableSSHAuth,
|
||||
SshJWTCacheTTL: sshJWTCacheTTL,
|
||||
MDMManagedFields: cfg.Policy().ManagedKeys(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1753,7 +1646,7 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
|
||||
features := &proto.GetFeaturesResponse{
|
||||
DisableProfiles: s.checkProfilesDisabled(),
|
||||
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
|
||||
DisableNetworks: s.checkNetworksDisabled(),
|
||||
DisableNetworks: s.networksDisabled,
|
||||
}
|
||||
|
||||
return features, nil
|
||||
@@ -1775,46 +1668,22 @@ func (s *Server) connect(ctx context.Context, config *profilemanager.Config, sta
|
||||
return nil
|
||||
}
|
||||
|
||||
// MDM authority: when the platform-native MDM source sets a kill switch
|
||||
// key (regardless of true/false value), that value wins. The CLI flag
|
||||
// supplied at service install time is the fallback used only when the
|
||||
// MDM source is silent on the key. This honors the "MDM decides
|
||||
// everything" semantic agreed for NET-1214 — an admin pushing
|
||||
// disableX=false via MDM explicitly re-enables the feature even on a
|
||||
// box installed with --disable-X.
|
||||
func (s *Server) checkProfilesDisabled() bool {
|
||||
if s.config != nil {
|
||||
if v, ok := s.config.Policy().GetBool(mdm.KeyDisableProfiles); ok {
|
||||
return v
|
||||
}
|
||||
// Check if the environment variable is set to disable profiles
|
||||
if s.profilesDisabled {
|
||||
return true
|
||||
}
|
||||
return s.profilesDisabled
|
||||
}
|
||||
|
||||
// checkNetworksDisabled reports whether the networks/exit-node feature
|
||||
// is disabled on this daemon instance. Resolved MDM-first: when the
|
||||
// active policy declares mdm.KeyDisableNetworks the policy value wins
|
||||
// (regardless of true/false), so an admin can re-enable the feature
|
||||
// via MDM even on a host that was installed with --disable-networks.
|
||||
// Falls back to the s.networksDisabled CLI flag when the policy is
|
||||
// silent on the key. Mirrors checkProfilesDisabled and
|
||||
// checkUpdateSettingsDisabled.
|
||||
func (s *Server) checkNetworksDisabled() bool {
|
||||
if s.config != nil {
|
||||
if v, ok := s.config.Policy().GetBool(mdm.KeyDisableNetworks); ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return s.networksDisabled
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Server) checkUpdateSettingsDisabled() bool {
|
||||
if s.config != nil {
|
||||
if v, ok := s.config.Policy().GetBool(mdm.KeyDisableUpdateSettings); ok {
|
||||
return v
|
||||
}
|
||||
// Check if the environment variable is set to disable profiles
|
||||
if s.updateSettingsDisabled {
|
||||
return true
|
||||
}
|
||||
return s.updateSettingsDisabled
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Server) startUpdateManagerForGUI() {
|
||||
|
||||
@@ -101,7 +101,6 @@ func TestCleanupConnection_ClearsConnectClient(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
|
||||
assert.False(t, s.clientRunning, "clientRunning should be cleared after cleanup (intent = down)")
|
||||
}
|
||||
|
||||
// TestCleanState_NilConnectClient validates that CleanState doesn't panic
|
||||
@@ -145,20 +144,17 @@ func TestDownThenUp_StaleRunningChan(t *testing.T) {
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
s.actCancel = cancel
|
||||
|
||||
// Simulate Down(): cleanupConnection sets connectClient = nil and
|
||||
// flips clientRunning to false (intent = down). The connectionGoroutineRunning state
|
||||
// remains independent of intent — derived from clientGiveUpChan.
|
||||
// Simulate Down(): cleanupConnection sets connectClient = nil
|
||||
s.mutex.Lock()
|
||||
err := s.cleanupConnection()
|
||||
s.mutex.Unlock()
|
||||
require.NoError(t, err)
|
||||
|
||||
// After cleanup: connectClient is nil, clientRunning is false (intent
|
||||
// cleared by cleanupConnection), connectionGoroutineRunning may still be true
|
||||
// (goroutine teardown is independent of the intent flag).
|
||||
// After cleanup: connectClient is nil, clientRunning still true
|
||||
// (goroutine hasn't exited yet)
|
||||
s.mutex.Lock()
|
||||
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
|
||||
assert.False(t, s.clientRunning, "clientRunning should be cleared by cleanupConnection (intent = down)")
|
||||
assert.True(t, s.clientRunning, "clientRunning still true until goroutine exits")
|
||||
s.mutex.Unlock()
|
||||
|
||||
// waitForUp() returns immediately due to stale closed clientRunningChan
|
||||
|
||||
@@ -1,198 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// withMDMPolicy temporarily overrides the server-package loadMDMPolicy hook
|
||||
// so SetConfig observes the supplied Policy. Restores the original loader
|
||||
// at test cleanup.
|
||||
func withMDMPolicy(t *testing.T, policy *mdm.Policy) {
|
||||
t.Helper()
|
||||
prev := loadMDMPolicy
|
||||
loadMDMPolicy = func() *mdm.Policy { return policy }
|
||||
t.Cleanup(func() { loadMDMPolicy = prev })
|
||||
}
|
||||
|
||||
// setupServerWithProfile mirrors the boilerplate of TestSetConfig_AllFieldsSaved:
|
||||
// overrides profilemanager paths to a temp dir, seeds a profile, sets it
|
||||
// active, and constructs a Server instance. Returns the constructed server
|
||||
// plus context + profile name + username + cfgPath for the seeded profile.
|
||||
func setupServerWithProfile(t *testing.T) (s *Server, ctx context.Context, profName, username, cfgPath string) {
|
||||
t.Helper()
|
||||
tempDir := t.TempDir()
|
||||
|
||||
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
|
||||
origDefaultConfigPath := profilemanager.DefaultConfigPath
|
||||
origActiveProfileStatePath := profilemanager.ActiveProfileStatePath
|
||||
profilemanager.ConfigDirOverride = tempDir
|
||||
profilemanager.DefaultConfigPathDir = tempDir
|
||||
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
|
||||
profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json")
|
||||
t.Cleanup(func() {
|
||||
profilemanager.DefaultConfigPathDir = origDefaultProfileDir
|
||||
profilemanager.ActiveProfileStatePath = origActiveProfileStatePath
|
||||
profilemanager.DefaultConfigPath = origDefaultConfigPath
|
||||
profilemanager.ConfigDirOverride = ""
|
||||
})
|
||||
|
||||
currUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
profName = "test-profile-mdm"
|
||||
cfgPath = filepath.Join(tempDir, profName+".json")
|
||||
|
||||
_, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: cfgPath,
|
||||
ManagementURL: "https://api.netbird.io:443",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
require.NoError(t, pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: profName,
|
||||
Username: currUser.Username,
|
||||
}))
|
||||
|
||||
ctx = context.Background()
|
||||
s = New(ctx, "console", "", false, false, false, false)
|
||||
return s, ctx, profName, currUser.Username, cfgPath
|
||||
}
|
||||
|
||||
// extractViolation pulls the MDMManagedFieldsViolation detail from a
|
||||
// FailedPrecondition error. Fails the test if absent or malformed.
|
||||
func extractViolation(t *testing.T, err error) *proto.MDMManagedFieldsViolation {
|
||||
t.Helper()
|
||||
require.Error(t, err)
|
||||
st, ok := gstatus.FromError(err)
|
||||
require.True(t, ok, "error must be a gRPC status: %v", err)
|
||||
require.Equal(t, codes.FailedPrecondition, st.Code(), "expected FailedPrecondition, got %s", st.Code())
|
||||
for _, d := range st.Details() {
|
||||
if v, ok := d.(*proto.MDMManagedFieldsViolation); ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
t.Fatalf("MDMManagedFieldsViolation detail not found on status; details: %v", st.Details())
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMReject_SingleField(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
ManagementUrl: "https://user.tried.this.com:443",
|
||||
})
|
||||
|
||||
v := extractViolation(t, err)
|
||||
assert.Equal(t, []string{mdm.KeyManagementURL}, v.GetFields())
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMReject_MultipleFields(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
mdm.KeyBlockInbound: true,
|
||||
mdm.KeyRosenpassEnabled: true,
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
blockInbound := false
|
||||
rosenpassEnabled := false
|
||||
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
ManagementUrl: "https://user.tried.this.com:443",
|
||||
BlockInbound: &blockInbound,
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
})
|
||||
|
||||
v := extractViolation(t, err)
|
||||
assert.ElementsMatch(t, []string{
|
||||
mdm.KeyManagementURL,
|
||||
mdm.KeyBlockInbound,
|
||||
mdm.KeyRosenpassEnabled,
|
||||
}, v.GetFields())
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMReject_AllOrNothing(t *testing.T) {
|
||||
// MDM enforces ManagementURL only; user request touches both the
|
||||
// enforced field AND a non-enforced field (RosenpassEnabled).
|
||||
// The whole request must be rejected — non-conflicting fields are not
|
||||
// applied either.
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, cfgPath := setupServerWithProfile(t)
|
||||
|
||||
rosenpassEnabled := true
|
||||
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
ManagementUrl: "https://user.tried.this.com:443",
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
})
|
||||
|
||||
v := extractViolation(t, err)
|
||||
assert.Equal(t, []string{mdm.KeyManagementURL}, v.GetFields())
|
||||
|
||||
// Confirm RosenpassEnabled was NOT applied even though it was not
|
||||
// in the conflict list: the request was rejected as a whole.
|
||||
reloaded, err := profilemanager.GetConfig(cfgPath)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, reloaded.RosenpassEnabled, "non-conflicting field must not be applied when request is rejected")
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMAllow_NonManagedFields(t *testing.T) {
|
||||
// MDM enforces ManagementURL but the user only writes RosenpassEnabled.
|
||||
// Request must succeed.
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
rosenpassEnabled := true
|
||||
resp, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMEmpty_NoEnforcement(t *testing.T) {
|
||||
// No MDM policy active: any field can be written.
|
||||
withMDMPolicy(t, mdm.NewPolicy(nil))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
resp, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
ManagementUrl: "https://user.changed.url.com:443",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/creack/pty"
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/netbirdio/netbird/client/internal/shell"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -146,10 +147,10 @@ func (s *Server) createShellCommand(ctx context.Context, shell string, args []st
|
||||
|
||||
// prepareCommandEnv prepares environment variables for command execution on Unix
|
||||
func (s *Server) prepareCommandEnv(_ *log.Entry, localUser *user.User, session ssh.Session) []string {
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
if shell.AcceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,10 +247,10 @@ func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, sess
|
||||
userEnv, err := s.getUserEnvironment(logger, username, domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
if shell.AcceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
@@ -260,7 +260,7 @@ func (s *Server) prepareCommandEnv(logger *log.Entry, localUser *user.User, sess
|
||||
env := userEnv
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
if shell.AcceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
@@ -273,7 +273,7 @@ func (s *Server) handlePtyLogin(logger *log.Entry, session ssh.Session, privileg
|
||||
return false
|
||||
}
|
||||
|
||||
shell := getUserShell(privilegeResult.User.Uid)
|
||||
shell := shell.GetUserShell(privilegeResult.User.Uid)
|
||||
logger.Infof("starting interactive shell: %s", shell)
|
||||
|
||||
s.executeCommandWithPty(logger, session, nil, privilegeResult, ptyReq, nil)
|
||||
@@ -384,7 +384,7 @@ func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, _
|
||||
}
|
||||
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
shell := getUserShell(localUser.Uid)
|
||||
shell := shell.GetUserShell(localUser.Uid)
|
||||
|
||||
req := PtyExecutionRequest{
|
||||
Shell: shell,
|
||||
|
||||
@@ -3,11 +3,15 @@ package server
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/netbirdio/netbird/client/internal/shell"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -23,8 +27,8 @@ func isPlatformUnix() bool {
|
||||
|
||||
// Dependency injection variables for testing - allows mocking dynamic runtime checks
|
||||
var (
|
||||
getCurrentUser = currentUserWithGetent
|
||||
lookupUser = lookupWithGetent
|
||||
getCurrentUser = shell.CurrentUserWithGetent
|
||||
lookupUser = shell.LookupWithGetent
|
||||
getCurrentOS = func() string { return runtime.GOOS }
|
||||
getIsProcessPrivileged = isCurrentProcessPrivileged
|
||||
|
||||
@@ -409,3 +413,29 @@ func isWindowsElevated() bool {
|
||||
log.Debugf("Windows user switching not supported: not running as privileged user (current: %s)", currentUser.Uid)
|
||||
return false
|
||||
}
|
||||
|
||||
// prepareSSHEnv prepares SSH protocol-specific environment variables
|
||||
// These variables provide information about the SSH connection itself
|
||||
func prepareSSHEnv(session ssh.Session) []string {
|
||||
remoteAddr := session.RemoteAddr()
|
||||
localAddr := session.LocalAddr()
|
||||
|
||||
remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
|
||||
if err != nil {
|
||||
remoteHost = remoteAddr.String()
|
||||
remotePort = "0"
|
||||
}
|
||||
|
||||
localHost, localPort, err := net.SplitHostPort(localAddr.String())
|
||||
if err != nil {
|
||||
localHost = localAddr.String()
|
||||
localPort = strconv.Itoa(InternalSSHPort)
|
||||
}
|
||||
|
||||
return []string{
|
||||
// SSH_CLIENT format: "client_ip client_port server_port"
|
||||
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
|
||||
// SSH_CONNECTION format: "client_ip client_port server_ip server_port"
|
||||
fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/netbirdio/netbird/client/internal/shell"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -160,7 +161,7 @@ func (s *Server) parseUserCredentials(localUser *user.User) (uint32, uint32, []u
|
||||
// getSupplementaryGroups retrieves supplementary group IDs for a user.
|
||||
// Uses id/getent fallback for NSS users in CGO_ENABLED=0 builds.
|
||||
func (s *Server) getSupplementaryGroups(u *user.User) ([]uint32, error) {
|
||||
groupIDStrings, err := groupIdsWithFallback(u)
|
||||
groupIDStrings, err := shell.GroupIdsWithFallback(u)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group IDs for user %s: %w", u.Username, err)
|
||||
}
|
||||
@@ -196,7 +197,7 @@ func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, l
|
||||
GID: gid,
|
||||
Groups: groups,
|
||||
WorkingDir: localUser.HomeDir,
|
||||
Shell: getUserShell(localUser.Uid),
|
||||
Shell: shell.GetUserShell(localUser.Uid),
|
||||
Command: session.RawCommand(),
|
||||
PTY: hasPty,
|
||||
}
|
||||
@@ -228,7 +229,7 @@ func (s *Server) createPtyCommand(privilegeResult PrivilegeCheckResult, ptyReq s
|
||||
func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.User, ptyReq ssh.Pty) *exec.Cmd {
|
||||
log.Debugf("creating direct Pty command for user %s (no user switching needed)", localUser.Username)
|
||||
|
||||
shell := getUserShell(localUser.Uid)
|
||||
shell := shell.GetUserShell(localUser.Uid)
|
||||
args := s.getShellCommandArgs(shell, session.RawCommand())
|
||||
|
||||
cmd := s.createShellCommand(session.Context(), shell, args)
|
||||
@@ -245,12 +246,12 @@ func (s *Server) preparePtyEnv(localUser *user.User, ptyReq ssh.Pty, session ssh
|
||||
termType = "xterm-256color"
|
||||
}
|
||||
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env := shell.PrepareUserEnv(localUser, shell.GetUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
env = append(env, fmt.Sprintf("TERM=%s", termType))
|
||||
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
if shell.AcceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/shell"
|
||||
)
|
||||
|
||||
// validateUsername validates Windows usernames according to SAM Account Name rules
|
||||
@@ -104,7 +106,7 @@ func (s *Server) createExecutorCommand(logger *log.Entry, session ssh.Session, l
|
||||
func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session, localUser *user.User) (*exec.Cmd, func(), error) {
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
|
||||
shell := getUserShell(localUser.Uid)
|
||||
sh := shell.GetUserShell(localUser.Uid)
|
||||
|
||||
rawCmd := session.RawCommand()
|
||||
var command string
|
||||
@@ -116,7 +118,7 @@ func (s *Server) createUserSwitchCommand(logger *log.Entry, session ssh.Session,
|
||||
Username: username,
|
||||
Domain: domain,
|
||||
WorkingDir: localUser.HomeDir,
|
||||
Shell: shell,
|
||||
Shell: sh,
|
||||
Command: command,
|
||||
}
|
||||
|
||||
|
||||
@@ -38,7 +38,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/ui/desktop"
|
||||
"github.com/netbirdio/netbird/client/ui/event"
|
||||
@@ -57,22 +56,8 @@ const (
|
||||
const (
|
||||
censoredPreSharedKey = "**********"
|
||||
maxSSHJWTCacheTTL = 86_400 // 24 hours in seconds
|
||||
// mdmFieldSuffix is appended to plain-text Entry widgets in the
|
||||
// advanced Settings window when the underlying field is enforced
|
||||
// by MDM, so the user sees the lock indicator inline next to the
|
||||
// value. Stripped before any read site that feeds the value back
|
||||
// into a SetConfig request (saveSettings / parseNumericSettings).
|
||||
mdmFieldSuffix = " (MDM)"
|
||||
)
|
||||
|
||||
// main is the entry point for the UI tray/client binary. Parses CLI
|
||||
// flags, initialises logging, builds the Fyne application and tray
|
||||
// icons, and constructs the service client (which may open a
|
||||
// requested UI window). When a window-mode flag is set the Fyne event
|
||||
// loop runs and main returns; otherwise main enforces single-instance
|
||||
// behaviour (signalling an existing instance to show its window when
|
||||
// present), sets up signal handling + default fonts, and runs the
|
||||
// system tray loop.
|
||||
func main() {
|
||||
flags := parseFlags()
|
||||
|
||||
@@ -330,13 +315,9 @@ type serviceClient struct {
|
||||
isUpdateIconActive bool
|
||||
isEnforcedUpdate bool
|
||||
lastNotifiedVersion string
|
||||
settingsEnabled bool
|
||||
profilesEnabled bool
|
||||
networksEnabled bool
|
||||
// networksMenuEnabled caches the last applied enabled-state of the
|
||||
// mNetworks + mExitNode submenu items. Combines features.DisableNetworks
|
||||
// AND s.connected — both must be true for the menus to be active.
|
||||
// Zero value (false) matches the Disable() call at AddMenuItem time.
|
||||
networksMenuEnabled bool
|
||||
showNetworks bool
|
||||
wNetworks fyne.Window
|
||||
wProfiles fyne.Window
|
||||
@@ -355,13 +336,6 @@ type serviceClient struct {
|
||||
updateContextCancel context.CancelFunc
|
||||
|
||||
connectCancel context.CancelFunc
|
||||
|
||||
// mdmManagedFields caches the names of MDM-enforced policy keys
|
||||
// surfaced by the daemon in GetConfigResponse. Each refresh of
|
||||
// daemon config (loadSettings, getSrvConfig, config_changed event)
|
||||
// updates this set and re-applies the lock/badge to the affected
|
||||
// menu items and settings-form widgets.
|
||||
mdmManagedFields map[string]bool
|
||||
}
|
||||
|
||||
type menuHandler struct {
|
||||
@@ -467,12 +441,15 @@ func (s *serviceClient) updateIcon() {
|
||||
}
|
||||
|
||||
func (s *serviceClient) showSettingsUI() {
|
||||
// DisableUpdateSettings no longer gates the window from opening:
|
||||
// the daemon blocks every actual mutation at SetConfig / Login,
|
||||
// so the window is safe to show as a read-only view. The previous
|
||||
// early-return also blocked Advanced Settings whenever update
|
||||
// editing was off, which conflated two distinct kill switches
|
||||
// (see comment in checkAndUpdateFeatures).
|
||||
// Check if update settings are disabled by daemon
|
||||
features, err := s.getFeatures()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get features from daemon: %v", err)
|
||||
// Continue with default behavior if features can't be retrieved
|
||||
} else if features != nil && features.DisableUpdateSettings {
|
||||
log.Warn("Update settings are disabled by daemon")
|
||||
return
|
||||
}
|
||||
|
||||
// add settings window UI elements.
|
||||
s.wSettings = s.app.NewWindow("NetBird Settings")
|
||||
@@ -555,7 +532,7 @@ func (s *serviceClient) saveSettings() {
|
||||
return
|
||||
}
|
||||
|
||||
iMngURL := strings.TrimSpace(strings.TrimSuffix(s.iMngURL.Text, mdmFieldSuffix))
|
||||
iMngURL := strings.TrimSpace(s.iMngURL.Text)
|
||||
|
||||
if s.hasSettingsChanged(iMngURL, port, mtu) {
|
||||
if err := s.applySettingsChanges(iMngURL, port, mtu); err != nil {
|
||||
@@ -577,7 +554,7 @@ func (s *serviceClient) validateSettings() error {
|
||||
}
|
||||
|
||||
func (s *serviceClient) parseNumericSettings() (int64, int64, error) {
|
||||
port, err := strconv.ParseInt(strings.TrimSpace(strings.TrimSuffix(s.iInterfacePort.Text, mdmFieldSuffix)), 10, 64)
|
||||
port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, errors.New("invalid interface port")
|
||||
}
|
||||
@@ -686,15 +663,7 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
|
||||
req.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||
}
|
||||
|
||||
// Only attach the PSK when the user actually typed something:
|
||||
// - "" means the field was left untouched (we deliberately render
|
||||
// an empty Text + placeholder hint to avoid leaking the daemon's
|
||||
// "**********" redaction through the password reveal toggle);
|
||||
// sending an empty pointer would tell the daemon to clear / overwrite
|
||||
// the on-disk or MDM-enforced PSK, which then trips the MDM
|
||||
// conflict gate when PSK is policy-managed.
|
||||
// - "**********" is the redacted echo (legacy non-MDM path); also a no-op.
|
||||
if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey {
|
||||
if s.iPreSharedKey.Text != censoredPreSharedKey {
|
||||
req.OptionalPreSharedKey = &s.iPreSharedKey.Text
|
||||
}
|
||||
|
||||
@@ -1067,13 +1036,6 @@ func (s *serviceClient) onTrayReady() {
|
||||
}
|
||||
|
||||
s.mProfile = newProfileMenu(*newProfileMenuArgs)
|
||||
// Seed the transition cache to match the actual default menu
|
||||
// state (visible / enabled). Without this, the first
|
||||
// checkAndUpdateFeatures tick that observes DisableProfiles=true
|
||||
// is a no-op (cache zero-value == desired-false) and the menu
|
||||
// never gets hidden — symptom: MDM enforces the kill switch but
|
||||
// the profile menu stays clickable.
|
||||
s.profilesEnabled = true
|
||||
|
||||
systray.AddSeparator()
|
||||
s.mUp = systray.AddMenuItem("Connect", "Connect")
|
||||
@@ -1093,18 +1055,18 @@ func (s *serviceClient) onTrayReady() {
|
||||
s.mCreateDebugBundle = s.mSettings.AddSubMenuItem("Create Debug Bundle", debugBundleMenuDescr)
|
||||
s.loadSettings()
|
||||
|
||||
// Disable profile menu if profiles are disabled by daemon.
|
||||
// DisableUpdateSettings is enforced at the daemon's SetConfig /
|
||||
// Login gates, not by hiding the UI — so the Settings menu (and
|
||||
// its Advanced Settings submenu, which has its own kill switch)
|
||||
// stays visible and the user can still inspect current values.
|
||||
// Disable settings menu if update settings are disabled by daemon
|
||||
features, err := s.getFeatures()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get features from daemon: %v", err)
|
||||
// Continue with default behavior if features can't be retrieved
|
||||
} else if features != nil && features.DisableProfiles {
|
||||
s.mProfile.setEnabled(false)
|
||||
s.profilesEnabled = false
|
||||
} else {
|
||||
if features != nil && features.DisableUpdateSettings {
|
||||
s.setSettingsEnabled(false)
|
||||
}
|
||||
if features != nil && features.DisableProfiles {
|
||||
s.mProfile.setEnabled(false)
|
||||
}
|
||||
}
|
||||
|
||||
s.exitNodeMu.Lock()
|
||||
@@ -1138,20 +1100,13 @@ func (s *serviceClient) onTrayReady() {
|
||||
// update exit node menu in case service is already connected
|
||||
go s.updateExitNodes()
|
||||
|
||||
// Features (DisableProfiles, DisableUpdateSettings, DisableNetworks,
|
||||
// ...) only change in two ways: at service install time (CLI flag,
|
||||
// static) and at MDM ticker diff time. The daemon already publishes
|
||||
// a SystemEvent{type=config_changed} on every MDM-driven engine
|
||||
// restart, so the UI no longer needs to poll GetFeatures every 2 s.
|
||||
// A single fetch at startup covers the static CLI-flag case; the
|
||||
// event handler below covers MDM transitions. updateStatus stays in
|
||||
// the 2 s loop because connection / peer state genuinely change
|
||||
// continuously and have no event yet.
|
||||
s.checkAndUpdateFeatures()
|
||||
go func() {
|
||||
s.getSrvConfig()
|
||||
time.Sleep(100 * time.Millisecond) // To prevent race condition caused by systray not being fully initialized and ignoring setIcon
|
||||
for {
|
||||
// Check features before status so menus respect disable flags before being enabled
|
||||
s.checkAndUpdateFeatures()
|
||||
|
||||
err := s.updateStatus()
|
||||
if err != nil {
|
||||
log.Errorf("error while updating status: %v", err)
|
||||
@@ -1195,23 +1150,6 @@ func (s *serviceClient) onTrayReady() {
|
||||
s.onUpdateAvailable(newVersion, enforced)
|
||||
}
|
||||
})
|
||||
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
|
||||
// Daemon emits a config_changed event after every engine spawn
|
||||
// (Server.Start, Server.Up, MDM ticker restart). Re-sync the
|
||||
// tray submenu checkboxes from the fresh daemon-side config so
|
||||
// the user does not have to restart the tray to see CLI- or
|
||||
// MDM-driven changes.
|
||||
if event.Category == proto.SystemEvent_SYSTEM && event.Metadata["type"] == "config_changed" {
|
||||
log.Infof("config_changed event received (source=%s); refreshing settings + features", event.Metadata["source"])
|
||||
s.loadSettings()
|
||||
// MDM-driven feature kill switches (DisableProfiles /
|
||||
// DisableUpdateSettings / DisableNetworks) ride the same
|
||||
// config_changed signal because the daemon re-applies its
|
||||
// MDM policy on every engine spawn. Pull them in here so
|
||||
// the UI is up to date without a periodic GetFeatures poll.
|
||||
s.checkAndUpdateFeatures()
|
||||
}
|
||||
})
|
||||
|
||||
go s.eventManager.Start(s.ctx)
|
||||
go s.eventHandler.listen(s.ctx)
|
||||
@@ -1275,6 +1213,18 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService
|
||||
return s.conn, nil
|
||||
}
|
||||
|
||||
// setSettingsEnabled enables or disables the settings menu based on the provided state
|
||||
func (s *serviceClient) setSettingsEnabled(enabled bool) {
|
||||
if s.mSettings != nil {
|
||||
if enabled {
|
||||
s.mSettings.Enable()
|
||||
} else {
|
||||
s.mSettings.Hide()
|
||||
s.mSettings.SetTooltip("Settings are disabled by daemon")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkAndUpdateFeatures checks the current features and updates the UI accordingly
|
||||
func (s *serviceClient) checkAndUpdateFeatures() {
|
||||
features, err := s.getFeatures()
|
||||
@@ -1286,11 +1236,12 @@ func (s *serviceClient) checkAndUpdateFeatures() {
|
||||
s.updateIndicationLock.Lock()
|
||||
defer s.updateIndicationLock.Unlock()
|
||||
|
||||
// DisableUpdateSettings is enforced server-side by the daemon gates
|
||||
// on SetConfig + Login: any attempt to mutate config from UI or
|
||||
// CLI is rejected at that layer. The UI deliberately keeps the
|
||||
// Settings menu visible so the user can still inspect current
|
||||
// values — read-only by virtue of the daemon refusing edits.
|
||||
// Update settings menu based on current features
|
||||
settingsEnabled := features == nil || !features.DisableUpdateSettings
|
||||
if s.settingsEnabled != settingsEnabled {
|
||||
s.settingsEnabled = settingsEnabled
|
||||
s.setSettingsEnabled(settingsEnabled)
|
||||
}
|
||||
|
||||
// Update profile menu based on current features
|
||||
if s.mProfile != nil {
|
||||
@@ -1301,23 +1252,14 @@ func (s *serviceClient) checkAndUpdateFeatures() {
|
||||
}
|
||||
}
|
||||
|
||||
// Update networks and exit node menus based on current features.
|
||||
// `networksEnabled` is the bare feature flag (read elsewhere, e.g. at
|
||||
// connection-status transitions). `networksMenuEnabled` is the
|
||||
// transition-cached state actually applied to the menu items —
|
||||
// it folds in the connection state so a Connected client with the
|
||||
// kill switch off shows the menus active, and only flips on diff.
|
||||
// Update networks and exit node menus based on current features
|
||||
s.networksEnabled = features == nil || !features.DisableNetworks
|
||||
desiredNetworksMenu := s.networksEnabled && s.connected
|
||||
if desiredNetworksMenu != s.networksMenuEnabled {
|
||||
s.networksMenuEnabled = desiredNetworksMenu
|
||||
if desiredNetworksMenu {
|
||||
s.mNetworks.Enable()
|
||||
s.mExitNode.Enable()
|
||||
} else {
|
||||
s.mNetworks.Disable()
|
||||
s.mExitNode.Disable()
|
||||
}
|
||||
if s.networksEnabled && s.connected {
|
||||
s.mNetworks.Enable()
|
||||
s.mExitNode.Enable()
|
||||
} else {
|
||||
s.mNetworks.Disable()
|
||||
s.mExitNode.Disable()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1414,14 +1356,7 @@ func (s *serviceClient) getSrvConfig() {
|
||||
|
||||
if s.showAdvancedSettings {
|
||||
s.iMngURL.SetText(s.managementURL)
|
||||
// PSK is rendered with an empty Text and a hint via the
|
||||
// placeholder so the eye toggle never reveals literal asterisks
|
||||
// (the daemon returns the "**********" sentinel — writing that
|
||||
// into a PasswordEntry would surface the literal sentinel when
|
||||
// the user unmasks the field). The placeholder communicates the
|
||||
// configured / MDM-managed state without exposing any value.
|
||||
s.iPreSharedKey.SetText("")
|
||||
s.iPreSharedKey.SetPlaceHolder(preSharedKeyPlaceholder(srvCfg))
|
||||
s.iPreSharedKey.SetText(cfg.PreSharedKey)
|
||||
s.iInterfaceName.SetText(cfg.WgIface)
|
||||
s.iInterfacePort.SetText(strconv.Itoa(cfg.WgPort))
|
||||
if cfg.MTU != 0 {
|
||||
@@ -1431,15 +1366,7 @@ func (s *serviceClient) getSrvConfig() {
|
||||
s.iMTU.SetPlaceHolder(strconv.Itoa(int(iface.DefaultMTU)))
|
||||
}
|
||||
s.sRosenpassPermissive.SetChecked(cfg.RosenpassPermissive)
|
||||
// Re-baseline the enabled state on every refresh: when Rosenpass
|
||||
// is on the checkbox is editable, when it's off the field is
|
||||
// inert. Without an explicit Enable() here the control stays
|
||||
// stuck disabled after a previous refresh (or an MDM unlock) had
|
||||
// turned it off — applyMDMLocksToSettingsForm below adds the
|
||||
// MDM lock on top of this baseline.
|
||||
if cfg.RosenpassEnabled {
|
||||
s.sRosenpassPermissive.Enable()
|
||||
} else {
|
||||
if !cfg.RosenpassEnabled {
|
||||
s.sRosenpassPermissive.Disable()
|
||||
}
|
||||
s.sNetworkMonitor.SetChecked(*cfg.NetworkMonitor)
|
||||
@@ -1468,13 +1395,6 @@ func (s *serviceClient) getSrvConfig() {
|
||||
}
|
||||
}
|
||||
|
||||
// MDM locks must run before the mNotifications-nil early return:
|
||||
// the Settings window is rendered by a separate UI process launched
|
||||
// with --settings (see handleAdvancedSettingsClick), and that child
|
||||
// process does NOT run onReady — so its mNotifications is nil and
|
||||
// the early return below skipped the lock pass entirely.
|
||||
s.applyMDMLocks(srvCfg.MDMManagedFields)
|
||||
|
||||
if s.mNotifications == nil {
|
||||
return
|
||||
}
|
||||
@@ -1659,129 +1579,6 @@ func (s *serviceClient) loadSettings() {
|
||||
if s.eventManager != nil {
|
||||
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
|
||||
}
|
||||
s.applyMDMLocks(cfg.MDMManagedFields)
|
||||
}
|
||||
|
||||
// applyMDMLocks disables and badges any tray submenu item or settings-
|
||||
// form widget whose underlying field is enforced by the active MDM
|
||||
// policy. Called from loadSettings (submenu refresh) and from
|
||||
// getSrvConfig (settings-window refresh). Locked items keep their value
|
||||
// already set by the surrounding refresh code — this routine only
|
||||
// flips the enabled state and the title suffix, never the value.
|
||||
func (s *serviceClient) applyMDMLocks(managed []string) {
|
||||
set := make(map[string]bool, len(managed))
|
||||
for _, k := range managed {
|
||||
set[k] = true
|
||||
}
|
||||
s.mdmManagedFields = set
|
||||
if len(managed) > 0 {
|
||||
log.Infof("MDM-managed UI fields: %v", managed)
|
||||
}
|
||||
|
||||
type submenuTarget struct {
|
||||
item *systray.MenuItem
|
||||
title string
|
||||
key string
|
||||
}
|
||||
for _, t := range []submenuTarget{
|
||||
{s.mAllowSSH, "Allow SSH", mdm.KeyAllowServerSSH},
|
||||
{s.mAutoConnect, "Connect on Startup", mdm.KeyDisableAutoConnect},
|
||||
{s.mEnableRosenpass, "Enable Quantum-Resistance", mdm.KeyRosenpassEnabled},
|
||||
{s.mBlockInbound, "Block Inbound Connections", mdm.KeyBlockInbound},
|
||||
} {
|
||||
if t.item == nil {
|
||||
continue
|
||||
}
|
||||
if set[t.key] {
|
||||
t.item.SetTitle(t.title + " (MDM)")
|
||||
t.item.Disable()
|
||||
} else {
|
||||
t.item.SetTitle(t.title)
|
||||
t.item.Enable()
|
||||
}
|
||||
}
|
||||
|
||||
s.applyMDMLocksToSettingsForm(set)
|
||||
}
|
||||
|
||||
// preSharedKeyPlaceholder returns the hint string shown in the PSK
|
||||
// Entry's placeholder slot. The placeholder is the only signal the
|
||||
// user gets that a PSK is configured, because the entry's Text is
|
||||
// forced to empty to keep the password reveal toggle from leaking
|
||||
// the daemon-returned "**********" redaction sentinel. Returns "" if
|
||||
// no PSK is present, "MDM-managed" if the key is enforced by MDM,
|
||||
// and "configured" otherwise.
|
||||
func preSharedKeyPlaceholder(cfg *proto.GetConfigResponse) string {
|
||||
if cfg == nil || cfg.PreSharedKey == "" {
|
||||
return ""
|
||||
}
|
||||
for _, k := range cfg.MDMManagedFields {
|
||||
if k == mdm.KeyPreSharedKey {
|
||||
return "MDM-managed"
|
||||
}
|
||||
}
|
||||
return "configured"
|
||||
}
|
||||
|
||||
// applyMDMLocksToSettingsForm disables the per-field input widgets in
|
||||
// the advanced Settings window when the corresponding MDM key is set.
|
||||
// For plain-text entries (Management URL, Interface Port) the visible
|
||||
// value is suffixed with " (MDM)" so the user sees the lock indicator
|
||||
// inline; for the password entry the suffix is skipped (a password
|
||||
// widget renders every char as a dot and the indicator would not be
|
||||
// readable). The widgets are created lazily by showSettingsUI, so
|
||||
// guard each ref against nil.
|
||||
func (s *serviceClient) applyMDMLocksToSettingsForm(set map[string]bool) {
|
||||
type entryTarget struct {
|
||||
entry *widget.Entry
|
||||
key string
|
||||
inlineTag bool
|
||||
}
|
||||
for _, t := range []entryTarget{
|
||||
{s.iMngURL, mdm.KeyManagementURL, true},
|
||||
{s.iPreSharedKey, mdm.KeyPreSharedKey, false},
|
||||
{s.iInterfacePort, mdm.KeyWireguardPort, true},
|
||||
} {
|
||||
if t.entry == nil {
|
||||
continue
|
||||
}
|
||||
if set[t.key] {
|
||||
if t.inlineTag && t.entry.Text != "" && !strings.HasSuffix(t.entry.Text, mdmFieldSuffix) {
|
||||
t.entry.SetText(t.entry.Text + mdmFieldSuffix)
|
||||
}
|
||||
t.entry.Disable()
|
||||
} else {
|
||||
if t.inlineTag {
|
||||
t.entry.SetText(strings.TrimSuffix(t.entry.Text, mdmFieldSuffix))
|
||||
}
|
||||
t.entry.Enable()
|
||||
}
|
||||
}
|
||||
type checkTarget struct {
|
||||
check *widget.Check
|
||||
key string
|
||||
}
|
||||
for _, t := range []checkTarget{
|
||||
{s.sDisableClientRoutes, mdm.KeyDisableClientRoutes},
|
||||
{s.sDisableServerRoutes, mdm.KeyDisableServerRoutes},
|
||||
} {
|
||||
if t.check == nil {
|
||||
continue
|
||||
}
|
||||
if set[t.key] {
|
||||
t.check.Disable()
|
||||
} else {
|
||||
t.check.Enable()
|
||||
}
|
||||
}
|
||||
if s.sRosenpassPermissive != nil && set[mdm.KeyRosenpassPermissive] {
|
||||
// MDM lock layered on top of the Rosenpass-on/off baseline
|
||||
// applied by getSrvConfig. No Enable() branch here: when the
|
||||
// MDM key is removed, the next getSrvConfig refresh re-baselines
|
||||
// the control on cfg.RosenpassEnabled and brings it back if
|
||||
// Rosenpass is on.
|
||||
s.sRosenpassPermissive.Disable()
|
||||
}
|
||||
}
|
||||
|
||||
// updateConfig updates the configuration parameters
|
||||
|
||||
@@ -666,49 +666,17 @@ func (p *profileMenu) clear(profiles []Profile) {
|
||||
}
|
||||
}
|
||||
|
||||
// setEnabled greys out (Disable) the profile menu and every existing
|
||||
// sub-item when the daemon reports the kill switch active, so the user
|
||||
// sees the menu but cannot enter "Manage Profiles" or switch profile.
|
||||
// Previously this used Hide() on the parent, but Fyne's systray on
|
||||
// Windows does not propagate Hide() to a parent that already has
|
||||
// children — the submenu kept popping up and accepting clicks. Disable
|
||||
// is the reliable visual lock.
|
||||
// setEnabled enables or disables the profile menu based on the provided state
|
||||
func (p *profileMenu) setEnabled(enabled bool) {
|
||||
if p.profileMenuItem == nil {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if enabled {
|
||||
p.profileMenuItem.Enable()
|
||||
p.profileMenuItem.SetTooltip("")
|
||||
} else {
|
||||
p.profileMenuItem.Disable()
|
||||
p.profileMenuItem.SetTooltip("Profiles are disabled by daemon")
|
||||
}
|
||||
|
||||
apply := func(item *systray.MenuItem) {
|
||||
if item == nil {
|
||||
return
|
||||
}
|
||||
if p.profileMenuItem != nil {
|
||||
if enabled {
|
||||
item.Enable()
|
||||
p.profileMenuItem.Enable()
|
||||
p.profileMenuItem.SetTooltip("")
|
||||
} else {
|
||||
item.Disable()
|
||||
p.profileMenuItem.Hide()
|
||||
p.profileMenuItem.SetTooltip("Profiles are disabled by daemon")
|
||||
}
|
||||
}
|
||||
for _, sub := range p.profileSubItems {
|
||||
if sub != nil {
|
||||
apply(sub.MenuItem)
|
||||
}
|
||||
}
|
||||
if p.manageProfilesSubItem != nil {
|
||||
apply(p.manageProfilesSubItem.MenuItem)
|
||||
}
|
||||
if p.logoutSubItem != nil {
|
||||
apply(p.logoutSubItem.MenuItem)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *profileMenu) updateMenu() {
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<!--
|
||||
NetBird MDM preferences (macOS) — bare plist for MDM platforms that
|
||||
accept a managed-preferences plist tied to a bundle identifier
|
||||
(e.g. JumpCloud "Mac Application Custom Settings", Mosyle "Custom
|
||||
Settings", Jamf "Application & Custom Settings" → External
|
||||
Application).
|
||||
|
||||
Bundle identifier (preference domain): io.netbird.client
|
||||
|
||||
The MDM provider will wrap this plist into a Configuration Profile
|
||||
payload of type com.apple.ManagedClient.preferences and push it to
|
||||
target devices via the Apple MDM protocol. The OS materializes the
|
||||
final file at:
|
||||
/Library/Managed Preferences/io.netbird.client.plist
|
||||
which is what the NetBird daemon's client/mdm/policy_darwin.go
|
||||
loader reads on every 1-minute MDM reload tick.
|
||||
|
||||
For MDM platforms that expect a full Configuration Profile instead
|
||||
of a bare plist (Custom Configuration Profile / .mobileconfig upload),
|
||||
use docs/netbird-macos.mobileconfig — same keys, additional Payload*
|
||||
envelope.
|
||||
|
||||
Editing this file:
|
||||
- Remove or comment out any key you do NOT want to enforce. The
|
||||
daemon treats an absent key as "no enforcement" for that field.
|
||||
- Keep the document well-formed XML. Validate locally with:
|
||||
plutil -lint docs/io.netbird.client.plist
|
||||
- Keys are camelCase; values are typed (<string>, <true/>, <false/>,
|
||||
<integer>). See docs/src/pages/client/mdm-integration.mdx (the
|
||||
public docs page) for the full reference.
|
||||
|
||||
Persistence caveat:
|
||||
macOS wipes /Library/Managed Preferences/ at every boot on
|
||||
devices that are NOT MDM-enrolled. This plist only sticks across
|
||||
reboots when delivered through a real MDM channel. For local
|
||||
testing on an un-enrolled host, write the file manually as root
|
||||
and accept it will not survive the next boot.
|
||||
-->
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
|
||||
<!-- ===== Identity / auth ===== -->
|
||||
<key>managementURL</key>
|
||||
<string>https://api.netbird.io:443</string>
|
||||
|
||||
<!--
|
||||
Pre-shared key: secret. Remove the entry entirely when not used;
|
||||
do NOT leave an empty <string></string>, which the daemon would
|
||||
otherwise treat as a deliberate empty-PSK enforcement.
|
||||
-->
|
||||
<!--
|
||||
<key>preSharedKey</key>
|
||||
<string>REPLACE_ME</string>
|
||||
-->
|
||||
|
||||
<!-- ===== Engine / runtime behavior =====
|
||||
Each key is optional. Remove or comment out to leave the
|
||||
field unmanaged on the client. -->
|
||||
|
||||
<key>allowServerSSH</key>
|
||||
<true/>
|
||||
|
||||
<!--
|
||||
<key>disableAutoConnect</key>
|
||||
<false/>
|
||||
|
||||
<key>disableClientRoutes</key>
|
||||
<false/>
|
||||
|
||||
<key>disableServerRoutes</key>
|
||||
<false/>
|
||||
|
||||
<key>blockInbound</key>
|
||||
<false/>
|
||||
|
||||
<key>rosenpassEnabled</key>
|
||||
<true/>
|
||||
|
||||
<key>rosenpassPermissive</key>
|
||||
<false/>
|
||||
-->
|
||||
|
||||
<!-- ===== WireGuard UDP port =====
|
||||
Range 1-65535. Omit to keep the daemon default. -->
|
||||
<!--
|
||||
<key>wireguardPort</key>
|
||||
<integer>51820</integer>
|
||||
-->
|
||||
|
||||
<!-- ===== UI / lockdown kill switches =====
|
||||
disableUpdateSettings : block every config change from UI and CLI
|
||||
on this device (Settings view stays
|
||||
readable but read-only).
|
||||
disableProfiles : hide the profile menu, reject profile CRUD.
|
||||
disableNetworks : hide the Networks / Exit Node menus,
|
||||
reject the related RPCs.
|
||||
disableMetricsCollection: opt out of anonymous usage telemetry. -->
|
||||
<!--
|
||||
<key>disableUpdateSettings</key>
|
||||
<true/>
|
||||
|
||||
<key>disableProfiles</key>
|
||||
<true/>
|
||||
|
||||
<key>disableNetworks</key>
|
||||
<true/>
|
||||
|
||||
<key>disableMetricsCollection</key>
|
||||
<false/>
|
||||
-->
|
||||
|
||||
<!-- ===== Split tunnel =====
|
||||
Android-only at the client level. Safe to ship on macOS for
|
||||
mixed-platform fleets; the macOS daemon parses and ignores. -->
|
||||
<!--
|
||||
<key>splitTunnelMode</key>
|
||||
<string>allow</string>
|
||||
|
||||
<key>splitTunnelApps</key>
|
||||
<string>com.acme.app1,com.acme.app2</string>
|
||||
-->
|
||||
|
||||
</dict>
|
||||
</plist>
|
||||
@@ -1,159 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<!--
|
||||
NetBird MDM configuration profile (macOS).
|
||||
|
||||
Wraps a `com.apple.ManagedClient.preferences` payload that pushes the
|
||||
NetBird MDM policy into:
|
||||
/Library/Managed Preferences/io.netbird.client.plist
|
||||
|
||||
Read at runtime by the netbird daemon's macOS loader
|
||||
(client/mdm/policy_darwin.go — Phase 2). Key names match the canonical
|
||||
lowerCamelCase form used in docs/netbird.admx and the mdm.Key*
|
||||
constants in client/mdm/policy.go.
|
||||
|
||||
Bundle identifier: io.netbird.client
|
||||
(confirm against the signed pkg before fleet roll-out)
|
||||
|
||||
Distribution:
|
||||
- sign with `productsign --sign "Developer ID Installer: ..." ...`
|
||||
before fleet roll-out (Apple-Configurator-2 won't install an
|
||||
unsigned profile on Sonoma+ without user override).
|
||||
- For local dev install: `sudo profiles install -path netbird-macos.mobileconfig`.
|
||||
- For MDM (Jamf/Kandji/Mosyle/Intune): upload as a Custom Profile.
|
||||
|
||||
Editing:
|
||||
- Replace UUID placeholders below with fresh UUIDs (`uuidgen` on
|
||||
macOS) when forking this template for a real fleet — each
|
||||
deployment should have unique UUIDs so the OS treats it as a
|
||||
distinct profile.
|
||||
- Tune the PayloadContent values to the policy you want to enforce.
|
||||
- Remove any key you do NOT want to enforce (the daemon treats an
|
||||
absent key as "no enforcement" for that field).
|
||||
|
||||
iOS note:
|
||||
This file is macOS-specific. iOS uses managed app config via
|
||||
UserDefaults[com.apple.configuration.managed] under a different
|
||||
payload type (com.apple.app.configuration.managed); the wrapper
|
||||
structure is the same but the inner payload dictionary differs.
|
||||
See docs/netbird-ios.mobileconfig (Phase 5) when shipped.
|
||||
-->
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<!-- Outer profile envelope -->
|
||||
<key>PayloadType</key>
|
||||
<string>Configuration</string>
|
||||
<key>PayloadVersion</key>
|
||||
<integer>1</integer>
|
||||
<key>PayloadIdentifier</key>
|
||||
<string>io.netbird.client.mdm</string>
|
||||
<key>PayloadUUID</key>
|
||||
<string>11111111-1111-1111-1111-111111111111</string>
|
||||
<key>PayloadDisplayName</key>
|
||||
<string>NetBird MDM Policy</string>
|
||||
<key>PayloadDescription</key>
|
||||
<string>Enforces NetBird client configuration. Values written here override any local user / CLI / on-disk setting and are re-applied at every daemon boot and on every 1-minute MDM reload tick.</string>
|
||||
<key>PayloadOrganization</key>
|
||||
<string>NetBird</string>
|
||||
<key>PayloadScope</key>
|
||||
<string>System</string>
|
||||
<key>PayloadRemovalDisallowed</key>
|
||||
<false/>
|
||||
|
||||
<key>PayloadContent</key>
|
||||
<array>
|
||||
<dict>
|
||||
<!-- Managed preferences payload: writes /Library/Managed Preferences/io.netbird.client.plist -->
|
||||
<key>PayloadType</key>
|
||||
<string>com.apple.ManagedClient.preferences</string>
|
||||
<key>PayloadVersion</key>
|
||||
<integer>1</integer>
|
||||
<key>PayloadIdentifier</key>
|
||||
<string>io.netbird.client.mdm.preferences</string>
|
||||
<key>PayloadUUID</key>
|
||||
<string>22222222-2222-2222-2222-222222222222</string>
|
||||
<key>PayloadDisplayName</key>
|
||||
<string>NetBird Managed Preferences</string>
|
||||
<key>PayloadEnabled</key>
|
||||
<true/>
|
||||
|
||||
<key>PayloadContent</key>
|
||||
<dict>
|
||||
<key>io.netbird.client</key>
|
||||
<dict>
|
||||
<key>Forced</key>
|
||||
<array>
|
||||
<dict>
|
||||
<key>mcx_preference_settings</key>
|
||||
<dict>
|
||||
|
||||
<!-- ===== Identity / auth (strings) ===== -->
|
||||
<key>managementURL</key>
|
||||
<string>https://api.netbird.io:443</string>
|
||||
|
||||
<!-- Pre-shared key: secret. Remove the entry entirely
|
||||
when not used; do NOT leave an empty string. -->
|
||||
<!--
|
||||
<key>preSharedKey</key>
|
||||
<string>REPLACE_ME</string>
|
||||
-->
|
||||
|
||||
<!-- ===== Engine / runtime behavior (bool) =====
|
||||
Remove any key to leave the field unmanaged. -->
|
||||
<!--
|
||||
<key>disableAutoConnect</key>
|
||||
<false/>
|
||||
<key>disableClientRoutes</key>
|
||||
<false/>
|
||||
<key>disableServerRoutes</key>
|
||||
<false/>
|
||||
<key>blockInbound</key>
|
||||
<false/>
|
||||
-->
|
||||
<key>allowServerSSH</key>
|
||||
<true/>
|
||||
<!--
|
||||
<key>rosenpassEnabled</key>
|
||||
<true/>
|
||||
<key>rosenpassPermissive</key>
|
||||
<false/>
|
||||
-->
|
||||
|
||||
<!-- ===== WireGuard UDP port (int) =====
|
||||
Range 1-65535. Omit to keep the default. -->
|
||||
<!--
|
||||
<key>wireguardPort</key>
|
||||
<integer>51820</integer>
|
||||
-->
|
||||
|
||||
<!-- ===== Split tunnel (Android-only at the daemon level)
|
||||
Pushed harmlessly on macOS for fleets with mixed
|
||||
desktop+mobile devices; the macOS daemon ignores it. -->
|
||||
<!--
|
||||
<key>splitTunnelMode</key>
|
||||
<string>allow</string>
|
||||
<key>splitTunnelApps</key>
|
||||
<string>com.acme.app1,com.acme.app2</string>
|
||||
-->
|
||||
|
||||
<!-- ===== UI / kill switches (bool) ===== -->
|
||||
<!--
|
||||
<key>disableUpdateSettings</key>
|
||||
<true/>
|
||||
<key>disableProfiles</key>
|
||||
<true/>
|
||||
<key>disableNetworks</key>
|
||||
<true/>
|
||||
<key>disableMetricsCollection</key>
|
||||
<false/>
|
||||
-->
|
||||
|
||||
</dict>
|
||||
</dict>
|
||||
</array>
|
||||
</dict>
|
||||
</dict>
|
||||
</dict>
|
||||
</array>
|
||||
</dict>
|
||||
</plist>
|
||||
@@ -1,189 +0,0 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# SYNOPSIS
|
||||
# Push the NetBird MDM policy to a macOS device via JumpCloud Commands.
|
||||
#
|
||||
# DESCRIPTION
|
||||
# This is the macOS counterpart of docs/netbird-policy.reg.ps1.
|
||||
# It writes the values declared in the "POLICY VALUES" block below to
|
||||
# the managed-preferences plist that the NetBird daemon's
|
||||
# client/mdm/policy_darwin.go loader reads on every 1-minute MDM
|
||||
# reload tick:
|
||||
#
|
||||
# /Library/Managed Preferences/io.netbird.client.plist
|
||||
#
|
||||
# Once the plist lands, the daemon picks up the new values without
|
||||
# restart (the ticker calls Config.apply() → applyMDMPolicy() and
|
||||
# restarts the engine on diff).
|
||||
#
|
||||
# DEPLOYMENT (JumpCloud)
|
||||
# 1. Admin Console -> Device Management -> Commands -> +.
|
||||
# 2. Type: Mac, Shell, Run as: root.
|
||||
# 3. Paste this file verbatim into the command body.
|
||||
# 4. Bind to the target system group, save, run.
|
||||
#
|
||||
# IMPORTANT: PERSISTENCE
|
||||
# macOS wipes /Library/Managed Preferences/ at every boot on devices
|
||||
# that are NOT MDM-enrolled. For a persistent fleet rollout, push the
|
||||
# companion docs/netbird-macos.mobileconfig as a Custom Configuration
|
||||
# Profile (Admin Console -> MDM -> Mac Custom Configuration Profiles)
|
||||
# instead of this script. Use this script when:
|
||||
# - the device is MDM-enrolled (file survives reboots), or
|
||||
# - you need a one-shot test push before reboot, or
|
||||
# - you orchestrate via JumpCloud Commands and want the same
|
||||
# variable-driven workflow as the Windows .ps1 sibling.
|
||||
#
|
||||
# IDEMPOTENCY: re-running with the same values is a no-op from the
|
||||
# daemon's point of view (the 1-minute reload ticker diff returns empty).
|
||||
#
|
||||
# SECURITY: PreSharedKey is redacted in this script's log output.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
### POLICY VALUES — EDIT THIS BLOCK ###########################################
|
||||
#
|
||||
# Set each variable below to the desired value. Set to empty string ""
|
||||
# or to NULL to omit a key entirely (the daemon treats an absent key
|
||||
# as "no enforcement" for that field). Booleans use "true"/"false"
|
||||
# (lowercase). Integers as decimal.
|
||||
#
|
||||
# Reference for key names + accepted values:
|
||||
# client/mdm/policy.go (Key* constants)
|
||||
# docs/netbird-macos.mobileconfig (sample profile)
|
||||
# docs/netbird.admx + .adml (Windows ADMX schema)
|
||||
#
|
||||
NULL='__UNSET__'
|
||||
managementURL='https://api.netbird.io:443'
|
||||
preSharedKey="$NULL" # secret; redacted in log
|
||||
allowServerSSH='true'
|
||||
blockInbound="$NULL"
|
||||
disableAutoConnect="$NULL"
|
||||
disableClientRoutes="$NULL"
|
||||
disableServerRoutes="$NULL"
|
||||
disableMetricsCollection="$NULL"
|
||||
disableUpdateSettings="$NULL"
|
||||
disableProfiles="$NULL"
|
||||
disableNetworks="$NULL"
|
||||
rosenpassEnabled="$NULL"
|
||||
rosenpassPermissive="$NULL"
|
||||
wireguardPort='51820'
|
||||
splitTunnelMode="$NULL" # "allow" or "disallow", Android-only at the daemon level
|
||||
splitTunnelApps="$NULL" # comma-separated app IDs, Android-only
|
||||
##############################################################################
|
||||
|
||||
readonly PLIST_DIR='/Library/Managed Preferences'
|
||||
readonly PLIST_PATH="$PLIST_DIR/io.netbird.client.plist"
|
||||
readonly LOG_TAG='netbird-mdm'
|
||||
|
||||
# log sends a message to the system logger using the configured tag and echoes the message to stdout prefixed by an ISO 8601 UTC timestamp and the tag.
|
||||
log() {
|
||||
/usr/bin/logger -t "$LOG_TAG" "$*"
|
||||
printf '%s [%s] %s\n' "$(date -u '+%Y-%m-%dT%H:%M:%SZ')" "$LOG_TAG" "$*"
|
||||
}
|
||||
|
||||
# is_set returns success if the provided value is non-empty and is not equal to the special NULL marker.
|
||||
is_set() {
|
||||
local value="$1"
|
||||
[[ -n "$value" && "$value" != "$NULL" ]]
|
||||
}
|
||||
|
||||
# start_plist creates the temporary plist file at "$PLIST_PATH.tmp" containing the XML plist header and opening `<dict>` for the policy plist.
|
||||
start_plist() {
|
||||
cat > "$PLIST_PATH.tmp" <<'EOF'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
EOF
|
||||
}
|
||||
|
||||
# end_plist appends the closing `</dict>` and `</plist>` tags to the temporary plist file.
|
||||
end_plist() {
|
||||
cat >> "$PLIST_PATH.tmp" <<'EOF'
|
||||
</dict>
|
||||
</plist>
|
||||
EOF
|
||||
}
|
||||
|
||||
# emit_string appends a plist `<key>`/`<string>` entry for the given key and value to "$PLIST_PATH.tmp", XML-escaping `&`, `<`, and `>`, and logs the assignment (masking the logged value as `********** (secret)` when the key is `preSharedKey`).
|
||||
emit_string() {
|
||||
local key="$1" value="$2" log_value="$2"
|
||||
# Escape XML entities in the value
|
||||
local escaped
|
||||
escaped="$(printf '%s' "$value" | sed -e 's/&/\&/g' -e 's/</\</g' -e 's/>/\>/g')"
|
||||
printf ' <key>%s</key>\n <string>%s</string>\n' "$key" "$escaped" >> "$PLIST_PATH.tmp"
|
||||
if [[ "$key" == "preSharedKey" ]]; then
|
||||
log_value='********** (secret)'
|
||||
fi
|
||||
log "set $key = $log_value"
|
||||
}
|
||||
|
||||
# emit_bool writes a boolean plist entry for a given key into the temporary plist file.
|
||||
# emit_bool writes a boolean plist entry for a key when the provided value matches an accepted boolean token; logs an error and skips the key on invalid input.
|
||||
emit_bool() {
|
||||
local key="$1" value="$2"
|
||||
local xml_bool
|
||||
case "$value" in
|
||||
true|True|TRUE|1|yes) xml_bool='<true/>' ; value='true' ;;
|
||||
false|False|FALSE|0|no) xml_bool='<false/>' ; value='false' ;;
|
||||
*) log "invalid boolean for $key: $value (must be true/false); skipping"; return ;;
|
||||
esac
|
||||
printf ' <key>%s</key>\n %s\n' "$key" "$xml_bool" >> "$PLIST_PATH.tmp"
|
||||
log "set $key = $value"
|
||||
}
|
||||
|
||||
# emit_int validates that VALUE contains only decimal digits and, if valid, appends an `<integer>` plist entry for KEY to the temporary plist (`$PLIST_PATH.tmp`) and logs the assignment; on invalid input it logs a skip and does not emit the key.
|
||||
emit_int() {
|
||||
local key="$1" value="$2"
|
||||
if ! [[ "$value" =~ ^[0-9]+$ ]]; then
|
||||
log "invalid integer for $key: $value (must be decimal); skipping"
|
||||
return
|
||||
fi
|
||||
printf ' <key>%s</key>\n <integer>%s</integer>\n' "$key" "$value" >> "$PLIST_PATH.tmp"
|
||||
log "set $key = $value"
|
||||
}
|
||||
|
||||
# main builds the NetBird MDM plist from configured policy variables, validates and installs it to /Library/Managed Preferences/io.netbird.client.plist (root:wheel, 644) and optionally triggers the NetBird daemon to reload.
|
||||
main() {
|
||||
log "applying NetBird MDM policy to $PLIST_PATH"
|
||||
/bin/mkdir -p "$PLIST_DIR"
|
||||
start_plist
|
||||
|
||||
is_set "$managementURL" && emit_string managementURL "$managementURL"
|
||||
is_set "$preSharedKey" && emit_string preSharedKey "$preSharedKey"
|
||||
is_set "$allowServerSSH" && emit_bool allowServerSSH "$allowServerSSH"
|
||||
is_set "$blockInbound" && emit_bool blockInbound "$blockInbound"
|
||||
is_set "$disableAutoConnect" && emit_bool disableAutoConnect "$disableAutoConnect"
|
||||
is_set "$disableClientRoutes" && emit_bool disableClientRoutes "$disableClientRoutes"
|
||||
is_set "$disableServerRoutes" && emit_bool disableServerRoutes "$disableServerRoutes"
|
||||
is_set "$disableMetricsCollection" && emit_bool disableMetricsCollection "$disableMetricsCollection"
|
||||
is_set "$disableUpdateSettings" && emit_bool disableUpdateSettings "$disableUpdateSettings"
|
||||
is_set "$disableProfiles" && emit_bool disableProfiles "$disableProfiles"
|
||||
is_set "$disableNetworks" && emit_bool disableNetworks "$disableNetworks"
|
||||
is_set "$rosenpassEnabled" && emit_bool rosenpassEnabled "$rosenpassEnabled"
|
||||
is_set "$rosenpassPermissive" && emit_bool rosenpassPermissive "$rosenpassPermissive"
|
||||
is_set "$wireguardPort" && emit_int wireguardPort "$wireguardPort"
|
||||
is_set "$splitTunnelMode" && emit_string splitTunnelMode "$splitTunnelMode"
|
||||
is_set "$splitTunnelApps" && emit_string splitTunnelApps "$splitTunnelApps"
|
||||
|
||||
end_plist
|
||||
|
||||
if ! /usr/bin/plutil -lint "$PLIST_PATH.tmp" >/dev/null 2>&1; then
|
||||
log "ERROR: generated plist failed plutil lint; not installing"
|
||||
/usr/bin/plutil -lint "$PLIST_PATH.tmp" >&2 || true
|
||||
/bin/rm -f "$PLIST_PATH.tmp"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
/bin/mv -f "$PLIST_PATH.tmp" "$PLIST_PATH"
|
||||
/usr/sbin/chown root:wheel "$PLIST_PATH"
|
||||
/bin/chmod 644 "$PLIST_PATH"
|
||||
|
||||
log "policy installed; NetBird daemon will pick it up within the next 1-minute reload tick"
|
||||
|
||||
# Optional: kick the daemon for an immediate apply. Safe — does
|
||||
# nothing on a host where NetBird is not yet installed.
|
||||
/bin/launchctl kickstart -k system/io.netbird.client 2>/dev/null || true
|
||||
}
|
||||
|
||||
main "$@"
|
||||
Binary file not shown.
@@ -1,94 +0,0 @@
|
||||
#requires -Version 5.1
|
||||
<#
|
||||
.SYNOPSIS
|
||||
Push the NetBird MDM policy to a Windows device via JumpCloud Commands
|
||||
by importing a sidecar netbird-policy.reg file.
|
||||
|
||||
.DESCRIPTION
|
||||
Windows counterpart of docs/netbird-macos.sh. Outcome:
|
||||
HKLM\Software\Policies\NetBird populated from the attached
|
||||
netbird-policy.reg file, daemon picks up the change via the
|
||||
1-minute MDM reload ticker.
|
||||
|
||||
Deployment:
|
||||
1. Admin Console -> Device Management -> Commands -> +.
|
||||
2. Type: Windows PowerShell. Run as: SYSTEM.
|
||||
3. Paste this file verbatim into the command body.
|
||||
4. In the same command, attach `netbird-policy.reg` as a file.
|
||||
JumpCloud copies attached files into the command's working
|
||||
directory before invoking the script, so `$PSScriptRoot` or
|
||||
Get-Location resolves to where the .reg lives.
|
||||
5. Bind to the target system group, save, run.
|
||||
|
||||
Producing the .reg file:
|
||||
On a reference machine, after configuring the policy values either
|
||||
via gpedit (GPO) or manual `reg add`, export with:
|
||||
|
||||
reg export "HKLM\Software\Policies\NetBird" netbird-policy.reg /y
|
||||
|
||||
Then attach the resulting file to the JumpCloud command.
|
||||
|
||||
Semantics:
|
||||
- The script nukes the existing HKLM\Software\Policies\NetBird key
|
||||
before importing the .reg, so the .reg is the SINGLE SOURCE OF
|
||||
TRUTH. Any value present in the registry but absent from the .reg
|
||||
is removed. This is what an MDM admin almost always wants.
|
||||
- Setting the .reg to an empty (header-only) file effectively unsets
|
||||
the policy.
|
||||
|
||||
Idempotency: re-running the script with the same .reg is a no-op from
|
||||
the daemon's perspective (values identical → 1-min ticker sees no
|
||||
diff → engine not restarted).
|
||||
|
||||
Exit codes: 0 = success; 1 = .reg missing or reg.exe error.
|
||||
#>
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
$RegFileName = "netbird-policy.reg"
|
||||
$RegKey = "HKLM\Software\Policies\NetBird"
|
||||
|
||||
# Resolve the attached .reg file: JumpCloud copies command attachments
|
||||
# into C:\Windows\Temp\ before invoking the script. Cwd / $PSScriptRoot
|
||||
# fallbacks cover the local-dev case where you might dot-source this
|
||||
# from elsewhere.
|
||||
$candidates = @(
|
||||
(Join-Path "$env:WINDIR\Temp" $RegFileName)
|
||||
(Join-Path (Get-Location) $RegFileName)
|
||||
(Join-Path $PSScriptRoot $RegFileName)
|
||||
) | Where-Object { Test-Path $_ }
|
||||
|
||||
if ($candidates.Count -eq 0) {
|
||||
Write-Error "[netbird-mdm] $RegFileName not found in working directory or `$PSScriptRoot. Attach the file to the JumpCloud command."
|
||||
exit 1
|
||||
}
|
||||
$regFile = $candidates[0]
|
||||
Write-Host "[netbird-mdm] using $regFile"
|
||||
|
||||
# Wipe the existing policy key so the .reg is authoritative.
|
||||
$existed = Test-Path "Registry::HKEY_LOCAL_MACHINE\Software\Policies\NetBird"
|
||||
if ($existed) {
|
||||
& reg.exe delete $RegKey /f | Out-Null
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Error "[netbird-mdm] failed to clear $RegKey before import (exit $LASTEXITCODE)"
|
||||
exit 1
|
||||
}
|
||||
Write-Host "[netbird-mdm] cleared previous values under $RegKey"
|
||||
}
|
||||
|
||||
# Import. reg.exe writes both data and (re-)creates the key if needed.
|
||||
& reg.exe import $regFile
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Error "[netbird-mdm] reg import failed (exit $LASTEXITCODE)"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Audit dump so the JumpCloud per-execution log captures the applied state.
|
||||
Write-Host "[netbird-mdm] final policy state under $RegKey :"
|
||||
& reg.exe query $RegKey /s
|
||||
|
||||
# Daemon's 1-min reload ticker picks up the change automatically.
|
||||
# Uncomment to force immediate convergence (skips the ticker wait):
|
||||
# Restart-Service netbird -Force -ErrorAction SilentlyContinue
|
||||
|
||||
exit 0
|
||||
@@ -1,95 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<policyDefinitionResources xmlns:xsd="http://www.w3.org/2001/XMLSchema"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
revision="1.0"
|
||||
schemaVersion="1.0"
|
||||
xmlns="http://schemas.microsoft.com/GroupPolicy/2006/07/PolicyDefinitions">
|
||||
<displayName>NetBird Client Policies</displayName>
|
||||
<description>Group Policy template for NetBird client MDM-managed settings. Values are written under HKLM\Software\Policies\NetBird and consumed by the netbird daemon at startup and every 1-minute reload tick.</description>
|
||||
<resources>
|
||||
<stringTable>
|
||||
|
||||
<!-- Categories -->
|
||||
<string id="NetBird_Category">NetBird</string>
|
||||
<string id="SUPPORTED_NetBird_All">NetBird Client 0.40+</string>
|
||||
|
||||
<!-- Identity / auth -->
|
||||
<string id="ManagementURL_Name">Management URL</string>
|
||||
<string id="ManagementURL_Help">URL of the NetBird management server. Format: https://host[:port]. When set, users cannot override this value via UI or CLI.</string>
|
||||
|
||||
<string id="PreSharedKey_Name">Pre-shared key</string>
|
||||
<string id="PreSharedKey_Help">WireGuard pre-shared key used as an additional symmetric secret on every peer-to-peer tunnel. Secret value.</string>
|
||||
|
||||
<!-- Settings: engine / runtime behavior -->
|
||||
<string id="DisableAutoConnect_Name">Disable auto-connect</string>
|
||||
<string id="DisableAutoConnect_Help">When enabled, the NetBird tunnel does not auto-connect at daemon startup. Equivalent to --disable-auto-connect.</string>
|
||||
|
||||
<string id="DisableClientRoutes_Name">Disable client routes</string>
|
||||
<string id="DisableClientRoutes_Help">When enabled, this client will not consume routes advertised by routing peers. Equivalent to --disable-client-routes.</string>
|
||||
|
||||
<string id="DisableServerRoutes_Name">Disable server routes</string>
|
||||
<string id="DisableServerRoutes_Help">When enabled, this client will not act as a routing peer for other clients. Equivalent to --disable-server-routes.</string>
|
||||
|
||||
<string id="BlockInbound_Name">Block inbound</string>
|
||||
<string id="BlockInbound_Help">When enabled, the client firewall blocks all inbound peer traffic on the WireGuard interface. Equivalent to --block-inbound.</string>
|
||||
|
||||
<string id="AllowServerSSH_Name">Allow server SSH</string>
|
||||
<string id="AllowServerSSH_Help">When enabled, this client accepts incoming SSH sessions via NetBird SSH. Equivalent to --allow-server-ssh.</string>
|
||||
|
||||
<string id="RosenpassEnabled_Name">Enable Rosenpass</string>
|
||||
<string id="RosenpassEnabled_Help">Enables Rosenpass post-quantum key exchange on WireGuard tunnels. Both peers must support it.</string>
|
||||
|
||||
<string id="RosenpassPermissive_Name">Rosenpass permissive</string>
|
||||
<string id="RosenpassPermissive_Help">When enabled, the client falls back to plain WireGuard if a peer does not support Rosenpass; otherwise it refuses the connection.</string>
|
||||
|
||||
<string id="WireguardPort_Name">WireGuard port</string>
|
||||
<string id="WireguardPort_Help">UDP port used by the local WireGuard interface. Allowed range: 1-65535.</string>
|
||||
|
||||
<string id="SplitTunnel_Name">Split tunnel</string>
|
||||
<string id="SplitTunnel_Help">Restrict the NetBird tunnel to or from a chosen list of application package names. Choose either the allow mode (only the listed apps route through NetBird) or the disallow mode (the listed apps bypass NetBird; everything else routes through). The mode is mutually exclusive — only one can be active at a time. Android-only at the daemon level; Windows/macOS/iOS clients ignore this policy.</string>
|
||||
<string id="SplitTunnel_Allow">Allow only listed apps (everything else bypasses)</string>
|
||||
<string id="SplitTunnel_Disallow">Disallow listed apps (everything else routes)</string>
|
||||
|
||||
<!-- UI -->
|
||||
<string id="DisableUpdateSettings_Name">Disable update settings</string>
|
||||
<string id="DisableUpdateSettings_Help">When enabled, blocks every configuration change from the client UI and from the CLI (netbird up / login / setconfig). The Settings view stays viewable but read-only. Equivalent to --disable-update-settings.</string>
|
||||
|
||||
<string id="DisableProfiles_Name">Disable profiles</string>
|
||||
<string id="DisableProfiles_Help">When enabled, the client UI/CLI cannot list, create, switch or remove NetBird connection profiles. Equivalent to --disable-profiles.</string>
|
||||
|
||||
<string id="DisableNetworks_Name">Disable networks</string>
|
||||
<string id="DisableNetworks_Help">When enabled, the client UI/CLI cannot list, select or deselect NetBird networks (the corresponding daemon RPCs return Unavailable). Equivalent to --disable-networks.</string>
|
||||
|
||||
<string id="DisableMetricsCollection_Name">Disable metrics collection</string>
|
||||
<string id="DisableMetricsCollection_Help">When enabled, the client does not collect or report local usage metrics.</string>
|
||||
|
||||
</stringTable>
|
||||
<presentationTable>
|
||||
|
||||
<presentation id="ManagementURL_Pres">
|
||||
<textBox refId="ManagementURL_Text">
|
||||
<label>Management URL:</label>
|
||||
<defaultValue>https://api.netbird.io:443</defaultValue>
|
||||
</textBox>
|
||||
</presentation>
|
||||
|
||||
<presentation id="PreSharedKey_Pres">
|
||||
<textBox refId="PreSharedKey_Text">
|
||||
<label>Pre-shared key:</label>
|
||||
</textBox>
|
||||
</presentation>
|
||||
|
||||
<presentation id="WireguardPort_Pres">
|
||||
<decimalTextBox refId="WireguardPort_Decimal" defaultValue="51820">WireGuard UDP port:</decimalTextBox>
|
||||
</presentation>
|
||||
|
||||
<presentation id="SplitTunnel_Pres">
|
||||
<dropdownList refId="SplitTunnel_Mode" defaultItem="0">Mode:</dropdownList>
|
||||
<textBox refId="SplitTunnel_Apps">
|
||||
<label>Package names (comma-separated):</label>
|
||||
</textBox>
|
||||
</presentation>
|
||||
|
||||
</presentationTable>
|
||||
</resources>
|
||||
</policyDefinitionResources>
|
||||
@@ -1,223 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<policyDefinitions xmlns:xsd="http://www.w3.org/2001/XMLSchema"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
revision="1.0"
|
||||
schemaVersion="1.0"
|
||||
xmlns="http://schemas.microsoft.com/GroupPolicy/2006/07/PolicyDefinitions">
|
||||
<policyNamespaces>
|
||||
<target prefix="netbird" namespace="NetBird.Policies.Client" />
|
||||
</policyNamespaces>
|
||||
<resources minRequiredRevision="1.0" />
|
||||
<supportedOn>
|
||||
<definitions>
|
||||
<definition name="SUPPORTED_NetBird_All" displayName="$(string.SUPPORTED_NetBird_All)" />
|
||||
</definitions>
|
||||
</supportedOn>
|
||||
<categories>
|
||||
<category name="NetBird" displayName="$(string.NetBird_Category)" />
|
||||
</categories>
|
||||
<policies>
|
||||
|
||||
<!-- ============================================================ -->
|
||||
<!-- TOP-LEVEL: foundational identity / authentication -->
|
||||
<!-- ============================================================ -->
|
||||
|
||||
<policy name="ManagementURL"
|
||||
class="Machine"
|
||||
displayName="$(string.ManagementURL_Name)"
|
||||
explainText="$(string.ManagementURL_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
presentation="$(presentation.ManagementURL_Pres)">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<elements>
|
||||
<text id="ManagementURL_Text" valueName="ManagementURL" required="true" />
|
||||
</elements>
|
||||
</policy>
|
||||
|
||||
<policy name="PreSharedKey"
|
||||
class="Machine"
|
||||
displayName="$(string.PreSharedKey_Name)"
|
||||
explainText="$(string.PreSharedKey_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
presentation="$(presentation.PreSharedKey_Pres)">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<elements>
|
||||
<text id="PreSharedKey_Text" valueName="PreSharedKey" />
|
||||
</elements>
|
||||
</policy>
|
||||
|
||||
<!-- ============================================================ -->
|
||||
<!-- SETTINGS: engine / runtime / connection behavior -->
|
||||
<!-- ============================================================ -->
|
||||
|
||||
<policy name="DisableAutoConnect"
|
||||
class="Machine"
|
||||
displayName="$(string.DisableAutoConnect_Name)"
|
||||
explainText="$(string.DisableAutoConnect_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="DisableAutoConnect">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
<policy name="DisableClientRoutes"
|
||||
class="Machine"
|
||||
displayName="$(string.DisableClientRoutes_Name)"
|
||||
explainText="$(string.DisableClientRoutes_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="DisableClientRoutes">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
<policy name="DisableServerRoutes"
|
||||
class="Machine"
|
||||
displayName="$(string.DisableServerRoutes_Name)"
|
||||
explainText="$(string.DisableServerRoutes_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="DisableServerRoutes">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
<policy name="BlockInbound"
|
||||
class="Machine"
|
||||
displayName="$(string.BlockInbound_Name)"
|
||||
explainText="$(string.BlockInbound_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="BlockInbound">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
<policy name="AllowServerSSH"
|
||||
class="Machine"
|
||||
displayName="$(string.AllowServerSSH_Name)"
|
||||
explainText="$(string.AllowServerSSH_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="AllowServerSSH">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
<policy name="RosenpassEnabled"
|
||||
class="Machine"
|
||||
displayName="$(string.RosenpassEnabled_Name)"
|
||||
explainText="$(string.RosenpassEnabled_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="RosenpassEnabled">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
<policy name="RosenpassPermissive"
|
||||
class="Machine"
|
||||
displayName="$(string.RosenpassPermissive_Name)"
|
||||
explainText="$(string.RosenpassPermissive_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="RosenpassPermissive">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
<policy name="WireguardPort"
|
||||
class="Machine"
|
||||
displayName="$(string.WireguardPort_Name)"
|
||||
explainText="$(string.WireguardPort_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
presentation="$(presentation.WireguardPort_Pres)">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<elements>
|
||||
<decimal id="WireguardPort_Decimal" valueName="WireguardPort"
|
||||
minValue="1" maxValue="65535" required="true" />
|
||||
</elements>
|
||||
</policy>
|
||||
|
||||
<policy name="SplitTunnel"
|
||||
class="Machine"
|
||||
displayName="$(string.SplitTunnel_Name)"
|
||||
explainText="$(string.SplitTunnel_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
presentation="$(presentation.SplitTunnel_Pres)">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<elements>
|
||||
<enum id="SplitTunnel_Mode" valueName="SplitTunnelMode" required="true">
|
||||
<item displayName="$(string.SplitTunnel_Allow)"><value><string>allow</string></value></item>
|
||||
<item displayName="$(string.SplitTunnel_Disallow)"><value><string>disallow</string></value></item>
|
||||
</enum>
|
||||
<text id="SplitTunnel_Apps" valueName="SplitTunnelApps" required="true" />
|
||||
</elements>
|
||||
</policy>
|
||||
|
||||
<!-- ============================================================ -->
|
||||
<!-- UI: visibility / UX kill switches -->
|
||||
<!-- ============================================================ -->
|
||||
|
||||
<policy name="DisableUpdateSettings"
|
||||
class="Machine"
|
||||
displayName="$(string.DisableUpdateSettings_Name)"
|
||||
explainText="$(string.DisableUpdateSettings_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="DisableUpdateSettings">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
<policy name="DisableProfiles"
|
||||
class="Machine"
|
||||
displayName="$(string.DisableProfiles_Name)"
|
||||
explainText="$(string.DisableProfiles_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="DisableProfiles">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
<policy name="DisableNetworks"
|
||||
class="Machine"
|
||||
displayName="$(string.DisableNetworks_Name)"
|
||||
explainText="$(string.DisableNetworks_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="DisableNetworks">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
<policy name="DisableMetricsCollection"
|
||||
class="Machine"
|
||||
displayName="$(string.DisableMetricsCollection_Name)"
|
||||
explainText="$(string.DisableMetricsCollection_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="DisableMetricsCollection">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
</policies>
|
||||
</policyDefinitions>
|
||||
@@ -99,9 +99,6 @@ func addFields(entry *logrus.Entry) {
|
||||
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
|
||||
entry.Data[context.AccountIDKey] = ctxAccountID
|
||||
}
|
||||
if ctxUserAgent, ok := entry.Context.Value(context.UserAgentKey).(string); ok {
|
||||
entry.Data[context.UserAgentKey] = ctxUserAgent
|
||||
}
|
||||
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
|
||||
entry.Data[context.UserIDKey] = ctxInitiatorID
|
||||
}
|
||||
|
||||
7
go.mod
7
go.mod
@@ -2,8 +2,6 @@ module github.com/netbirdio/netbird
|
||||
|
||||
go 1.25.5
|
||||
|
||||
toolchain go1.25.11
|
||||
|
||||
require (
|
||||
cunicu.li/go-rosenpass v0.5.42
|
||||
github.com/cenkalti/backoff/v4 v4.3.0
|
||||
@@ -56,7 +54,6 @@ require (
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/gliderlabs/ssh v0.3.8
|
||||
github.com/go-jose/go-jose/v4 v4.1.4
|
||||
github.com/goccy/go-yaml v1.18.0
|
||||
github.com/godbus/dbus/v5 v5.1.0
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1
|
||||
github.com/golang/mock v1.6.0
|
||||
@@ -134,7 +131,6 @@ require (
|
||||
gorm.io/driver/sqlite v1.5.7
|
||||
gorm.io/gorm v1.25.12
|
||||
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89
|
||||
howett.net/plist v1.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -215,9 +211,10 @@ require (
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
|
||||
github.com/go-webauthn/webauthn v0.16.4 // indirect
|
||||
github.com/go-webauthn/x v0.2.3 // indirect
|
||||
github.com/goccy/go-yaml v1.18.0 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
|
||||
github.com/google/btree v1.1.3 // indirect
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/go-tpm v0.9.8 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
|
||||
8
go.sum
8
go.sum
@@ -275,8 +275,8 @@ github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu
|
||||
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
|
||||
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
@@ -380,7 +380,6 @@ github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZ
|
||||
github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
|
||||
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade h1:FmusiCI1wHw+XQbvL9M+1r/C3SPqKrmBaIOYwVfQoDE=
|
||||
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade/go.mod h1:ZDXo8KHryOWSIqnsb/CiDq7hQUYryCgdVnxbj8tDG7o=
|
||||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
@@ -947,7 +946,6 @@ gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI
|
||||
gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0/go.mod h1:WDnlLJ4WF5VGsH/HVa3CI79GS0ol3YnhVnKP89i0kNg=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
@@ -970,7 +968,5 @@ gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
|
||||
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA=
|
||||
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q=
|
||||
howett.net/plist v1.0.1 h1:37GdZ8tP09Q35o9ych3ehygcsL+HqKSwzctveSlarvM=
|
||||
howett.net/plist v1.0.1/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
|
||||
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
|
||||
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=
|
||||
|
||||
@@ -45,7 +45,7 @@ type Controller struct {
|
||||
EphemeralPeersManager ephemeral.Manager
|
||||
|
||||
accountUpdateLocks sync.Map
|
||||
affectedPeerUpdateLocks sync.Map
|
||||
sendAccountUpdateLocks sync.Map
|
||||
updateAccountPeersBufferInterval atomic.Int64
|
||||
// dnsDomain is used for peer resolution. This is appended to the peer's name
|
||||
dnsDomain string
|
||||
@@ -64,13 +64,6 @@ type bufferUpdate struct {
|
||||
update atomic.Bool
|
||||
}
|
||||
|
||||
type bufferAffectedUpdate struct {
|
||||
sendMu sync.Mutex
|
||||
dataMu sync.Mutex
|
||||
next *time.Timer
|
||||
peerIDs map[string]struct{}
|
||||
}
|
||||
|
||||
var _ network_map.Controller = (*Controller)(nil)
|
||||
|
||||
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller {
|
||||
@@ -208,7 +201,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
@@ -233,6 +226,44 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||
log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName())
|
||||
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
|
||||
bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
|
||||
b := bufUpd.(*bufferUpdate)
|
||||
|
||||
if !b.mu.TryLock() {
|
||||
b.update.Store(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
if b.next != nil {
|
||||
b.next.Stop()
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer b.mu.Unlock()
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||
if !b.update.Load() {
|
||||
return
|
||||
}
|
||||
b.update.Store(false)
|
||||
if b.next == nil {
|
||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||
})
|
||||
return
|
||||
}
|
||||
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||
@@ -242,143 +273,6 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, r
|
||||
return c.sendUpdateAccountPeers(ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// UpdateAffectedPeers updates only the specified peers that belong to an account.
|
||||
func (c *Controller) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
if len(peerIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return c.sendUpdateForAffectedPeers(ctx, accountID, peerIDs)
|
||||
}
|
||||
|
||||
func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: account %s, %d affected peers: %v (caller: %s)", accountID, len(peerIDs), peerIDs, util.GetCallerName())
|
||||
|
||||
if !c.hasConnectedPeers(peerIDs) {
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no connected peers among %v, skipping", peerIDs)
|
||||
return nil
|
||||
}
|
||||
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account: %v", err)
|
||||
}
|
||||
|
||||
globalStart := time.Now()
|
||||
|
||||
peersToUpdate := c.filterConnectedAffectedPeers(account, peerIDs)
|
||||
if len(peersToUpdate) == 0 {
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no peers to update (affected peers not found in account or no channels)")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: sending network map to %d connected peers", len(peersToUpdate))
|
||||
|
||||
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get validate peers: %v", err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 10)
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
dnsDomain := c.GetDNSDomain(account.Settings)
|
||||
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
return fmt.Errorf("failed to get proxy network maps: %v", err)
|
||||
}
|
||||
|
||||
extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get flow enabled status: %v", err)
|
||||
}
|
||||
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
|
||||
return fmt.Errorf("failed to get account zones: %v", err)
|
||||
}
|
||||
|
||||
for _, peer := range peersToUpdate {
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
go func(p *nbpeer.Peer) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
postureChecks, err := c.getPeerPostureChecks(account, p.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||
|
||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
|
||||
if ok {
|
||||
remotePeerNetworkMap.Merge(proxyNetworkMap)
|
||||
}
|
||||
|
||||
peerGroups := account.GetPeerGroups(p.ID)
|
||||
start = time.Now()
|
||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
|
||||
Update: update,
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
})
|
||||
}(peer)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) hasConnectedPeers(peerIDs []string) bool {
|
||||
for _, id := range peerIDs {
|
||||
if c.peersUpdateManager.HasChannel(id) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Controller) filterConnectedAffectedPeers(account *types.Account, peerIDs []string) []*nbpeer.Peer {
|
||||
affected := make(map[string]struct{}, len(peerIDs))
|
||||
for _, id := range peerIDs {
|
||||
affected[id] = struct{}{}
|
||||
}
|
||||
|
||||
var result []*nbpeer.Peer
|
||||
for _, peer := range account.Peers {
|
||||
if _, ok := affected[peer.ID]; ok && c.peersUpdateManager.HasChannel(peer.ID) {
|
||||
result = append(result, peer)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
|
||||
if !c.peersUpdateManager.HasChannel(peerId) {
|
||||
return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId)
|
||||
@@ -487,104 +381,6 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
||||
return nil
|
||||
}
|
||||
|
||||
// BufferUpdateAffectedPeers accumulates peer IDs and flushes them after the buffer interval.
|
||||
func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
|
||||
if len(peerIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
|
||||
|
||||
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
|
||||
peerIDs: make(map[string]struct{}),
|
||||
})
|
||||
b := bufUpd.(*bufferAffectedUpdate)
|
||||
|
||||
b.addPeerIDs(peerIDs)
|
||||
|
||||
if !b.sendMu.TryLock() {
|
||||
// Another goroutine is already sending; it will pick up our IDs on its next drain.
|
||||
return nil
|
||||
}
|
||||
|
||||
b.stopTimer()
|
||||
|
||||
// The send and the debounced timer outlive the calling request, so detach from
|
||||
// its context to avoid sending with a cancelled context once the handler returns.
|
||||
bgCtx := context.WithoutCancel(ctx)
|
||||
|
||||
collected := b.drainPeerIDs()
|
||||
go func() {
|
||||
defer b.sendMu.Unlock()
|
||||
_ = c.sendUpdateForAffectedPeers(bgCtx, accountID, collected)
|
||||
|
||||
// Check if more peer IDs accumulated while we were sending.
|
||||
if !b.hasPending() {
|
||||
return
|
||||
}
|
||||
|
||||
// Schedule a debounced flush for the newly accumulated IDs.
|
||||
b.setTimer(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||
ids := b.drainPeerIDs()
|
||||
if len(ids) > 0 {
|
||||
_ = c.sendUpdateForAffectedPeers(bgCtx, accountID, ids)
|
||||
}
|
||||
})
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *bufferAffectedUpdate) addPeerIDs(ids []string) {
|
||||
b.dataMu.Lock()
|
||||
for _, id := range ids {
|
||||
b.peerIDs[id] = struct{}{}
|
||||
}
|
||||
b.dataMu.Unlock()
|
||||
}
|
||||
|
||||
func (b *bufferAffectedUpdate) drainPeerIDs() []string {
|
||||
b.dataMu.Lock()
|
||||
defer b.dataMu.Unlock()
|
||||
if len(b.peerIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
ids := make([]string, 0, len(b.peerIDs))
|
||||
for id := range b.peerIDs {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
b.peerIDs = make(map[string]struct{})
|
||||
return ids
|
||||
}
|
||||
|
||||
func (b *bufferAffectedUpdate) hasPending() bool {
|
||||
b.dataMu.Lock()
|
||||
defer b.dataMu.Unlock()
|
||||
return len(b.peerIDs) > 0
|
||||
}
|
||||
|
||||
func (b *bufferAffectedUpdate) stopTimer() {
|
||||
b.dataMu.Lock()
|
||||
defer b.dataMu.Unlock()
|
||||
if b.next != nil {
|
||||
b.next.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *bufferAffectedUpdate) setTimer(d time.Duration, f func()) {
|
||||
b.dataMu.Lock()
|
||||
defer b.dataMu.Unlock()
|
||||
if b.next == nil {
|
||||
b.next = time.AfterFunc(d, f)
|
||||
return
|
||||
}
|
||||
b.next.Reset(d)
|
||||
}
|
||||
|
||||
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
if isRequiresApproval {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
@@ -782,24 +578,21 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
if len(affectedPeerIDs) == 0 {
|
||||
log.WithContext(ctx).Tracef("no affected peers for peer update in account %s, skipping", accountID)
|
||||
return nil
|
||||
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
err := c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
|
||||
}
|
||||
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
|
||||
if len(affectedPeerIDs) == 0 {
|
||||
log.WithContext(ctx).Tracef("no affected peers for peer add in account %s, skipping", accountID)
|
||||
return nil
|
||||
}
|
||||
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -832,11 +625,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
|
||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||
}
|
||||
|
||||
if len(affectedPeerIDs) == 0 {
|
||||
log.WithContext(ctx).Tracef("no affected peers for peer delete in account %s, skipping", accountID)
|
||||
return nil
|
||||
}
|
||||
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
|
||||
}
|
||||
|
||||
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
|
||||
|
||||
@@ -19,8 +19,6 @@ const (
|
||||
|
||||
type Controller interface {
|
||||
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
||||
UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error
|
||||
BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error
|
||||
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
@@ -29,9 +27,9 @@ type Controller interface {
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
CountStreams() int
|
||||
|
||||
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error
|
||||
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error
|
||||
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error
|
||||
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error
|
||||
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error
|
||||
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error
|
||||
DisconnectPeers(ctx context.Context, accountId string, peerIDs []string)
|
||||
OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerID string)
|
||||
|
||||
@@ -57,20 +57,6 @@ func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, r
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// BufferUpdateAffectedPeers mocks base method.
|
||||
func (m *MockController) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BufferUpdateAffectedPeers", ctx, accountID, peerIDs, reason)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BufferUpdateAffectedPeers indicates an expected call of BufferUpdateAffectedPeers.
|
||||
func (mr *MockControllerMockRecorder) BufferUpdateAffectedPeers(ctx, accountID, peerIDs, reason any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAffectedPeers), ctx, accountID, peerIDs, reason)
|
||||
}
|
||||
|
||||
// CountStreams mocks base method.
|
||||
func (m *MockController) CountStreams() int {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -172,45 +158,45 @@ func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID
|
||||
}
|
||||
|
||||
// OnPeersAdded mocks base method.
|
||||
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs, affectedPeerIDs)
|
||||
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// OnPeersAdded indicates an expected call of OnPeersAdded.
|
||||
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call {
|
||||
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs, affectedPeerIDs)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs)
|
||||
}
|
||||
|
||||
// OnPeersDeleted mocks base method.
|
||||
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs, affectedPeerIDs)
|
||||
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// OnPeersDeleted indicates an expected call of OnPeersDeleted.
|
||||
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call {
|
||||
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs, affectedPeerIDs)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs)
|
||||
}
|
||||
|
||||
// OnPeersUpdated mocks base method.
|
||||
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error {
|
||||
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs, affectedPeerIDs)
|
||||
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// OnPeersUpdated indicates an expected call of OnPeersUpdated.
|
||||
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs, affectedPeerIDs any) *gomock.Call {
|
||||
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs, affectedPeerIDs)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs)
|
||||
}
|
||||
|
||||
// StartWarmup mocks base method.
|
||||
@@ -264,17 +250,3 @@ func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID, reason
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// UpdateAffectedPeers mocks base method.
|
||||
func (m *MockController) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateAffectedPeers", ctx, accountID, peerIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateAffectedPeers indicates an expected call of UpdateAffectedPeers.
|
||||
func (mr *MockControllerMockRecorder) UpdateAffectedPeers(ctx, accountID, peerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).UpdateAffectedPeers), ctx, accountID, peerIDs)
|
||||
}
|
||||
|
||||
@@ -488,195 +488,6 @@ func TestUpdate_AllowsPortChange(t *testing.T) {
|
||||
assert.Equal(t, uint16(54321), updated.ListenPort, "explicit port change should be applied")
|
||||
}
|
||||
|
||||
func TestUpdate_PreservesPortWhenCustomPortsNotSupported(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
|
||||
ctx := context.Background()
|
||||
|
||||
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
|
||||
|
||||
updated := &rpservice.Service{
|
||||
ID: existing.ID,
|
||||
AccountID: testAccountID,
|
||||
Name: "tcp-svc-renamed",
|
||||
Mode: "tcp",
|
||||
Domain: testCluster,
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 0,
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
|
||||
require.NoError(t, err, "update must not be rejected by the custom-port capability check")
|
||||
assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved on unsupported cluster")
|
||||
}
|
||||
|
||||
func TestUpdate_PreservesPortWhenCustomPortsUnknown(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
|
||||
|
||||
updated := &rpservice.Service{
|
||||
ID: existing.ID,
|
||||
AccountID: testAccountID,
|
||||
Name: "tcp-svc-renamed",
|
||||
Mode: "tcp",
|
||||
Domain: testCluster,
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 0,
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
|
||||
require.NoError(t, err, "update must not be rejected when cluster capability is unknown")
|
||||
assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved when capability is unknown")
|
||||
}
|
||||
|
||||
func TestUpdate_RejectsPortChangeWhenCustomPortsNotSupported(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
|
||||
ctx := context.Background()
|
||||
|
||||
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
|
||||
|
||||
updated := &rpservice.Service{
|
||||
ID: existing.ID,
|
||||
AccountID: testAccountID,
|
||||
Name: "tcp-svc",
|
||||
Mode: "tcp",
|
||||
Domain: testCluster,
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 54321,
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
|
||||
require.Error(t, err, "explicit port change on update must be rejected on unsupported clusters")
|
||||
assert.Contains(t, err.Error(), "custom ports not supported on target cluster")
|
||||
}
|
||||
|
||||
func TestUpdate_TLSPortChangeAllowedWhenNotSupported(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
|
||||
ctx := context.Background()
|
||||
|
||||
existing := seedService(t, testStore, "tls-svc", "tls", "app.example.com", testCluster, 443)
|
||||
|
||||
updated := &rpservice.Service{
|
||||
ID: existing.ID,
|
||||
AccountID: testAccountID,
|
||||
Name: "tls-svc",
|
||||
Mode: "tls",
|
||||
Domain: "app.example.com",
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 9999,
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
|
||||
require.NoError(t, err, "TLS port change uses SNI routing and is exempt from the custom-port check")
|
||||
assert.Equal(t, uint16(9999), updated.ListenPort, "TLS port change should be applied")
|
||||
}
|
||||
|
||||
func TestValidateL4PortDiffOnClusterDiff(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mode string
|
||||
customPorts *bool
|
||||
newPort uint16
|
||||
oldPort uint16
|
||||
wantErr bool
|
||||
}{
|
||||
{"tcp port change unsupported", "tcp", boolPtr(false), 54321, 12345, true},
|
||||
{"tcp port change unknown capability", "tcp", nil, 54321, 12345, true},
|
||||
{"udp port change unsupported", "udp", boolPtr(false), 54321, 12345, true},
|
||||
{"tcp first port assignment unsupported", "tcp", boolPtr(false), 54321, 0, true},
|
||||
{"tcp port change supported", "tcp", boolPtr(true), 54321, 12345, false},
|
||||
{"tcp port unchanged unsupported", "tcp", boolPtr(false), 12345, 12345, false},
|
||||
{"tcp zero port unsupported", "tcp", boolPtr(false), 0, 12345, false},
|
||||
{"tls port change unsupported", "tls", boolPtr(false), 9999, 443, false},
|
||||
{"http mode ignored", "http", boolPtr(false), 54321, 12345, false},
|
||||
{"empty mode ignored", "", boolPtr(false), 54321, 12345, false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
newSvc := &rpservice.Service{Mode: tc.mode, ListenPort: tc.newPort, ProxyCluster: testCluster}
|
||||
oldSvc := &rpservice.Service{Mode: tc.mode, ListenPort: tc.oldPort, ProxyCluster: testCluster}
|
||||
|
||||
err := validateL4PortDiffOnClusterDiff(tc.customPorts, newSvc, oldSvc)
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err, "port diff should be rejected for %s", tc.name)
|
||||
} else {
|
||||
assert.NoError(t, err, "port diff should be allowed for %s", tc.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate_PortConflictRejected(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
|
||||
ctx := context.Background()
|
||||
|
||||
seedService(t, testStore, "tcp-a", "tcp", "tcp-a."+testCluster, testCluster, 5432)
|
||||
svcB := seedService(t, testStore, "tcp-b", "tcp", "tcp-b."+testCluster, testCluster, 6543)
|
||||
|
||||
updated := &rpservice.Service{
|
||||
ID: svcB.ID,
|
||||
AccountID: testAccountID,
|
||||
Name: "tcp-b",
|
||||
Mode: "tcp",
|
||||
Domain: "tcp-b." + testCluster,
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 5432,
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
|
||||
require.Error(t, err, "updating to a port held by another service should be rejected")
|
||||
assert.Contains(t, err.Error(), "already in use")
|
||||
}
|
||||
|
||||
func TestUpdate_AutoAssignsWhenNoPort(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
|
||||
ctx := context.Background()
|
||||
|
||||
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 0)
|
||||
|
||||
updated := &rpservice.Service{
|
||||
ID: existing.ID,
|
||||
AccountID: testAccountID,
|
||||
Name: "tcp-svc",
|
||||
Mode: "tcp",
|
||||
Domain: testCluster,
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 0,
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, updated.ListenPort >= autoAssignPortMin && updated.ListenPort <= autoAssignPortMax,
|
||||
"auto-assigned port %d should be in range [%d, %d]", updated.ListenPort, autoAssignPortMin, autoAssignPortMax)
|
||||
assert.True(t, updated.PortAutoAssigned, "PortAutoAssigned should be set when update triggers auto-assignment")
|
||||
}
|
||||
|
||||
func TestCreateServiceFromPeer_TCP(t *testing.T) {
|
||||
mgr, _, _ := setupL4Test(t, boolPtr(false))
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -338,7 +338,7 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.ensureL4Port(ctx, transaction, svc, customPorts, false); err != nil {
|
||||
if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -367,11 +367,11 @@ func (m *Manager) clusterCustomPorts(ctx context.Context, svc *service.Service)
|
||||
|
||||
// ensureL4Port auto-assigns a listen port when needed and validates cluster support.
|
||||
// customPorts must be pre-computed via clusterCustomPorts before entering a transaction.
|
||||
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service, customPorts *bool, serviceUpdate bool) error {
|
||||
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service, customPorts *bool) error {
|
||||
if !service.IsL4Protocol(svc.Mode) {
|
||||
return nil
|
||||
}
|
||||
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && !serviceUpdate && (customPorts == nil || !*customPorts) {
|
||||
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) {
|
||||
if svc.Source != service.SourceEphemeral {
|
||||
return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster)
|
||||
}
|
||||
@@ -465,7 +465,7 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.ensureL4Port(ctx, transaction, svc, customPorts, false); err != nil {
|
||||
if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -651,22 +651,12 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
|
||||
m.preserveListenPort(service, existingService)
|
||||
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
|
||||
|
||||
// if the service is being updated, and we decide in the future to allow mode update,
|
||||
// we should reconsider the currently assigned port if not 0 for clusters that don't support custom ports
|
||||
if err := validateL4PortDiffOnClusterDiff(customPorts, service, existingService); err != nil {
|
||||
if err := m.ensureL4Port(ctx, transaction, service, customPorts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.ensureL4Port(ctx, transaction, service, customPorts, true); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// we can try carrying the previous service port into a new cluster, if this becomes a problem for multiple users,
|
||||
// we should reconsider adding another check
|
||||
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("update service: %w", err)
|
||||
}
|
||||
@@ -674,21 +664,6 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateL4PortDiffOnClusterDiff checks if custom L4 ports are configured and validates port changes across clusters.
|
||||
// It ensures no port changes if custom ports are unsupported for a given cluster and protocol mode.
|
||||
// Returns an error if validation fails, otherwise returns nil.
|
||||
func validateL4PortDiffOnClusterDiff(customPorts *bool, newSVC, oldSVC *service.Service) error {
|
||||
if !service.IsPortBasedProtocol(newSVC.Mode) || (customPorts != nil && *customPorts) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if newSVC.ListenPort != 0 && newSVC.ListenPort != oldSVC.ListenPort {
|
||||
return status.Errorf(status.InvalidArgument, "custom ports not supported on target cluster %s", newSVC.ProxyCluster)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleDomainChange validates the new domain is free inside the transaction
|
||||
// and applies the pre-resolved cluster (computed outside the tx by
|
||||
// resolveEffectiveCluster). It must NOT call clusterDeriver here: that talks
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
nbversion "github.com/netbirdio/netbird/version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
@@ -30,23 +28,6 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
const (
|
||||
// deprecatedRemotePeersVersion is the version of Netbird that introduced the NetworkMap.RemotePeers field, deprecated in favor of RemotePeers.
|
||||
deprecatedRemotePeersVersion = "0.29.3"
|
||||
)
|
||||
|
||||
// precomputedDeprecatedRemotePeersConstraint is the parsed ">= 0.29.3" constraint,
|
||||
// built once at init since the bound is a compile-time constant.
|
||||
var precomputedDeprecatedRemotePeersConstraint version.Constraints
|
||||
|
||||
func init() {
|
||||
constraint, err := version.NewConstraint(">= " + deprecatedRemotePeersVersion)
|
||||
if err != nil {
|
||||
panic("parse deprecated remote peers version constraint: " + err.Error())
|
||||
}
|
||||
precomputedDeprecatedRemotePeersConstraint = constraint
|
||||
}
|
||||
|
||||
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
||||
if config == nil {
|
||||
return nil
|
||||
@@ -174,11 +155,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
|
||||
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
|
||||
|
||||
if !shouldSkipSendingDeprecatedRemotePeers(peer.Meta.WtVersion) {
|
||||
response.RemotePeers = remotePeers
|
||||
}
|
||||
|
||||
response.RemotePeers = remotePeers
|
||||
response.NetworkMap.RemotePeers = remotePeers
|
||||
response.RemotePeersIsEmpty = len(remotePeers) == 0
|
||||
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
||||
@@ -269,19 +246,6 @@ func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]m
|
||||
return hashedUsers, machineUsers
|
||||
}
|
||||
|
||||
func shouldSkipSendingDeprecatedRemotePeers(peerVersion string) bool {
|
||||
if nbversion.IsDevelopmentVersion(peerVersion) {
|
||||
return true
|
||||
}
|
||||
|
||||
peerNBVersion, err := version.NewVersion(peerVersion)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return precomputedDeprecatedRemotePeersConstraint.Check(peerNBVersion)
|
||||
}
|
||||
|
||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
|
||||
for _, rPeer := range peers {
|
||||
allowedIPs := []string{rPeer.IP.String() + "/32"}
|
||||
@@ -399,6 +363,7 @@ func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSource
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
|
||||
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
|
||||
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {
|
||||
|
||||
@@ -202,42 +202,6 @@ func TestBuildJWTConfig_Audiences(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestShouldSkipSendingDeprecatedRemotePeers covers the version gate that
|
||||
// stops populating the deprecated top-level SyncResponse.RemotePeers field for
|
||||
// peers new enough to read RemotePeers off the NetworkMap. Development builds
|
||||
// are treated as latest and skip the field. The gate otherwise fails safe: a
|
||||
// release version older than the boundary, or one that can't be parsed (empty,
|
||||
// garbage, prereleases of the boundary) still receives the deprecated field so
|
||||
// older/unknown clients keep working.
|
||||
func TestShouldSkipSendingDeprecatedRemotePeers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
peerVersion string
|
||||
wantSkip bool
|
||||
}{
|
||||
{"exact boundary skips", "0.29.3", true},
|
||||
{"newer patch skips", "0.29.4", true},
|
||||
{"newer minor skips", "0.30.0", true},
|
||||
{"newer major skips", "1.0.0", true},
|
||||
{"v-prefixed newer skips", "v0.30.0", true},
|
||||
{"development build skips", "development", true},
|
||||
{"development build with commit skips", "development-abc123def456-dirty", true},
|
||||
{"older patch keeps field", "0.29.2", false},
|
||||
{"older minor keeps field", "0.28.0", false},
|
||||
{"prerelease of boundary keeps field", "0.29.3-SNAPSHOT", false},
|
||||
{"tagged dev prerelease keeps field", "v0.31.1-dev", false},
|
||||
{"empty version keeps field", "", false},
|
||||
{"garbage version keeps field", "not-a-version", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := shouldSkipSendingDeprecatedRemotePeers(tc.peerVersion)
|
||||
assert.Equal(t, tc.wantSkip, got, "skip decision for peer version %q", tc.peerVersion)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeSessionExpiresAt pins the wire encoding the client's
|
||||
// applySessionDeadline depends on:
|
||||
//
|
||||
|
||||
@@ -666,10 +666,8 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
||||
case resp := <-conn.sendChan:
|
||||
if err := conn.sendResponse(resp); err != nil {
|
||||
errChan <- err
|
||||
log.WithContext(conn.ctx).Tracef("Failed to send response to proxy %s: %v", conn.proxyID, err)
|
||||
return
|
||||
}
|
||||
log.WithContext(conn.ctx).Tracef("Send response to proxy %s", conn.proxyID)
|
||||
case <-conn.ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1890,7 +1890,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
|
||||
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
|
||||
}
|
||||
|
||||
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano(), netMap); err != nil {
|
||||
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano()); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
@@ -2573,9 +2573,7 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
changedPeerIDs := []string{peerID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, changedPeerIDs, affectedPeerIDs)
|
||||
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, []string{peerID})
|
||||
if err != nil {
|
||||
return fmt.Errorf("notify network map controller of peer update: %w", err)
|
||||
}
|
||||
@@ -2666,9 +2664,7 @@ func (am *DefaultAccountManager) UpdatePeerIPv6(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
if updateNetworkMap {
|
||||
changedPeerIDs := []string{peerID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peerID}); err != nil {
|
||||
return fmt.Errorf("notify network map controller: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@@ -62,7 +61,7 @@ type Manager interface {
|
||||
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
|
||||
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
@@ -110,7 +109,7 @@ type Manager interface {
|
||||
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
|
||||
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
|
||||
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
|
||||
GetExternalCacheManager() ExternalCacheManager
|
||||
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
@@ -129,7 +128,6 @@ type Manager interface {
|
||||
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
|
||||
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
|
||||
ExpandAndUpdateAffected(ctx context.Context, accountID string, snap *affectedpeers.Snapshot, change affectedpeers.Change)
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
|
||||
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
dns "github.com/netbirdio/netbird/dns"
|
||||
service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
activity "github.com/netbirdio/netbird/management/server/activity"
|
||||
affectedpeers "github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
idp "github.com/netbirdio/netbird/management/server/idp"
|
||||
peer "github.com/netbirdio/netbird/management/server/peer"
|
||||
posture "github.com/netbirdio/netbird/management/server/posture"
|
||||
@@ -1321,17 +1320,17 @@ func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID int
|
||||
}
|
||||
|
||||
// MarkPeerConnected mocks base method.
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
|
||||
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt)
|
||||
}
|
||||
|
||||
// MarkPeerDisconnected mocks base method.
|
||||
@@ -1638,18 +1637,6 @@ func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID, reason int
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// ExpandAndUpdateAffected mocks base method.
|
||||
func (m *MockManager) ExpandAndUpdateAffected(ctx context.Context, accountID string, snap *affectedpeers.Snapshot, change affectedpeers.Change) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "ExpandAndUpdateAffected", ctx, accountID, snap, change)
|
||||
}
|
||||
|
||||
// ExpandAndUpdateAffected indicates an expected call of ExpandAndUpdateAffected.
|
||||
func (mr *MockManagerMockRecorder) ExpandAndUpdateAffected(ctx, accountID, snap, change interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExpandAndUpdateAffected", reflect.TypeOf((*MockManager)(nil).ExpandAndUpdateAffected), ctx, accountID, snap, change)
|
||||
}
|
||||
|
||||
// UpdateAccountSettings mocks base method.
|
||||
func (m *MockManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -1813,7 +1813,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
@@ -1884,7 +1884,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
// when we mark peer as connected, the peer login expiration routine should trigger
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
failed := waitTimeout(wg, time.Second)
|
||||
@@ -1912,7 +1912,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
|
||||
t.Run("disconnect peer when session token matches", func(t *testing.T) {
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -1933,7 +1933,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
|
||||
// Newer stream wins on connect (sets SessionStartedAt = now ns).
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -1957,7 +1957,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
|
||||
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
|
||||
node2SyncTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano())
|
||||
require.NoError(t, err, "node 2 should connect peer")
|
||||
|
||||
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -1967,7 +1967,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
"SessionStartedAt should equal node2SyncTime token")
|
||||
|
||||
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano())
|
||||
require.NoError(t, err, "stale connect should not return error")
|
||||
|
||||
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
|
||||
@@ -2029,7 +2029,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
|
||||
defer done.Done()
|
||||
ready.Done()
|
||||
start.Wait()
|
||||
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token, nil)
|
||||
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -2070,7 +2070,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "unable to get the account")
|
||||
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
@@ -3282,19 +3282,6 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
|
||||
// when the channel delivers.
|
||||
const peerUpdateTimeout = 5 * time.Second
|
||||
|
||||
func drainPeerUpdates(ch <-chan *network_map.UpdateMessage) {
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
|
||||
t.Helper()
|
||||
select {
|
||||
|
||||
@@ -1,117 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// TestAffectedPeers_DependencyCoverageMatrix enumerates each network-map
|
||||
// dependency crossed with the change-type that can alter it, asserting the
|
||||
// resolver folds in exactly the peers whose map changes. A new dependency that
|
||||
// the resolver fails to walk should fail one of these rows; a new change-type
|
||||
// without a row is a coverage gap to add here.
|
||||
func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
|
||||
type row struct {
|
||||
name string
|
||||
build func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string)
|
||||
}
|
||||
|
||||
rows := []row{
|
||||
{
|
||||
name: "policy-groups/source-group-change refreshes source+routing, excludes unrelated",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{ChangedGroupIDs: []string{s.sourceGroupID}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "resource-routing-bridge/router-peer-change refreshes policy sources",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{ChangedPeerIDs: []string{s.routerPeerID}},
|
||||
[]string{s.sourcePeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "policy-change/explicit-policy refreshes source+routing",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
return affectedpeers.Change{Policies: []*types.Policy{policy}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "policy-destinationresource/explicit-policy bridges to routing peer",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
|
||||
return affectedpeers.Change{Policies: []*types.Policy{policy}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "resource-change refreshes source+routing on its network",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{Resources: []*resourceTypes.NetworkResource{
|
||||
{ID: s.resourceID, NetworkID: s.networkID, GroupIDs: []string{s.resourceGroupID}},
|
||||
}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "network-change refreshes source+routing on that network",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{Networks: []*networkTypes.Network{{ID: s.networkID}}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "posture-check-change refreshes source+routing of gated policy",
|
||||
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
|
||||
check, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
|
||||
Name: "cov-min-version",
|
||||
Checks: posture.ChecksDefinition{NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"}},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
policy.SourcePostureChecks = []string{check.ID}
|
||||
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
|
||||
require.NoError(t, err)
|
||||
return affectedpeers.Change{PostureCheckIDs: []string{check.ID}},
|
||||
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, r := range rows {
|
||||
t.Run(r.name, func(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
change, mustContain, mustExclude := r.build(t, s, ctx)
|
||||
affected := resolveAffected(t, s.manager.Store, s.accountID, change)
|
||||
|
||||
for _, id := range mustContain {
|
||||
assert.Contains(t, affected, id, "expected peer to be affected")
|
||||
}
|
||||
for _, id := range mustExclude {
|
||||
assert.NotContains(t, affected, id, "peer must not be affected")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,143 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// An update spans an old and a new state. The affected set must be the UNION of
|
||||
// peers reachable before and after the change; resolving only against the final
|
||||
// state drops peers that were reachable but no longer are. These tests pin the
|
||||
// two paths where the old state is reachable only by the changed object's
|
||||
// previous references: detaching a resource group, and re-pointing a router peer.
|
||||
|
||||
// TestAffectedPeers_E2E_UpdateResource_DetachGroup_RefreshesOldGroupSources:
|
||||
// a resource is reachable by a source group via two destination resource groups;
|
||||
// detaching one of them must still refresh that group's policy source peers, even
|
||||
// though the post-update resource no longer maps to it.
|
||||
func TestAffectedPeers_E2E_UpdateResource_DetachGroup_RefreshesOldGroupSources(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
// A second resource group + a second source group/peer that reaches the
|
||||
// resource only through that second group.
|
||||
const detachGroupID = "rs-detach-grp"
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{ID: detachGroupID, Name: "rs-detach"}))
|
||||
|
||||
const secondSourceGroupID = "rs-source-grp-2"
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-detach-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
secondSourcePeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
|
||||
ID: secondSourceGroupID, Name: "rs-source-2", Peers: []string{secondSourcePeer.ID},
|
||||
}))
|
||||
|
||||
resourcesManager, _, _ := s.managers()
|
||||
|
||||
// Attach the resource to the detach group as well: now in [resourceGroup, detachGroup].
|
||||
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
ID: s.resourceID,
|
||||
AccountID: s.accountID,
|
||||
NetworkID: s.networkID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/24",
|
||||
GroupIDs: []string{s.resourceGroupID, detachGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Policy granting the second source group access via the detach group.
|
||||
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(secondSourceGroupID, detachGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
secondSrcCh := s.updateManager.CreateChannel(ctx, secondSourcePeer.ID)
|
||||
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, secondSourcePeer.ID) })
|
||||
settleAffectedUpdates(secondSrcCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
// Detaching the resource from detachGroup removes the second source's
|
||||
// access; that source peer must be refreshed even though the post-update
|
||||
// resource no longer maps to detachGroup.
|
||||
peerShouldReceiveUpdate(t, secondSrcCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
ID: s.resourceID,
|
||||
AccountID: s.accountID,
|
||||
NetworkID: s.networkID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/24",
|
||||
GroupIDs: []string{s.resourceGroupID}, // detached detachGroup
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: detaching a resource group did not refresh the old group's policy source peer")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAffectedPeers_E2E_UpdateRouter_RepointPeer_RefreshesOldRoutingPeer:
|
||||
// changing router.Peer within the same network must still refresh the OLD routing
|
||||
// peer, which loses its routing role.
|
||||
func TestAffectedPeers_E2E_UpdateRouter_RepointPeer_RefreshesOldRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, routersManager, _ := s.managers()
|
||||
|
||||
routers, err := s.manager.Store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, s.accountID, s.networkID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, routers, 1)
|
||||
router := routers[0]
|
||||
oldRoutingPeer := router.Peer
|
||||
require.NotEmpty(t, oldRoutingPeer)
|
||||
|
||||
// A new peer to become the routing peer in place of the old one.
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-newrouter-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
newRoutingPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
|
||||
oldCh := s.updateManager.CreateChannel(ctx, oldRoutingPeer)
|
||||
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, oldRoutingPeer) })
|
||||
settleAffectedUpdates(oldCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
// The old routing peer stops serving the resource and must be refreshed.
|
||||
peerShouldReceiveUpdate(t, oldCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = routersManager.UpdateRouter(ctx, userID, &routerTypes.NetworkRouter{
|
||||
ID: router.ID,
|
||||
NetworkID: s.networkID,
|
||||
AccountID: s.accountID,
|
||||
Peer: newRoutingPeer.ID, // repoint within the same network
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: re-pointing the router peer did not refresh the old routing peer")
|
||||
}
|
||||
}
|
||||
@@ -1,255 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// allPeerMaps computes the serialized per-peer network map for every peer in the
|
||||
// account, mirroring the controller's compute path so the property test compares
|
||||
// against real output.
|
||||
func allPeerMaps(t *testing.T, manager *DefaultAccountManager, accountID string) map[string]string {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||
require.NoError(t, err)
|
||||
|
||||
account.InjectProxyPolicies(ctx)
|
||||
|
||||
validated := make(map[string]struct{}, len(account.Peers))
|
||||
for id := range account.Peers {
|
||||
validated[id] = struct{}{}
|
||||
}
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
out := make(map[string]string, len(account.Peers))
|
||||
for peerID := range account.Peers {
|
||||
nm := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validated, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
// Network.Serial is an account-global counter bumped on every change; it
|
||||
// is not a per-peer dependency, so normalize it out of the comparison.
|
||||
if nm.Network != nil {
|
||||
nm.Network.Serial = 0
|
||||
}
|
||||
out[peerID] = canonicalJSON(t, nm)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// canonicalJSON marshals v and returns an order-insensitive string form: every
|
||||
// JSON array is sorted by the canonical form of its elements. The network map's
|
||||
// Peers/Routes/FirewallRules/SourceRanges slices have nondeterministic order, so
|
||||
// a raw JSON compare would report spurious changes.
|
||||
func canonicalJSON(t *testing.T, v interface{}) string {
|
||||
t.Helper()
|
||||
b, err := json.Marshal(v)
|
||||
require.NoError(t, err)
|
||||
var parsed interface{}
|
||||
require.NoError(t, json.Unmarshal(b, &parsed))
|
||||
canonicalized, err := json.Marshal(sortAny(parsed))
|
||||
require.NoError(t, err)
|
||||
return string(canonicalized)
|
||||
}
|
||||
|
||||
func sortAny(v interface{}) interface{} {
|
||||
switch val := v.(type) {
|
||||
case []interface{}:
|
||||
for i := range val {
|
||||
val[i] = sortAny(val[i])
|
||||
}
|
||||
sort.Slice(val, func(i, j int) bool {
|
||||
bi, _ := json.Marshal(val[i])
|
||||
bj, _ := json.Marshal(val[j])
|
||||
return string(bi) < string(bj)
|
||||
})
|
||||
return val
|
||||
case map[string]interface{}:
|
||||
for k := range val {
|
||||
val[k] = sortAny(val[k])
|
||||
}
|
||||
return val
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// changedPeers returns the peer IDs whose serialized map differs between before
|
||||
// and after.
|
||||
func changedPeers(before, after map[string]string) []string {
|
||||
var changed []string
|
||||
for id, b := range before {
|
||||
a, ok := after[id]
|
||||
if !ok || a != b {
|
||||
changed = append(changed, id)
|
||||
}
|
||||
}
|
||||
for id := range after {
|
||||
if _, ok := before[id]; !ok {
|
||||
changed = append(changed, id)
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
// TestAffectedPeers_Property_ResolverSupersetsRealChanges builds a topology,
|
||||
// applies random changes, and asserts that the resolver's affected set is a
|
||||
// superset of the peers whose real network map actually changed. If the resolver
|
||||
// ever misses a dependency, a change will alter a peer's map without that peer
|
||||
// appearing in the affected set, failing here.
|
||||
func TestAffectedPeers_Property_ResolverSupersetsRealChanges(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
// A pre-existing peer->resource policy so the resource/router bridge is live.
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Extra peers and groups to give mutations room to move membership around.
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "prop-key", types.SetupKeyReusable, 0, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
extraPeers := make([]string, 0, 4)
|
||||
for i := 0; i < 4; i++ {
|
||||
p := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
extraPeers = append(extraPeers, p.ID)
|
||||
}
|
||||
extraGroups := []string{"prop-grp-0", "prop-grp-1"}
|
||||
for _, g := range extraGroups {
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{ID: g, Name: g}))
|
||||
}
|
||||
|
||||
rng := rand.New(rand.NewSource(1))
|
||||
allGroups := append([]string{s.sourceGroupID, s.resourceGroupID, s.routerPeerGroupID}, extraGroups...)
|
||||
allPeers := append([]string{s.sourcePeerID, s.routerPeerID, s.routerGroupPeerID, s.unrelatedPeerID}, extraPeers...)
|
||||
|
||||
for iter := 0; iter < 60; iter++ {
|
||||
change, apply := s.randomMutation(t, rng, allGroups, allPeers)
|
||||
if apply == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
before := allPeerMaps(t, s.manager, s.accountID)
|
||||
|
||||
resolvedSet := make(map[string]struct{})
|
||||
resolve := func() {
|
||||
require.NoError(t, s.manager.Store.ExecuteInTransaction(ctx, func(tx store.Store) error {
|
||||
snap, err := affectedpeers.Load(ctx, tx, s.accountID, change)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, id := range snap.Expand(ctx, s.accountID, change) {
|
||||
resolvedSet[id] = struct{}{}
|
||||
}
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// Resolve on both sides of the mutation and union: removals are visible
|
||||
// only pre-apply (the leaving peer is still a member), additions only
|
||||
// post-apply (the joining peer is now a member). Production captures both
|
||||
// via per-path handling (e.g. UpdateGroup passes peersToRemove); the union
|
||||
// models that without coupling the test to each path's ordering.
|
||||
resolve()
|
||||
changedIDs := change.ChangedPeerIDs
|
||||
apply()
|
||||
resolve()
|
||||
|
||||
after := allPeerMaps(t, s.manager, s.accountID)
|
||||
|
||||
// The explicitly-changed peer's own map refresh is the caller's
|
||||
// responsibility (the resolver returns the peers to propagate to), so it
|
||||
// is allowed to be absent from the resolved set.
|
||||
changedExplicitly := make(map[string]struct{}, len(changedIDs))
|
||||
for _, id := range changedIDs {
|
||||
changedExplicitly[id] = struct{}{}
|
||||
}
|
||||
|
||||
for _, id := range changedPeers(before, after) {
|
||||
if _, stillExists := after[id]; !stillExists {
|
||||
continue
|
||||
}
|
||||
if _, isExplicit := changedExplicitly[id]; isExplicit {
|
||||
continue
|
||||
}
|
||||
_, ok := resolvedSet[id]
|
||||
require.Truef(t, ok,
|
||||
"iter %d: peer %s network map changed but was not in the resolver's affected set %v (change=%+v)",
|
||||
iter, id, maps.Keys(resolvedSet), change)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// randomMutation picks a random change, returns the Change to resolve and a
|
||||
// function that applies the underlying store mutation. apply is nil when the
|
||||
// drawn mutation is a no-op for the current state.
|
||||
func (s *routerScenario) randomMutation(t *testing.T, rng *rand.Rand, allGroups, allPeers []string) (affectedpeers.Change, func()) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
switch rng.Intn(3) {
|
||||
case 0:
|
||||
groupID := allGroups[rng.Intn(len(allGroups))]
|
||||
peerID := allPeers[rng.Intn(len(allPeers))]
|
||||
grp, err := s.manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, s.accountID, groupID)
|
||||
require.NoError(t, err)
|
||||
if slicesContains(grp.Peers, peerID) {
|
||||
return affectedpeers.Change{}, nil
|
||||
}
|
||||
return affectedpeers.Change{ChangedGroupIDs: []string{groupID}, ChangedPeerIDs: []string{peerID}},
|
||||
func() {
|
||||
require.NoError(t, s.manager.GroupAddPeer(ctx, s.accountID, groupID, peerID))
|
||||
}
|
||||
case 1:
|
||||
groupID := allGroups[rng.Intn(len(allGroups))]
|
||||
grp, err := s.manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, s.accountID, groupID)
|
||||
require.NoError(t, err)
|
||||
if len(grp.Peers) == 0 {
|
||||
return affectedpeers.Change{}, nil
|
||||
}
|
||||
peerID := grp.Peers[rng.Intn(len(grp.Peers))]
|
||||
return affectedpeers.Change{ChangedGroupIDs: []string{groupID}, ChangedPeerIDs: []string{peerID}},
|
||||
func() {
|
||||
require.NoError(t, s.manager.GroupDeletePeer(ctx, s.accountID, groupID, peerID))
|
||||
}
|
||||
default:
|
||||
src := allGroups[rng.Intn(len(allGroups))]
|
||||
dst := allGroups[rng.Intn(len(allGroups))]
|
||||
policy := &types.Policy{
|
||||
Enabled: true,
|
||||
Name: fmt.Sprintf("prop-policy-%d", rng.Int()),
|
||||
Rules: []*types.PolicyRule{{
|
||||
Enabled: true,
|
||||
Sources: []string{src},
|
||||
Destinations: []string{dst},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
}},
|
||||
}
|
||||
return affectedpeers.Change{Policies: []*types.Policy{policy}},
|
||||
func() {
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func slicesContains(s []string, v string) bool {
|
||||
for _, x := range s {
|
||||
if x == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,164 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// countingStore wraps a real store and counts the per-account collection loads
|
||||
// the resolver performs, so a test can assert each is read at most once and that
|
||||
// irrelevant collections are skipped entirely.
|
||||
type countingStore struct {
|
||||
store.Store
|
||||
mu sync.Mutex
|
||||
counts map[string]int
|
||||
}
|
||||
|
||||
func newCountingStore(s store.Store) *countingStore {
|
||||
return &countingStore{Store: s, counts: map[string]int{}}
|
||||
}
|
||||
|
||||
func (c *countingStore) bump(name string) {
|
||||
c.mu.Lock()
|
||||
c.counts[name]++
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *countingStore) count(name string) int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.counts[name]
|
||||
}
|
||||
|
||||
func (c *countingStore) total() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
n := 0
|
||||
for _, v := range c.counts {
|
||||
n += v
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (c *countingStore) GetAccountPolicies(ctx context.Context, ls store.LockingStrength, accountID string) ([]*types.Policy, error) {
|
||||
c.bump("policies")
|
||||
return c.Store.GetAccountPolicies(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetAccountRoutes(ctx context.Context, ls store.LockingStrength, accountID string) ([]*route.Route, error) {
|
||||
c.bump("routes")
|
||||
return c.Store.GetAccountRoutes(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetAccountNameServerGroups(ctx context.Context, ls store.LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
|
||||
c.bump("nameservers")
|
||||
return c.Store.GetAccountNameServerGroups(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetAccountDNSSettings(ctx context.Context, ls store.LockingStrength, accountID string) (*types.DNSSettings, error) {
|
||||
c.bump("dnssettings")
|
||||
return c.Store.GetAccountDNSSettings(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetNetworkRoutersByAccountID(ctx context.Context, ls store.LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) {
|
||||
c.bump("routers")
|
||||
return c.Store.GetNetworkRoutersByAccountID(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetNetworkResourcesByAccountID(ctx context.Context, ls store.LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) {
|
||||
c.bump("resources")
|
||||
return c.Store.GetNetworkResourcesByAccountID(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
func (c *countingStore) GetAccountServices(ctx context.Context, ls store.LockingStrength, accountID string) ([]*rpservice.Service, error) {
|
||||
c.bump("services")
|
||||
return c.Store.GetAccountServices(ctx, ls, accountID)
|
||||
}
|
||||
|
||||
// TestAffectedPeers_QueryCount_NoRedundantFullTableLoads asserts the resolver
|
||||
// loads each per-account collection at most once per Resolve (memoization) even
|
||||
// on a change that drives every bridge, and skips the services table when the
|
||||
// account has no embedded proxy peers.
|
||||
func TestAffectedPeers_QueryCount_NoRedundantFullTableLoads(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
cs := newCountingStore(s.manager.Store)
|
||||
|
||||
// A group change that exercises policies, routers, resources and the bridge.
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{s.sourceGroupID}}
|
||||
snap, err := affectedpeers.Load(ctx, cs, s.accountID, change)
|
||||
require.NoError(t, err)
|
||||
affected := snap.Expand(ctx, s.accountID, change)
|
||||
assert.Contains(t, affected, s.routerPeerID, "bridge must still resolve the routing peer")
|
||||
|
||||
for _, name := range []string{"policies", "routes", "nameservers", "dnssettings", "routers", "resources"} {
|
||||
assert.LessOrEqualf(t, cs.count(name), 1,
|
||||
"%s must be loaded at most once per Resolve, got %d", name, cs.count(name))
|
||||
}
|
||||
assert.Equal(t, 0, cs.count("services"),
|
||||
"services must not be loaded when the account has no embedded proxy peers")
|
||||
}
|
||||
|
||||
// TestAffectedPeers_QueryCount_NarrowChangeSkipsLoads asserts that a change with
|
||||
// no group/peer signal touches no per-account collections beyond what its inputs
|
||||
// require.
|
||||
func TestAffectedPeers_QueryCount_NarrowChangeSkipsLoads(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
cs := newCountingStore(s.manager.Store)
|
||||
|
||||
// A bare network change drives only the router->source bridge: routers and
|
||||
// resources are needed, but routes/nameservers/dnssettings/services are not.
|
||||
_, err := affectedpeers.Load(ctx, cs, s.accountID, affectedpeers.Change{Networks: []*networkTypes.Network{{ID: s.networkID}}})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 0, cs.count("routes"), "routes must not be loaded for a network-only change")
|
||||
assert.Equal(t, 0, cs.count("nameservers"), "nameservers must not be loaded for a network-only change")
|
||||
assert.Equal(t, 0, cs.count("dnssettings"), "dnssettings must not be loaded for a network-only change")
|
||||
assert.Equal(t, 0, cs.count("services"), "services must not be loaded for a network-only change")
|
||||
}
|
||||
|
||||
// TestAffectedPeers_QueryCount_ExpandReadsNothing is the core invariant of the
|
||||
// Load/Expand split: Load (run inside the transaction) does all store reads;
|
||||
// Expand (run after commit) must touch the store ZERO times, so it never holds
|
||||
// the write lock and never reads post-commit state.
|
||||
func TestAffectedPeers_QueryCount_ExpandReadsNothing(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{s.sourceGroupID}}
|
||||
|
||||
cs := newCountingStore(s.manager.Store)
|
||||
snap, err := affectedpeers.Load(ctx, cs, s.accountID, change)
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, cs.total(), 0, "Load must read the store")
|
||||
|
||||
// Any store access during Expand would increment the same counter. Expand
|
||||
// operates purely on the snapshot, so the count must not move.
|
||||
readsAfterLoad := cs.total()
|
||||
affected := snap.Expand(ctx, s.accountID, change)
|
||||
assert.Contains(t, affected, s.routerPeerID, "Expand must still produce the affected peers from the snapshot")
|
||||
assert.Equal(t, readsAfterLoad, cs.total(), "Expand must perform zero store reads — it operates purely on the loaded snapshot")
|
||||
}
|
||||
@@ -1,333 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func (s *routerScenario) resolveGroupChangeAffected(ctx context.Context, changedGroupIDs []string) []string {
|
||||
change := affectedpeers.Change{ChangedGroupIDs: changedGroupIDs}
|
||||
snap, err := affectedpeers.Load(ctx, s.manager.Store, s.accountID, change)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return snap.Expand(ctx, s.accountID, change)
|
||||
}
|
||||
|
||||
func (s *routerScenario) resolvePeerChangeAffected(ctx context.Context, changedPeerIDs []string) []string {
|
||||
change := affectedpeers.Change{ChangedPeerIDs: changedPeerIDs}
|
||||
snap, err := affectedpeers.Load(ctx, s.manager.Store, s.accountID, change)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return snap.Expand(ctx, s.accountID, change)
|
||||
}
|
||||
|
||||
func TestAffectedPeers_GroupChange_SourceGroupMembership_RefreshesRoutingPeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source group member must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"changing the source group of a peer->resource policy must refresh the resource's routing peer")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_GroupChange_SourceGroupMembership_RefreshesRoutingPeer_RouterPeerGroups(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
|
||||
|
||||
assert.Contains(t, affected, s.routerGroupPeerID,
|
||||
"changing the source group must refresh the routing peer defined via router.PeerGroups")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_GroupChange_RouterPeerGroupMembership_RefreshesPolicySources(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{s.routerPeerGroupID})
|
||||
|
||||
assert.Contains(t, affected, s.routerGroupPeerID, "the routing peer itself must be affected")
|
||||
assert.Contains(t, affected, s.sourcePeerID,
|
||||
"changing the router's PeerGroups must refresh the source peers of policies serving the resource")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PeerChange_SourcePeer_RefreshesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"a status change on a source peer must refresh the resource's routing peer that serves it")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PeerChange_SourcePeer_ByDestinationResource_RefreshesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"DestinationResource-targeted policy must still bridge a source-peer change to the routing peer")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_DeleteGroup_ResolvesAffectedPeers(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
const memberOnlyGroupID = "rs-memberonly-grp"
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
|
||||
ID: memberOnlyGroupID, Name: "rs-memberonly", Peers: []string{s.sourcePeerID},
|
||||
}))
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{memberOnlyGroupID})
|
||||
assert.Empty(t, affected, "an unlinked group has no network-map impact, so no peer is affected")
|
||||
|
||||
require.NoError(t, s.manager.DeleteGroup(ctx, s.accountID, userID, memberOnlyGroupID))
|
||||
}
|
||||
|
||||
func TestAffectedPeers_GroupAddResource_RefreshesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
const extraResourceGroupID = "rs-resource-grp-extra"
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
|
||||
ID: extraResourceGroupID, Name: "rs-resource-extra",
|
||||
}))
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, extraResourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, s.manager.GroupAddResource(ctx, s.accountID, extraResourceGroupID, types.Resource{
|
||||
ID: s.resourceID,
|
||||
Type: types.ResourceTypeHost,
|
||||
}))
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{extraResourceGroupID})
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"attaching a resource to a policy destination group must refresh the resource's routing peer")
|
||||
assert.Contains(t, affected, s.sourcePeerID, "policy source peers must refresh")
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func (s *routerScenario) createPostureCheckGatedPolicy(t *testing.T, ctx context.Context) string {
|
||||
t.Helper()
|
||||
|
||||
check, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
|
||||
Name: "rs-min-version",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"},
|
||||
},
|
||||
}, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
policy.SourcePostureChecks = []string{check.ID}
|
||||
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
return check.ID
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_SavePostureCheck_RefreshesRoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
checkID := s.createPostureCheckGatedPolicy(t, ctx)
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
peerShouldNotReceiveUpdate(t, unrelatedCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
|
||||
ID: checkID,
|
||||
Name: "rs-min-version",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.31.0"},
|
||||
},
|
||||
}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: editing a posture check did not refresh source + routing peers")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_UpdateResource_DestinationResourcePolicy_RefreshesSourcePeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
resourcesManager, _, _ := s.managers()
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
peerShouldNotReceiveUpdate(t, unrelatedCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
ID: s.resourceID,
|
||||
AccountID: s.accountID,
|
||||
NetworkID: s.networkID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/25",
|
||||
GroupIDs: []string{s.resourceGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: updating a DestinationResource-targeted resource did not refresh its policy source peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
resourcesManager, routersManager, _ := s.managers()
|
||||
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-disabled", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
disabledRouterPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
|
||||
NetworkID: s.networkID,
|
||||
AccountID: s.accountID,
|
||||
Peer: disabledRouterPeer.ID,
|
||||
Masquerade: true,
|
||||
Metric: 9000,
|
||||
Enabled: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
disabledCh := s.updateManager.CreateChannel(ctx, disabledRouterPeer.ID)
|
||||
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID) })
|
||||
|
||||
settleAffectedUpdates(disabledCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, disabledCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
ID: s.resourceID,
|
||||
AccountID: s.accountID,
|
||||
NetworkID: s.networkID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/25",
|
||||
GroupIDs: []string{s.resourceGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: resource update did not refresh the disabled sibling router's peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_GroupChange_RouterInOtherNetworkNotAffected(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
second := s.addSecondTopology(t, "groupiso")
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
|
||||
assert.NotContains(t, affected, second.routerPeerID,
|
||||
"a router in an unrelated network must not be affected by a source-group change for another resource")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PeerChange_RouterInOtherNetworkNotAffected(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
second := s.addSecondTopology(t, "peeriso")
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
|
||||
assert.NotContains(t, affected, second.routerPeerID,
|
||||
"a router in an unrelated network must not be affected by a source-peer change for another resource")
|
||||
}
|
||||
@@ -1,771 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// routerScenario captures the topology from the bug report:
|
||||
//
|
||||
// network ── router (routing peer) ── resource (in resourceGroup)
|
||||
// independent peer ──(policy: source -> resource)──> resource
|
||||
//
|
||||
// The routing peer must be refreshed when a policy grants a source peer access
|
||||
// to the resource, because the network map connects the source peer to the
|
||||
// routing peer at compute time (Account.GetPoliciesForNetworkResource +
|
||||
// addNetworksRoutingPeers). The routing peer is NOT a member of the resource
|
||||
// group, so static group/peer resolution alone cannot find it.
|
||||
type routerScenario struct {
|
||||
manager *DefaultAccountManager
|
||||
updateManager *update_channel.PeersUpdateManager
|
||||
accountID string
|
||||
networkID string
|
||||
|
||||
sourcePeerID string // independent peer that the policy grants access from
|
||||
sourceGroupID string // group containing the source peer
|
||||
|
||||
routerPeerID string // peer acting as the routing peer (direct router.Peer)
|
||||
routerGroupPeerID string // peer that is a member of routerPeerGroup
|
||||
routerPeerGroupID string // group used for router.PeerGroups
|
||||
|
||||
resourceID string // network resource
|
||||
resourceGroupID string // group whose member is the resource (no peers)
|
||||
|
||||
unrelatedPeerID string // peer in no relevant entity
|
||||
}
|
||||
|
||||
// setupRouterScenario builds the topology above with the default policy removed
|
||||
// and channels NOT yet created, so callers control exactly when updates can flow.
|
||||
func setupRouterScenario(t *testing.T, directRouterPeer bool) *routerScenario {
|
||||
t.Helper()
|
||||
|
||||
manager, updateManager, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
account, err := createAccount(manager, "router_scenario", userID, "")
|
||||
require.NoError(t, err)
|
||||
accountID := account.Id
|
||||
|
||||
// Remove the default policy so AddPeer/CreateGroup don't schedule unrelated updates.
|
||||
policies, err := manager.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
for _, p := range policies {
|
||||
require.NoError(t, manager.Store.DeletePolicy(ctx, accountID, p.ID))
|
||||
}
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(ctx, accountID, "rs-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
sourcePeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
|
||||
routerPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
|
||||
routerGroupPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
|
||||
unrelatedPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
|
||||
|
||||
const (
|
||||
sourceGroupID = "rs-source-grp"
|
||||
routerPeerGroupID = "rs-router-grp"
|
||||
resourceGroupID = "rs-resource-grp"
|
||||
)
|
||||
|
||||
for _, g := range []*types.Group{
|
||||
{ID: sourceGroupID, Name: "rs-source", Peers: []string{sourcePeer.ID}},
|
||||
{ID: routerPeerGroupID, Name: "rs-router", Peers: []string{routerGroupPeer.ID}},
|
||||
{ID: resourceGroupID, Name: "rs-resource"}, // intentionally peerless; the resource is its only member
|
||||
} {
|
||||
require.NoError(t, manager.CreateGroup(ctx, accountID, userID, g))
|
||||
}
|
||||
|
||||
permissionsManager := permissions.NewManager(manager.Store)
|
||||
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
|
||||
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
|
||||
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
|
||||
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
|
||||
|
||||
network, err := networksManager.CreateNetwork(ctx, userID, &networkTypes.Network{
|
||||
ID: "rs-network",
|
||||
AccountID: accountID,
|
||||
Name: "rs-network",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resource, err := resourcesManager.CreateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
AccountID: accountID,
|
||||
NetworkID: network.ID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/24",
|
||||
GroupIDs: []string{resourceGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
router := &routerTypes.NetworkRouter{
|
||||
ID: "rs-router",
|
||||
NetworkID: network.ID,
|
||||
AccountID: accountID,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
}
|
||||
if directRouterPeer {
|
||||
router.Peer = routerPeer.ID
|
||||
} else {
|
||||
router.PeerGroups = []string{routerPeerGroupID}
|
||||
}
|
||||
_, err = routersManager.CreateRouter(ctx, userID, router)
|
||||
require.NoError(t, err)
|
||||
|
||||
return &routerScenario{
|
||||
manager: manager,
|
||||
updateManager: updateManager,
|
||||
accountID: accountID,
|
||||
networkID: network.ID,
|
||||
sourcePeerID: sourcePeer.ID,
|
||||
sourceGroupID: sourceGroupID,
|
||||
routerPeerID: routerPeer.ID,
|
||||
routerGroupPeerID: routerGroupPeer.ID,
|
||||
routerPeerGroupID: routerPeerGroupID,
|
||||
resourceID: resource.ID,
|
||||
resourceGroupID: resourceGroupID,
|
||||
unrelatedPeerID: unrelatedPeer.ID,
|
||||
}
|
||||
}
|
||||
|
||||
// peerToResourcePolicy builds a policy granting the source group access to the
|
||||
// resource, referencing the resource by its group in the rule destination.
|
||||
func peerToResourcePolicyByGroup(sourceGroupID, resourceGroupID string) *types.Policy {
|
||||
return &types.Policy{
|
||||
Enabled: true,
|
||||
Name: "peer-to-resource-by-group",
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{sourceGroupID},
|
||||
Destinations: []string{resourceGroupID},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// peerToResourcePolicyByResource builds a policy referencing the resource
|
||||
// directly via DestinationResource rather than its group.
|
||||
func peerToResourcePolicyByResource(sourceGroupID, resourceID string) *types.Policy {
|
||||
return &types.Policy{
|
||||
Enabled: true,
|
||||
Name: "peer-to-resource-by-resource",
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{sourceGroupID},
|
||||
DestinationResource: types.Resource{ID: resourceID, Type: types.ResourceTypeHost},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// resolvePolicyAffected mirrors SavePolicy's resolution: resolve the affected
|
||||
// peers for the given policy.
|
||||
func (s *routerScenario) resolvePolicyAffected(ctx context.Context, policy *types.Policy) []string {
|
||||
change := affectedpeers.Change{Policies: []*types.Policy{policy}}
|
||||
snap, err := affectedpeers.Load(ctx, s.manager.Store, s.accountID, change)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return snap.Expand(ctx, s.accountID, change)
|
||||
}
|
||||
|
||||
func TestAffectedPeers_SourcePeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_RoutingPeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// BUG: the direct routing peer serves the resource's subnet to the source
|
||||
// peer, so it must be refreshed when the policy is created. The policy path
|
||||
// only resolves the literal rule groups (source group + resource group);
|
||||
// the resource group has no peer members and the router peer is reachable
|
||||
// only through the network, so it is dropped.
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"routing peer (router.Peer) serving the resource must be affected by a policy granting access to it")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_RoutingPeer_RouterPeerGroups(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// Router defined via PeerGroups instead of a direct peer.
|
||||
assert.Contains(t, affected, s.routerGroupPeerID,
|
||||
"routing peer (router.PeerGroups member) serving the resource must be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_DestResource_RoutingPeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// When the resource is referenced via DestinationResource, RuleGroups()
|
||||
// returns only the source group and the resource ID is not a peer, so
|
||||
// collectPolicyAffectedGroupsAndPeers yields nothing for the destination at
|
||||
// all. The routing peer is dropped here too.
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"routing peer must be affected when the resource is referenced via DestinationResource")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_DestResource_RoutingPeer_RouterPeerGroups(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
assert.Contains(t, affected, s.routerGroupPeerID,
|
||||
"routing peer (PeerGroups) must be affected when the resource is referenced via DestinationResource")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_SourceResourcePeer_RoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
// Source expressed as a direct peer (SourceResource), destination as resource group.
|
||||
policy := &types.Policy{
|
||||
Enabled: true,
|
||||
Name: "sourceResource-peer-to-resource",
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
SourceResource: types.Resource{ID: s.sourcePeerID, Type: types.ResourceTypePeer},
|
||||
Destinations: []string{s.resourceGroupID},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// The direct source peer IS picked up (collectPolicyAffectedGroupsAndPeers
|
||||
// handles SourceResource peers), but the routing peer is still missing.
|
||||
assert.Contains(t, affected, s.sourcePeerID, "direct source peer must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID, "routing peer must be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_PolicyToResource_UnrelatedPeerNotAffected(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
// Guard against an over-broad fix: a peer in no relevant entity must never
|
||||
// be pulled in.
|
||||
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_ResourceSideBridgesToRoutingPeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
// A pre-existing policy grants the source group access to the resource.
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Drive an update through the resource manager and assert the routing peer
|
||||
// is among the affected set by observing the channel. This path walks
|
||||
// policies whose destinations reference the resource's groups, folds in the
|
||||
// source groups, and loads the network's routers, so it reaches both the
|
||||
// source peer and the routing peer.
|
||||
permissionsManager := permissions.NewManager(s.manager.Store)
|
||||
groupsManager := groups.NewManager(s.manager.Store, permissionsManager, s.manager)
|
||||
rm := resources.NewManager(s.manager.Store, permissionsManager, groupsManager, s.manager, s.manager.serviceManager)
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
})
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err = rm.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
ID: s.resourceID,
|
||||
AccountID: s.accountID,
|
||||
NetworkID: s.networkID,
|
||||
Name: "rs-resource-host",
|
||||
Address: "10.20.30.0/24",
|
||||
GroupIDs: []string{s.resourceGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: resource update did not refresh source peer + routing peer")
|
||||
}
|
||||
}
|
||||
|
||||
// settleAffectedUpdates waits for in-flight async updates to arrive, then drains
|
||||
// every given channel so subsequent assertions start from a clean slate.
|
||||
//
|
||||
// Setup (CreateNetwork/CreateResource/CreateRouter) fires async UpdateAffectedPeers
|
||||
// goroutines; draining first means the assertion only observes updates from the
|
||||
// action under test, not setup stragglers.
|
||||
func settleAffectedUpdates(chans ...<-chan *network_map.UpdateMessage) {
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
for _, ch := range chans {
|
||||
drainPeerUpdates(ch)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_CreatePolicy_RoutingPeer_DirectRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
peerShouldNotReceiveUpdate(t, unrelatedCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: creating peer->resource policy did not refresh the routing peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_CreatePolicy_RoutingPeer_RouterPeerGroups(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerGroupPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerGroupPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: routing peer (PeerGroups) not refreshed on policy create")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_DestResource_RoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: routing peer not refreshed when policy targets DestinationResource")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_DeletePolicy_RoutingPeer(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
require.NoError(t, s.manager.DeletePolicy(ctx, s.accountID, policy.ID, userID))
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: deleting peer->resource policy did not refresh the routing peer")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *routerScenario) managers() (resources.Manager, routers.Manager, networks.Manager) {
|
||||
permissionsManager := permissions.NewManager(s.manager.Store)
|
||||
groupsManager := groups.NewManager(s.manager.Store, permissionsManager, s.manager)
|
||||
resourcesManager := resources.NewManager(s.manager.Store, permissionsManager, groupsManager, s.manager, s.manager.serviceManager)
|
||||
routersManager := routers.NewManager(s.manager.Store, permissionsManager, s.manager)
|
||||
networksManager := networks.NewManager(s.manager.Store, permissionsManager, resourcesManager, routersManager, s.manager)
|
||||
return resourcesManager, routersManager, networksManager
|
||||
}
|
||||
|
||||
type secondTopology struct {
|
||||
networkID string
|
||||
resourceID string
|
||||
resourceGroupID string
|
||||
routerPeerID string
|
||||
}
|
||||
|
||||
func (s *routerScenario) addSecondTopology(t *testing.T, suffix string) secondTopology {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
resourcesManager, routersManager, networksManager := s.managers()
|
||||
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-"+suffix, types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
routerPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
|
||||
resourceGroupID := "rs-resource-grp-" + suffix
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
|
||||
ID: resourceGroupID, Name: "rs-resource-" + suffix,
|
||||
}))
|
||||
|
||||
network, err := networksManager.CreateNetwork(ctx, userID, &networkTypes.Network{
|
||||
ID: "rs-network-" + suffix,
|
||||
AccountID: s.accountID,
|
||||
Name: "rs-network-" + suffix,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
resource, err := resourcesManager.CreateResource(ctx, userID, &resourceTypes.NetworkResource{
|
||||
AccountID: s.accountID,
|
||||
NetworkID: network.ID,
|
||||
Name: "rs-resource-host-" + suffix,
|
||||
Address: "10.40.50.0/24",
|
||||
GroupIDs: []string{resourceGroupID},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
|
||||
NetworkID: network.ID,
|
||||
AccountID: s.accountID,
|
||||
Peer: routerPeer.ID,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return secondTopology{
|
||||
networkID: network.ID,
|
||||
resourceID: resource.ID,
|
||||
resourceGroupID: resourceGroupID,
|
||||
routerPeerID: routerPeer.ID,
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_UpdatePolicy_BothRoutingPeers(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
second := s.addSecondTopology(t, "b")
|
||||
ctx := context.Background()
|
||||
|
||||
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerACh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
routerBCh := s.updateManager.CreateChannel(ctx, second.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
s.updateManager.CloseChannel(ctx, second.routerPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerACh, routerBCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerACh)
|
||||
peerShouldReceiveUpdate(t, routerBCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
policy.Rules[0].Destinations = []string{second.resourceGroupID}
|
||||
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: re-pointing the policy destination did not refresh both routing peers")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_UpdatePolicy_AddSource(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
const secondSourceGroupID = "rs-source-grp-2"
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-2", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
secondSourcePeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
|
||||
ID: secondSourceGroupID, Name: "rs-source-2", Peers: []string{secondSourcePeer.ID},
|
||||
}))
|
||||
|
||||
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
newSrcCh := s.updateManager.CreateChannel(ctx, secondSourcePeer.ID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, secondSourcePeer.ID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(newSrcCh, routerCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, newSrcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
policy.Rules[0].Sources = []string{s.sourceGroupID, secondSourceGroupID}
|
||||
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: adding a source group did not refresh the new source peer + routing peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_E2E_DestResource_RouterPeerGroups(t *testing.T) {
|
||||
s := setupRouterScenario(t, false)
|
||||
ctx := context.Background()
|
||||
|
||||
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
|
||||
routerCh := s.updateManager.CreateChannel(ctx, s.routerGroupPeerID)
|
||||
t.Cleanup(func() {
|
||||
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
|
||||
s.updateManager.CloseChannel(ctx, s.routerGroupPeerID)
|
||||
})
|
||||
|
||||
settleAffectedUpdates(srcCh, routerCh)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
peerShouldReceiveUpdate(t, srcCh)
|
||||
peerShouldReceiveUpdate(t, routerCh)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(peerUpdateTimeout):
|
||||
t.Error("timeout: DestinationResource policy with PeerGroups router did not refresh the routing peer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectedPeers_AllRoutingPeers_Network(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
_, routersManager, _ := s.managers()
|
||||
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-r2", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
|
||||
require.NoError(t, err)
|
||||
secondRouterPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
|
||||
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
|
||||
NetworkID: s.networkID,
|
||||
AccountID: s.accountID,
|
||||
Peer: secondRouterPeer.ID,
|
||||
Masquerade: true,
|
||||
Metric: 9998,
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID, "first routing peer must be affected")
|
||||
assert.Contains(t, affected, secondRouterPeer.ID, "second routing peer on the same network must also be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_DisabledRouter(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
routers, err := s.manager.Store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, s.accountID, s.networkID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, routers, 1)
|
||||
routers[0].Enabled = false
|
||||
require.NoError(t, s.manager.Store.UpdateNetworkRouter(ctx, routers[0]))
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"disabled router's peer must still be affected: Enabled must not gate affected-peers")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_DisabledResource(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
res, err := s.manager.Store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, s.accountID, s.resourceID)
|
||||
require.NoError(t, err)
|
||||
res.Enabled = false
|
||||
require.NoError(t, s.manager.Store.SaveNetworkResource(ctx, res))
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"disabled resource must still resolve the routing peer: Enabled must not gate affected-peers")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_DisabledRule(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
ctx := context.Background()
|
||||
|
||||
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
|
||||
policy.Rules[0].Enabled = false
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID,
|
||||
"disabled rule must still resolve the routing peer: Enabled must not gate affected-peers")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_MultiRule(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
second := s.addSecondTopology(t, "c")
|
||||
ctx := context.Background()
|
||||
|
||||
policy := &types.Policy{
|
||||
Enabled: true,
|
||||
Name: "multi-rule-two-resources",
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{s.sourceGroupID},
|
||||
Destinations: []string{s.resourceGroupID},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{s.sourceGroupID},
|
||||
Destinations: []string{second.resourceGroupID},
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, policy)
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID, "routing peer for resource A must be affected")
|
||||
assert.Contains(t, affected, second.routerPeerID, "routing peer for resource B must be affected")
|
||||
}
|
||||
|
||||
func TestAffectedPeers_RouterOtherNetwork(t *testing.T) {
|
||||
s := setupRouterScenario(t, true)
|
||||
second := s.addSecondTopology(t, "d")
|
||||
ctx := context.Background()
|
||||
|
||||
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
|
||||
|
||||
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
|
||||
assert.NotContains(t, affected, second.routerPeerID,
|
||||
"a router in an unrelated network must not be affected by a policy that does not target its resource")
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,825 +0,0 @@
|
||||
// Package affectedpeers computes which peers' network maps a change touches, so
|
||||
// only those peers are refreshed instead of the whole account.
|
||||
//
|
||||
// Two phases keep the dependency walk off the write transaction:
|
||||
// - Load: reads the needed collections. Call INSIDE the mutating tx (consistent,
|
||||
// and before a delete/removal severs the old state).
|
||||
// - Snapshot.Expand: in-memory walk, no store access. Run AFTER the tx commits.
|
||||
//
|
||||
// Enabled is never consulted: toggling it is itself an observable change.
|
||||
package affectedpeers
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// Snapshot is an in-memory view of the collections needed to expand a Change.
|
||||
// Loaded in-tx, walked by Expand after commit. Only the collections the Change
|
||||
// can touch are loaded; the rest stay nil (see Load).
|
||||
type Snapshot struct {
|
||||
policies []*types.Policy
|
||||
routes []*route.Route
|
||||
nsGroups []*nbdns.NameServerGroup
|
||||
dnsSettings *types.DNSSettings
|
||||
routers []*routerTypes.NetworkRouter
|
||||
resources []*resourceTypes.NetworkResource
|
||||
services []*rpservice.Service
|
||||
proxyByCluster map[string][]string
|
||||
groups map[string]*types.Group
|
||||
groupPeers map[string]map[string]struct{} // groupID -> member peer IDs
|
||||
}
|
||||
|
||||
// Load reads the collections a Change requires, inside the caller's tx. It mirrors
|
||||
// Expand's walker preconditions, loading only what the change can touch.
|
||||
func Load(ctx context.Context, s store.Store, accountID string, c Change) (*Snapshot, error) {
|
||||
snap := &Snapshot{}
|
||||
if c.isEmpty() {
|
||||
return snap, nil
|
||||
}
|
||||
|
||||
if err := snap.loadCollections(ctx, s, accountID, c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := snap.loadGroupIndex(ctx, s, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return snap, nil
|
||||
}
|
||||
|
||||
// loadCollections reads the policy/route/nameserver/dns/router/resource/proxy
|
||||
// collections a Change can touch, gated to what the walk needs.
|
||||
func (snap *Snapshot) loadCollections(ctx context.Context, s store.Store, accountID string, c Change) error {
|
||||
hasGroupOrPeerChange := len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.Resources) > 0
|
||||
hasNetworkObject := len(c.Routers) > 0 || len(c.Resources) > 0 || len(c.Networks) > 0
|
||||
// the resource<->router bridge can fire for any of these
|
||||
needsRoutersResources := hasGroupOrPeerChange || len(c.PostureCheckIDs) > 0 || len(c.Policies) > 0 || hasNetworkObject
|
||||
|
||||
if needsRoutersResources {
|
||||
if err := snap.loadPolicyRoutersResources(ctx, s, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if hasGroupOrPeerChange {
|
||||
if err := snap.loadRoutesAndProxy(ctx, s, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 {
|
||||
if err := snap.loadDNS(ctx, s, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadPolicyRoutersResources loads the policies plus the routers and resources
|
||||
// the resource<->router bridge walks.
|
||||
func (snap *Snapshot) loadPolicyRoutersResources(ctx context.Context, s store.Store, accountID string) error {
|
||||
var err error
|
||||
if snap.policies, err = s.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
if snap.routers, err = s.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
snap.resources, err = s.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
return err
|
||||
}
|
||||
|
||||
// loadRoutesAndProxy loads the routes and the embedded-proxy services index.
|
||||
func (snap *Snapshot) loadRoutesAndProxy(ctx context.Context, s store.Store, accountID string) error {
|
||||
var err error
|
||||
if snap.routes, err = s.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
return snap.loadProxyServices(ctx, s, accountID)
|
||||
}
|
||||
|
||||
// loadDNS loads the nameserver groups and account DNS settings.
|
||||
func (snap *Snapshot) loadDNS(ctx context.Context, s store.Store, accountID string) error {
|
||||
var err error
|
||||
if snap.nsGroups, err = s.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
snap.dnsSettings, err = s.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
return err
|
||||
}
|
||||
|
||||
// loadProxyServices loads the embedded-proxy cluster index, and the services only
|
||||
// when the account actually has embedded proxy peers.
|
||||
func (snap *Snapshot) loadProxyServices(ctx context.Context, s store.Store, accountID string) error {
|
||||
var err error
|
||||
if snap.proxyByCluster, err = s.GetEmbeddedProxyPeerIDsByCluster(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(snap.proxyByCluster) == 0 {
|
||||
return nil
|
||||
}
|
||||
snap.services, err = s.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
return err
|
||||
}
|
||||
|
||||
// loadGroupIndex loads all groups (for group.Resources) and builds the
|
||||
// group->member-peers index. Always needed: the bridge resolves group.Resources
|
||||
// and Expand maps groups to member peers.
|
||||
func (snap *Snapshot) loadGroupIndex(ctx context.Context, s store.Store, accountID string) error {
|
||||
groups, err := s.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
snap.groups = make(map[string]*types.Group, len(groups))
|
||||
snap.groupPeers = make(map[string]map[string]struct{}, len(groups))
|
||||
for _, g := range groups {
|
||||
snap.groups[g.ID] = g
|
||||
members := make(map[string]struct{}, len(g.Peers))
|
||||
for _, pID := range g.Peers {
|
||||
members[pID] = struct{}{}
|
||||
}
|
||||
snap.groupPeers[g.ID] = members
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Change describes what changed in an account.
|
||||
type Change struct {
|
||||
ChangedGroupIDs []string
|
||||
ChangedPeerIDs []string
|
||||
Policies []*types.Policy
|
||||
Routes []*route.Route
|
||||
Routers []*routerTypes.NetworkRouter
|
||||
Resources []*resourceTypes.NetworkResource
|
||||
Networks []*networkTypes.Network
|
||||
PostureCheckIDs []string
|
||||
|
||||
// DistributionGroupIDs are groups whose members are directly affected, with no
|
||||
// dependency walk — the change distributes config to the groups' member peers
|
||||
// only (nameserver groups, DNS DisabledManagementGroups), not through the
|
||||
// policy/route reachability graph. Pass old∪new so both states refresh.
|
||||
DistributionGroupIDs []string
|
||||
|
||||
// RemovedPeersByGroup: peers that left a group, keyed by that group. They are no
|
||||
// longer in the group's member index but still lose its reachability, so they are
|
||||
// folded in — but only when the group is linked (an unlinked group has no map
|
||||
// impact), matching how current members are handled.
|
||||
RemovedPeersByGroup map[string][]string
|
||||
}
|
||||
|
||||
func (c Change) isEmpty() bool {
|
||||
return len(c.ChangedGroupIDs) == 0 &&
|
||||
len(c.ChangedPeerIDs) == 0 &&
|
||||
len(c.Policies) == 0 &&
|
||||
len(c.Routes) == 0 &&
|
||||
len(c.Routers) == 0 &&
|
||||
len(c.Resources) == 0 &&
|
||||
len(c.Networks) == 0 &&
|
||||
len(c.PostureCheckIDs) == 0 &&
|
||||
len(c.DistributionGroupIDs) == 0 &&
|
||||
len(c.RemovedPeersByGroup) == 0
|
||||
}
|
||||
|
||||
// Expand returns the deduplicated affected peer IDs from the preloaded Snapshot,
|
||||
// no store access. Run after the producing tx commits. Logs the full walk at
|
||||
// trace level for diagnosing a miscalculation.
|
||||
func (snap *Snapshot) Expand(ctx context.Context, accountID string, c Change) []string {
|
||||
if c.isEmpty() {
|
||||
return nil
|
||||
}
|
||||
r := newResolver(ctx, snap, accountID, c)
|
||||
log.WithContext(ctx).Tracef("affectedpeers expand start: account=%s changedGroups=%v changedPeers=%v policies=%d routes=%d routers=%d resources=%d networks=%d postureChecks=%v distributionGroups=%v",
|
||||
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, len(c.Policies), len(c.Routes), len(c.Routers), len(c.Resources), len(c.Networks), c.PostureCheckIDs, c.DistributionGroupIDs)
|
||||
r.walk()
|
||||
return r.expand()
|
||||
}
|
||||
|
||||
// Collect returns the affected group and direct-peer IDs without expanding groups
|
||||
// to members. Test-only introspection; use Resolve otherwise.
|
||||
func Collect(ctx context.Context, s store.Store, accountID string, c Change) (groupIDs []string, directPeerIDs []string) {
|
||||
if c.isEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
snap, err := Load(ctx, s, accountID, c)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to load snapshot for affected peers collect: %v", err)
|
||||
return nil, nil
|
||||
}
|
||||
r := newResolver(ctx, snap, accountID, c)
|
||||
r.walk()
|
||||
return setToSlice(r.groupSet), setToSlice(r.peerSet)
|
||||
}
|
||||
|
||||
func newResolver(ctx context.Context, snap *Snapshot, accountID string, c Change) *resolver {
|
||||
r := &resolver{
|
||||
ctx: ctx,
|
||||
snap: snap,
|
||||
accountID: accountID,
|
||||
change: c,
|
||||
changedGroupSet: toSet(c.ChangedGroupIDs),
|
||||
changedPeerSet: toSet(c.ChangedPeerIDs),
|
||||
groupSet: make(map[string]struct{}),
|
||||
peerSet: make(map[string]struct{}),
|
||||
networkIDs: make(map[string]struct{}),
|
||||
}
|
||||
// Resolve each changed peer to its groups here so callers pass only ChangedPeerIDs.
|
||||
r.seedChangedGroupsFromPeers()
|
||||
r.matchedPolicies = append(r.matchedPolicies, c.Policies...)
|
||||
return r
|
||||
}
|
||||
|
||||
// seedChangedGroupsFromPeers adds each changed peer's groups to changedGroupSet so
|
||||
// the group-driven walkers fire for memberships, not just direct peer references.
|
||||
func (r *resolver) seedChangedGroupsFromPeers() {
|
||||
if len(r.changedPeerSet) == 0 {
|
||||
return
|
||||
}
|
||||
for groupID, members := range r.snap.groupPeers {
|
||||
for pID := range r.changedPeerSet {
|
||||
if _, ok := members[pID]; ok {
|
||||
r.changedGroupSet[groupID] = struct{}{}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) walk() {
|
||||
r.collectFromExplicitPolicies()
|
||||
r.collectFromExplicitRoutes(r.change.Routes)
|
||||
r.collectFromExplicitRouters(r.change.Routers)
|
||||
r.collectFromExplicitResources(r.change.Resources)
|
||||
r.collectFromExplicitNetworks(r.change.Networks)
|
||||
r.collectFromPostureChecks(r.change.PostureCheckIDs)
|
||||
|
||||
// Distribution groups (nameserver/DNS) affect only their member peers: fold them
|
||||
// straight into groupSet so expand() maps them to members, without the policy/
|
||||
// route walk that changedGroupSet would trigger.
|
||||
addAll(r.groupSet, r.change.DistributionGroupIDs)
|
||||
|
||||
if len(r.changedGroupSet) > 0 || len(r.changedPeerSet) > 0 {
|
||||
r.collectFromPolicies()
|
||||
r.collectFromRoutes()
|
||||
r.collectFromNameServers()
|
||||
r.collectFromDNSSettings()
|
||||
r.collectFromNetworkRouters()
|
||||
r.collectFromProxyServices()
|
||||
}
|
||||
|
||||
r.collectResourceRouterBridge()
|
||||
}
|
||||
|
||||
type resolver struct {
|
||||
ctx context.Context
|
||||
snap *Snapshot
|
||||
accountID string
|
||||
change Change
|
||||
|
||||
changedGroupSet map[string]struct{}
|
||||
changedPeerSet map[string]struct{}
|
||||
|
||||
groupSet map[string]struct{}
|
||||
peerSet map[string]struct{}
|
||||
|
||||
matchedPolicies []*types.Policy
|
||||
networkIDs map[string]struct{}
|
||||
}
|
||||
|
||||
func (r *resolver) policies() []*types.Policy { return r.snap.policies }
|
||||
|
||||
func (r *resolver) networkResources() []*resourceTypes.NetworkResource { return r.snap.resources }
|
||||
|
||||
func (r *resolver) networkRouters() []*routerTypes.NetworkRouter { return r.snap.routers }
|
||||
|
||||
// peerIDsForGroups maps a group set to its member peer IDs via the preloaded index.
|
||||
func (r *resolver) peerIDsForGroups(groupSet map[string]struct{}) []string {
|
||||
seen := make(map[string]struct{})
|
||||
var ids []string
|
||||
for gID := range groupSet {
|
||||
for pID := range r.snap.groupPeers[gID] {
|
||||
if _, ok := seen[pID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[pID] = struct{}{}
|
||||
ids = append(ids, pID)
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func (r *resolver) expand() []string {
|
||||
peerIDs := r.peerIDsForGroups(r.groupSet)
|
||||
|
||||
log.WithContext(r.ctx).Tracef("affectedpeers expand: account=%s affectedGroups=%v -> %d group-member peers; direct peers=%v",
|
||||
r.accountID, setToSlice(r.groupSet), len(peerIDs), setToSlice(r.peerSet))
|
||||
|
||||
seen := make(map[string]struct{}, len(peerIDs))
|
||||
for _, id := range peerIDs {
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
for id := range r.peerSet {
|
||||
if _, ok := seen[id]; !ok {
|
||||
peerIDs = append(peerIDs, id)
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Fold in removed peers only when their group is linked (in groupSet).
|
||||
for groupID, removed := range r.change.RemovedPeersByGroup {
|
||||
if _, linked := r.groupSet[groupID]; !linked {
|
||||
continue
|
||||
}
|
||||
for _, id := range removed {
|
||||
if _, ok := seen[id]; !ok {
|
||||
peerIDs = append(peerIDs, id)
|
||||
seen[id] = struct{}{}
|
||||
log.WithContext(r.ctx).Tracef("affectedpeers expand: removed peer %s from linked group %s -> affected", id, groupID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(r.ctx).Tracef("affectedpeers expand done: account=%s -> %d affected peers: %v", r.accountID, len(peerIDs), peerIDs)
|
||||
return peerIDs
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromExplicitPolicies() {
|
||||
for _, policy := range r.matchedPolicies {
|
||||
if policy == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitPolicies: changed policy %s (%s) -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromExplicitRoutes(routes []*route.Route) {
|
||||
for _, rt := range routes {
|
||||
if rt == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
|
||||
rt.ID, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
|
||||
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
if rt.Peer != "" {
|
||||
r.peerSet[rt.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromExplicitRouters folds changed routers' peers and marks their networks
|
||||
// for the bridge. Passing the old router keeps a repointed router's previous peers
|
||||
// affected without a post-commit read.
|
||||
func (r *resolver) collectFromExplicitRouters(routers []*routerTypes.NetworkRouter) {
|
||||
for _, router := range routers {
|
||||
if router == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitRouters: changed router %s on network %s -> folding peerGroups=%v peer=%q and marking network for source bridge",
|
||||
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
if router.NetworkID != "" {
|
||||
r.networkIDs[router.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromExplicitResources marks changed resources' networks for the bridge and
|
||||
// treats their group IDs as changed, so policies targeting the resource via a
|
||||
// now-detached (old) group still refresh.
|
||||
func (r *resolver) collectFromExplicitResources(resources []*resourceTypes.NetworkResource) {
|
||||
for _, resource := range resources {
|
||||
if resource == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitResources: changed resource %s on network %s -> marking network for bridge and treating groups %v as changed",
|
||||
resource.ID, resource.NetworkID, resource.GroupIDs)
|
||||
addAll(r.changedGroupSet, resource.GroupIDs)
|
||||
if resource.NetworkID != "" {
|
||||
r.networkIDs[resource.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectFromExplicitNetworks marks changed networks for the bridge. A network has
|
||||
// no groups/peers of its own.
|
||||
func (r *resolver) collectFromExplicitNetworks(networks []*networkTypes.Network) {
|
||||
for _, network := range networks {
|
||||
if network == nil {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromExplicitNetworks: changed network %s -> marking for bridge", network.ID)
|
||||
if network.ID != "" {
|
||||
r.networkIDs[network.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromPostureChecks(postureCheckIDs []string) {
|
||||
if len(postureCheckIDs) == 0 {
|
||||
return
|
||||
}
|
||||
ids := toSet(postureCheckIDs)
|
||||
for _, policy := range r.policies() {
|
||||
if !policyReferencesPostureChecks(policy, ids) {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromPostureChecks: policy %s (%s) references changed posture checks %v -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, postureCheckIDs, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
r.matchedPolicies = append(r.matchedPolicies, policy)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromPolicies() {
|
||||
for _, policy := range r.policies() {
|
||||
matchedByGroup := policyReferencesGroups(policy, r.changedGroupSet)
|
||||
matchedByPeer := len(r.changedPeerSet) > 0 && policyReferencesDirectPeers(policy, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromPolicies: policy %s (%s) matched (byGroup=%t byPeer=%t) -> folding rule groups %v + direct peers",
|
||||
policy.ID, policy.Name, matchedByGroup, matchedByPeer, policy.RuleGroups())
|
||||
addAll(r.groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, r.peerSet)
|
||||
r.matchedPolicies = append(r.matchedPolicies, policy)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromRoutes() {
|
||||
for _, rt := range r.snap.routes {
|
||||
matchedByGroup := anyInSet(rt.Groups, r.changedGroupSet) || anyInSet(rt.PeerGroups, r.changedGroupSet) || anyInSet(rt.AccessControlGroups, r.changedGroupSet)
|
||||
matchedByPeer := rt.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(rt.Peer, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromRoutes: route %s matched (byGroup=%t byPeer=%t) -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
|
||||
rt.ID, matchedByGroup, matchedByPeer, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
|
||||
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
|
||||
if rt.Peer != "" {
|
||||
r.peerSet[rt.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromNameServers() {
|
||||
if len(r.changedGroupSet) == 0 {
|
||||
return
|
||||
}
|
||||
for _, ns := range r.snap.nsGroups {
|
||||
if anyInSet(ns.Groups, r.changedGroupSet) {
|
||||
log.WithContext(r.ctx).Tracef("collectFromNameServers: nameserver group %s references a changed group -> folding its groups %v", ns.ID, ns.Groups)
|
||||
addAll(r.groupSet, ns.Groups)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromDNSSettings() {
|
||||
if len(r.changedGroupSet) == 0 || r.snap.dnsSettings == nil {
|
||||
return
|
||||
}
|
||||
for _, gID := range r.snap.dnsSettings.DisabledManagementGroups {
|
||||
if _, ok := r.changedGroupSet[gID]; ok {
|
||||
log.WithContext(r.ctx).Tracef("collectFromDNSSettings: changed group %s is in DisabledManagementGroups -> folding it", gID)
|
||||
r.groupSet[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromNetworkRouters() {
|
||||
for _, router := range r.networkRouters() {
|
||||
matchedByGroup := anyInSet(router.PeerGroups, r.changedGroupSet)
|
||||
matchedByPeer := router.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(router.Peer, r.changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding peerGroups=%v peer=%q and marking network for source bridge",
|
||||
router.ID, router.NetworkID, matchedByGroup, matchedByPeer, router.PeerGroups, router.Peer)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
r.networkIDs[router.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) collectFromProxyServices() {
|
||||
if len(r.snap.proxyByCluster) == 0 || len(r.snap.services) == 0 {
|
||||
return
|
||||
}
|
||||
services, proxyByCluster := r.snap.services, r.snap.proxyByCluster
|
||||
|
||||
expanded := r.expandChangedPeersWithGroups()
|
||||
|
||||
for _, svc := range services {
|
||||
if svc == nil {
|
||||
continue
|
||||
}
|
||||
proxyPeers := proxyByCluster[svc.ProxyCluster]
|
||||
if len(proxyPeers) == 0 {
|
||||
continue
|
||||
}
|
||||
matchedByPeer := serviceMatchesChangedPeers(svc, proxyPeers, expanded)
|
||||
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.changedGroupSet)
|
||||
if !matchedByPeer && !matchedByAccessGroup {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("collectFromProxyServices: service %s (cluster=%s) matched (byProxyOrTargetPeer=%t byAccessGroup=%t) -> folding %d proxy peers, peer targets and access groups %v",
|
||||
svc.ID, svc.ProxyCluster, matchedByPeer, matchedByAccessGroup, len(proxyPeers), svc.AccessGroups)
|
||||
for _, pid := range proxyPeers {
|
||||
r.peerSet[pid] = struct{}{}
|
||||
}
|
||||
for _, target := range svc.Targets {
|
||||
if target.TargetType == rpservice.TargetTypePeer && target.TargetId != "" {
|
||||
r.peerSet[target.TargetId] = struct{}{}
|
||||
}
|
||||
}
|
||||
addAll(r.groupSet, svc.AccessGroups)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
|
||||
if len(r.changedGroupSet) == 0 {
|
||||
return r.changedPeerSet
|
||||
}
|
||||
ids := r.peerIDsForGroups(r.changedGroupSet)
|
||||
if len(ids) == 0 {
|
||||
return r.changedPeerSet
|
||||
}
|
||||
merged := make(map[string]struct{}, len(r.changedPeerSet)+len(ids))
|
||||
for id := range r.changedPeerSet {
|
||||
merged[id] = struct{}{}
|
||||
}
|
||||
for _, id := range ids {
|
||||
merged[id] = struct{}{}
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
// collectResourceRouterBridge crosses between source peers and routing peers, which
|
||||
// are reachable only via resource -> network -> router, not through the policy's own
|
||||
// groups: source -> router (targeted resources' networks), then router -> source.
|
||||
func (r *resolver) collectResourceRouterBridge() {
|
||||
r.bridgeSourceToRouters()
|
||||
r.bridgeRoutersToSources()
|
||||
}
|
||||
|
||||
func (r *resolver) bridgeSourceToRouters() {
|
||||
resourceIDs := r.policyDestinationResourceIDs(r.matchedPolicies...)
|
||||
if len(resourceIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
networkIDs := r.resourceNetworkIDs(resourceIDs)
|
||||
log.WithContext(r.ctx).Tracef("bridgeSourceToRouters: targeted resources %v -> networks %v (their routers become affected via the router->source pass)",
|
||||
setToSlice(resourceIDs), setToSlice(networkIDs))
|
||||
for id := range networkIDs {
|
||||
r.networkIDs[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) bridgeRoutersToSources() {
|
||||
if len(r.networkIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: affected networks %v -> folding their routing peers and the source peers of policies targeting their resources",
|
||||
setToSlice(r.networkIDs))
|
||||
|
||||
r.foldRoutersOnNetworks(r.networkIDs)
|
||||
|
||||
resourceIDs := make(map[string]struct{})
|
||||
for _, resource := range r.networkResources() {
|
||||
if _, ok := r.networkIDs[resource.NetworkID]; ok {
|
||||
resourceIDs[resource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(resourceIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, policy := range r.policies() {
|
||||
if r.policyTargetsResources(policy, resourceIDs) {
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: policy %s (%s) targets an affected-network resource -> folding its source groups/peers", policy.ID, policy.Name)
|
||||
collectPolicySources(policy, r.groupSet, r.peerSet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
|
||||
for _, router := range r.networkRouters() {
|
||||
if _, ok := networkIDs[router.NetworkID]; !ok {
|
||||
continue
|
||||
}
|
||||
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: router %s serves affected network %s -> folding peerGroups=%v peer=%q",
|
||||
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
|
||||
addAll(r.groupSet, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
r.peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *resolver) resourceNetworkIDs(resourceIDs map[string]struct{}) map[string]struct{} {
|
||||
networkIDs := make(map[string]struct{})
|
||||
for _, resource := range r.networkResources() {
|
||||
if _, ok := resourceIDs[resource.ID]; ok {
|
||||
networkIDs[resource.NetworkID] = struct{}{}
|
||||
}
|
||||
}
|
||||
return networkIDs
|
||||
}
|
||||
|
||||
func (r *resolver) policyTargetsResources(policy *types.Policy, resourceIDs map[string]struct{}) bool {
|
||||
if policy == nil {
|
||||
return false
|
||||
}
|
||||
destGroupSet := make(map[string]struct{})
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.DestinationResource.Type != types.ResourceTypePeer && isInSet(rule.DestinationResource.ID, resourceIDs) {
|
||||
return true
|
||||
}
|
||||
for _, gID := range rule.Destinations {
|
||||
destGroupSet[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
if len(destGroupSet) == 0 {
|
||||
return false
|
||||
}
|
||||
for gID := range destGroupSet {
|
||||
group := r.snap.groups[gID]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
for _, res := range group.Resources {
|
||||
if isInSet(res.ID, resourceIDs) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *resolver) policyDestinationResourceIDs(policies ...*types.Policy) map[string]struct{} {
|
||||
resourceIDs := make(map[string]struct{})
|
||||
destGroupSet := collectPolicyDestinations(resourceIDs, policies...)
|
||||
r.addGroupResourceIDs(destGroupSet, resourceIDs)
|
||||
return resourceIDs
|
||||
}
|
||||
|
||||
// collectPolicyDestinations adds direct destination resource IDs to resourceIDs and
|
||||
// returns the referenced destination group IDs.
|
||||
func collectPolicyDestinations(resourceIDs map[string]struct{}, policies ...*types.Policy) map[string]struct{} {
|
||||
destGroupSet := make(map[string]struct{})
|
||||
for _, policy := range policies {
|
||||
if policy == nil {
|
||||
continue
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
addAll(destGroupSet, rule.Destinations)
|
||||
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
resourceIDs[rule.DestinationResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
return destGroupSet
|
||||
}
|
||||
|
||||
// addGroupResourceIDs folds the resource IDs of the given groups into resourceIDs.
|
||||
func (r *resolver) addGroupResourceIDs(groupIDs map[string]struct{}, resourceIDs map[string]struct{}) {
|
||||
for gID := range groupIDs {
|
||||
group := r.snap.groups[gID]
|
||||
if group == nil {
|
||||
continue
|
||||
}
|
||||
for _, res := range group.Resources {
|
||||
if res.ID != "" {
|
||||
resourceIDs[res.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{}) {
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
peerSet[rule.DestinationResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectPolicySources(policy *types.Policy, groupSet, peerSet map[string]struct{}) {
|
||||
for _, rule := range policy.Rules {
|
||||
addAll(groupSet, rule.Sources)
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func policyReferencesPostureChecks(policy *types.Policy, ids map[string]struct{}) bool {
|
||||
for _, id := range policy.SourcePostureChecks {
|
||||
if _, ok := ids[id]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isDirectPeerInSet(res types.Resource, set map[string]struct{}) bool {
|
||||
if res.Type != types.ResourceTypePeer || res.ID == "" {
|
||||
return false
|
||||
}
|
||||
_, ok := set[res.ID]
|
||||
return ok
|
||||
}
|
||||
|
||||
func serviceMatchesChangedPeers(svc *rpservice.Service, proxyPeers []string, changedPeers map[string]struct{}) bool {
|
||||
for _, pid := range proxyPeers {
|
||||
if _, ok := changedPeers[pid]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, target := range svc.Targets {
|
||||
if target.TargetType != rpservice.TargetTypePeer || target.TargetId == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := changedPeers[target.TargetId]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func anyInSet(ids []string, set map[string]struct{}) bool {
|
||||
for _, id := range ids {
|
||||
if _, ok := set[id]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isInSet(id string, set map[string]struct{}) bool {
|
||||
_, ok := set[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
func addAll(set map[string]struct{}, slices ...[]string) {
|
||||
for _, s := range slices {
|
||||
for _, id := range s {
|
||||
set[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func toSet(ids []string) map[string]struct{} {
|
||||
set := make(map[string]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
set[id] = struct{}{}
|
||||
}
|
||||
return set
|
||||
}
|
||||
|
||||
func setToSlice(set map[string]struct{}) []string {
|
||||
s := make([]string, 0, len(set))
|
||||
for id := range set {
|
||||
s = append(s, id)
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -1,140 +0,0 @@
|
||||
package affectedpeers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// policyGroupsAndPeers mirrors the explicit-policy extraction (RuleGroups +
|
||||
// direct peers) the resolver folds in, for asserting the pure logic.
|
||||
func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []string) {
|
||||
peerSet := map[string]struct{}{}
|
||||
for _, p := range policies {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
groups = append(groups, p.RuleGroups()...)
|
||||
collectPolicyDirectPeers(p, peerSet)
|
||||
}
|
||||
for id := range peerSet {
|
||||
peers = append(peers, id)
|
||||
}
|
||||
return groups, peers
|
||||
}
|
||||
|
||||
func TestPolicyGroupsAndPeers_Basic(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
|
||||
groups, peers := policyGroupsAndPeers(policy)
|
||||
assert.ElementsMatch(t, []string{"g1", "g2", "g3"}, groups)
|
||||
assert.Empty(t, peers)
|
||||
}
|
||||
|
||||
func TestPolicyGroupsAndPeers_WithPeerResources(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
Sources: []string{"g1"},
|
||||
SourceResource: types.Resource{ID: "p1", Type: types.ResourceTypePeer},
|
||||
Destinations: []string{"g2"},
|
||||
DestinationResource: types.Resource{ID: "p2", Type: types.ResourceTypePeer},
|
||||
}}}
|
||||
groups, peers := policyGroupsAndPeers(policy)
|
||||
assert.ElementsMatch(t, []string{"g1", "g2"}, groups)
|
||||
assert.ElementsMatch(t, []string{"p1", "p2"}, peers)
|
||||
}
|
||||
|
||||
func TestPolicyGroupsAndPeers_NilPolicy(t *testing.T) {
|
||||
groups, peers := policyGroupsAndPeers(nil)
|
||||
assert.Nil(t, groups)
|
||||
assert.Nil(t, peers)
|
||||
}
|
||||
|
||||
func TestPolicyGroupsAndPeers_MultiplePolicies(t *testing.T) {
|
||||
old := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1"}, Destinations: []string{"g2"}}}}
|
||||
updated := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g3"}, Destinations: []string{"g4"}}}}
|
||||
groups, _ := policyGroupsAndPeers(updated, old)
|
||||
assert.ElementsMatch(t, []string{"g1", "g2", "g3", "g4"}, groups)
|
||||
}
|
||||
|
||||
func TestPolicyGroupsAndPeers_NonPeerResource(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
Sources: []string{"g1"},
|
||||
SourceResource: types.Resource{ID: "domain-1", Type: types.ResourceTypeDomain},
|
||||
Destinations: []string{"g2"},
|
||||
}}}
|
||||
groups, peers := policyGroupsAndPeers(policy)
|
||||
assert.ElementsMatch(t, []string{"g1", "g2"}, groups)
|
||||
assert.Empty(t, peers, "domain resource type should not produce direct peer IDs")
|
||||
}
|
||||
|
||||
func TestChangeIsEmpty(t *testing.T) {
|
||||
assert.True(t, Change{}.isEmpty())
|
||||
assert.False(t, Change{ChangedGroupIDs: []string{"g"}}.isEmpty())
|
||||
assert.False(t, Change{ChangedPeerIDs: []string{"p"}}.isEmpty())
|
||||
assert.False(t, Change{Policies: []*types.Policy{{}}}.isEmpty())
|
||||
assert.False(t, Change{Resources: []*resourceTypes.NetworkResource{{ID: "r"}}}.isEmpty())
|
||||
assert.False(t, Change{Networks: []*networkTypes.Network{{ID: "n"}}}.isEmpty())
|
||||
assert.False(t, Change{PostureCheckIDs: []string{"pc"}}.isEmpty())
|
||||
}
|
||||
|
||||
func TestPolicyReferencesGroups(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
|
||||
|
||||
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g1": {}}))
|
||||
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g3": {}}))
|
||||
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{"g4": {}}))
|
||||
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{}))
|
||||
}
|
||||
|
||||
func TestPolicyReferencesDirectPeers(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
|
||||
}}}
|
||||
|
||||
assert.True(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p1": {}}))
|
||||
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"r1": {}}))
|
||||
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p2": {}}))
|
||||
}
|
||||
|
||||
func TestPolicyReferencesPostureChecks(t *testing.T) {
|
||||
policy := &types.Policy{SourcePostureChecks: []string{"pc1", "pc2"}}
|
||||
|
||||
assert.True(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc1": {}}))
|
||||
assert.False(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc3": {}}))
|
||||
}
|
||||
|
||||
func TestCollectPolicyDirectPeers(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypePeer, ID: "p2"},
|
||||
}, {
|
||||
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
|
||||
}}}
|
||||
|
||||
peerSet := map[string]struct{}{}
|
||||
collectPolicyDirectPeers(policy, peerSet)
|
||||
|
||||
assert.Contains(t, peerSet, "p1")
|
||||
assert.Contains(t, peerSet, "p2")
|
||||
assert.NotContains(t, peerSet, "r1")
|
||||
}
|
||||
|
||||
func TestCollectPolicySources(t *testing.T) {
|
||||
policy := &types.Policy{Rules: []*types.PolicyRule{{
|
||||
Sources: []string{"g1"},
|
||||
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
|
||||
Destinations: []string{"g2"},
|
||||
}}}
|
||||
|
||||
groupSet := map[string]struct{}{}
|
||||
peerSet := map[string]struct{}{}
|
||||
collectPolicySources(policy, groupSet, peerSet)
|
||||
|
||||
assert.Contains(t, groupSet, "g1")
|
||||
assert.NotContains(t, groupSet, "g2", "destination groups must not be collected as sources")
|
||||
assert.Contains(t, peerSet, "p1")
|
||||
}
|
||||
@@ -12,7 +12,6 @@ const (
|
||||
RoleKey = nbcontext.RoleKey
|
||||
UserIDKey = nbcontext.UserIDKey
|
||||
PeerIDKey = nbcontext.PeerIDKey
|
||||
UserAgentKey = nbcontext.UserAgentKey
|
||||
)
|
||||
|
||||
// RoleFromContext returns the role stored in ctx, or empty string and false if absent.
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -48,9 +47,8 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var updateAccountPeers bool
|
||||
var eventsToStore []func()
|
||||
var snap *affectedpeers.Snapshot
|
||||
var change affectedpeers.Change
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
|
||||
@@ -65,6 +63,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
addedGroups := util.Difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||
removedGroups := util.Difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
|
||||
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
@@ -72,11 +75,6 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
return err
|
||||
}
|
||||
|
||||
change = affectedpeers.Change{DistributionGroupIDs: slices.Concat(addedGroups, removedGroups)}
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -87,7 +85,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceDNSSettings, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -133,6 +133,20 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
|
||||
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, addedGroups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, accountID, removedGroups)
|
||||
}
|
||||
|
||||
// validateDNSSettings validates the DNS settings.
|
||||
func validateDNSSettings(ctx context.Context, transaction store.Store, accountID string, settings *types.DNSSettings) error {
|
||||
if len(settings.DisabledManagementGroups) == 0 {
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
@@ -80,8 +79,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var snap *affectedpeers.Snapshot
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{newGroup.ID}}
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
@@ -93,6 +91,11 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to create group: %v", err)
|
||||
}
|
||||
@@ -103,11 +106,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
}
|
||||
|
||||
snap, err = affectedpeers.Load(ctx, transaction, accountID, change)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -118,7 +116,9 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -134,8 +134,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var snap *affectedpeers.Snapshot
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{newGroup.ID}}
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
@@ -154,7 +153,20 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
|
||||
peersToAdd := util.Difference(newGroup.Peers, oldGroup.Peers)
|
||||
peersToRemove := util.Difference(oldGroup.Peers, newGroup.Peers)
|
||||
if err = syncGroupMembership(ctx, transaction, accountID, newGroup.ID, peersToAdd, peersToRemove); err != nil {
|
||||
|
||||
for _, peerID := range peersToAdd {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, newGroup.ID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err)
|
||||
}
|
||||
}
|
||||
for _, peerID := range peersToRemove {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peerID, newGroup.ID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to remove peer %s from group %s: %v", peerID, newGroup.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -166,17 +178,6 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
return err
|
||||
}
|
||||
|
||||
// A membership change does not alter which entities reference the group, so
|
||||
// the dependency walk runs once against the post-change snapshot. The new
|
||||
// members are already in the snapshot's index; the removed members are
|
||||
// carried separately and folded in only when the group is linked.
|
||||
if len(peersToRemove) > 0 {
|
||||
change.RemovedPeersByGroup = map[string][]string{newGroup.ID: peersToRemove}
|
||||
}
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -187,23 +188,10 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncGroupMembership applies the peer membership delta for a group within a transaction.
|
||||
func syncGroupMembership(ctx context.Context, transaction store.Store, accountID, groupID string, peersToAdd, peersToRemove []string) error {
|
||||
for _, peerID := range peersToAdd {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, groupID, err)
|
||||
}
|
||||
}
|
||||
for _, peerID := range peersToRemove {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to remove peer %s from group %s: %v", peerID, groupID, err)
|
||||
}
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -221,14 +209,11 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var snaps []*affectedpeers.Snapshot
|
||||
var changes []affectedpeers.Change
|
||||
var updateAccountPeers bool
|
||||
|
||||
var globalErr error
|
||||
createdCount := 0
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, newGroup := range groups {
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{newGroup.ID}}
|
||||
var snap *affectedpeers.Snapshot
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
@@ -245,31 +230,35 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
return err
|
||||
}
|
||||
|
||||
groupIDs = append(groupIDs, newGroup.ID)
|
||||
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
snap, err = affectedpeers.Load(ctx, transaction, accountID, change)
|
||||
return err
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err)
|
||||
if createdCount == 0 {
|
||||
if len(groupIDs) == 1 {
|
||||
return err
|
||||
}
|
||||
globalErr = errors.Join(globalErr, err)
|
||||
// continue updating other groups
|
||||
continue
|
||||
}
|
||||
createdCount++
|
||||
snaps = append(snaps, snap)
|
||||
changes = append(changes, change)
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
go am.dispatchAffected(ctx, accountID, snaps, changes)
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate})
|
||||
}
|
||||
|
||||
return globalErr
|
||||
}
|
||||
@@ -288,13 +277,12 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var snaps []*affectedpeers.Snapshot
|
||||
var changes []affectedpeers.Change
|
||||
var updateAccountPeers bool
|
||||
|
||||
var globalErr error
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for _, newGroup := range groups {
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{newGroup.ID}}
|
||||
events, snap, err := am.updateSingleGroup(ctx, accountID, userID, newGroup, change)
|
||||
events, err := am.updateSingleGroup(ctx, accountID, userID, newGroup)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err)
|
||||
if len(groups) == 1 {
|
||||
@@ -304,22 +292,27 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
continue
|
||||
}
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
snaps = append(snaps, snap)
|
||||
changes = append(changes, change)
|
||||
groupIDs = append(groupIDs, newGroup.ID)
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
go am.dispatchAffected(ctx, accountID, snaps, changes)
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
return globalErr
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountID, userID string, newGroup *types.Group, change affectedpeers.Change) ([]func(), *affectedpeers.Snapshot, error) {
|
||||
func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) ([]func(), error) {
|
||||
var events []func()
|
||||
var snap *affectedpeers.Snapshot
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
return err
|
||||
@@ -340,12 +333,9 @@ func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
events = am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
|
||||
var err error
|
||||
snap, err = affectedpeers.Load(ctx, transaction, accountID, change)
|
||||
return err
|
||||
return nil
|
||||
})
|
||||
return events, snap, err
|
||||
return events, err
|
||||
}
|
||||
|
||||
// prepareGroupEvents prepares a list of event functions to be stored.
|
||||
@@ -448,8 +438,6 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
var allErrors error
|
||||
var groupIDsToDelete []string
|
||||
var deletedGroups []*types.Group
|
||||
var snap *affectedpeers.Snapshot
|
||||
var change affectedpeers.Change
|
||||
|
||||
extraSettings, err := am.settingsManager.GetExtraSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
@@ -457,23 +445,26 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
deletedGroups, allErrors = collectDeletableGroups(ctx, transaction, accountID, userID, groupIDs, extraSettings.FlowGroups)
|
||||
for _, group := range deletedGroups {
|
||||
groupIDsToDelete = append(groupIDsToDelete, group.ID)
|
||||
for _, groupID := range groupIDs {
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err = validateDeleteGroup(ctx, transaction, group, userID, extraSettings.FlowGroups); err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
|
||||
groupIDsToDelete = append(groupIDsToDelete, groupID)
|
||||
deletedGroups = append(deletedGroups, group)
|
||||
}
|
||||
|
||||
if len(groupIDsToDelete) == 0 {
|
||||
return allErrors
|
||||
}
|
||||
|
||||
// Delete: compute affected peers from the PRE-delete state. The groups,
|
||||
// their members and the entities referencing them still exist, so a plain
|
||||
// Load+Expand captures everyone — no removed-peer folding needed.
|
||||
change = affectedpeers.Change{ChangedGroupIDs: groupIDsToDelete}
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.DeleteGroups(ctx, accountID, groupIDsToDelete); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -492,47 +483,25 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
|
||||
}
|
||||
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
|
||||
return allErrors
|
||||
}
|
||||
|
||||
// collectDeletableGroups loads and validates each group for deletion, returning
|
||||
// the groups that may be deleted and the joined validation errors for the rest.
|
||||
func collectDeletableGroups(ctx context.Context, transaction store.Store, accountID, userID string, groupIDs, flowGroups []string) ([]*types.Group, error) {
|
||||
var deletable []*types.Group
|
||||
var allErrors error
|
||||
for _, groupID := range groupIDs {
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
if err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
if err = validateDeleteGroup(ctx, transaction, group, userID, flowGroups); err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
deletable = append(deletable, group)
|
||||
}
|
||||
return deletable, allErrors
|
||||
}
|
||||
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
var snap *affectedpeers.Snapshot
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{groupID}}
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil {
|
||||
if err = transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var err error
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -542,7 +511,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
return err
|
||||
}
|
||||
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -550,9 +521,8 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
// GroupAddResource appends resource to the group
|
||||
func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
|
||||
var group *types.Group
|
||||
var snap *affectedpeers.Snapshot
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{groupID}}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
|
||||
@@ -564,11 +534,12 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, group); err != nil {
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
if err = transaction.UpdateGroup(ctx, group); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -578,32 +549,29 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
return err
|
||||
}
|
||||
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GroupDeletePeer removes peer from the group
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
var snap *affectedpeers.Snapshot
|
||||
change := affectedpeers.Change{
|
||||
ChangedGroupIDs: []string{groupID},
|
||||
RemovedPeersByGroup: map[string][]string{groupID: {peerID}},
|
||||
}
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
|
||||
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil {
|
||||
if err = transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// The removed peer is carried in change.RemovedPeersByGroup and folded in
|
||||
// only when the group is linked, so loading post-removal is correct.
|
||||
var err error
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -613,7 +581,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
return err
|
||||
}
|
||||
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -621,9 +591,8 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
// GroupDeleteResource removes resource from the group
|
||||
func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
|
||||
var group *types.Group
|
||||
var snap *affectedpeers.Snapshot
|
||||
var updateAccountPeers bool
|
||||
var err error
|
||||
change := affectedpeers.Change{ChangedGroupIDs: []string{groupID}}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
|
||||
@@ -635,9 +604,8 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load before persisting the removal, so the snapshot still maps the group
|
||||
// to the resource and the bridge can reach its routing peers.
|
||||
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -651,7 +619,9 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -862,103 +832,49 @@ func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store,
|
||||
}
|
||||
|
||||
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
|
||||
// It fetches each collection once and checks all groupIDs against them in memory.
|
||||
func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
groupSet := make(map[string]struct{}, len(groupIDs))
|
||||
for _, id := range groupIDs {
|
||||
groupSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
if affected, err := dnsSettingsReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
|
||||
return affected, err
|
||||
}
|
||||
if affected, err := nameServersReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
|
||||
return affected, err
|
||||
}
|
||||
if affected, err := policiesReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
|
||||
return affected, err
|
||||
}
|
||||
if affected, err := routesReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
|
||||
return affected, err
|
||||
}
|
||||
if affected, err := networkRoutersReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
|
||||
return affected, err
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func dnsSettingsReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return anyInSet(dnsSettings.DisabledManagementGroups, groupSet), nil
|
||||
}
|
||||
|
||||
func nameServersReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, ns := range nameServerGroups {
|
||||
if anyInSet(ns.Groups, groupSet) {
|
||||
for _, groupID := range groupIDs {
|
||||
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
if linked, _ := isGroupLinkedToNetworkRouter(ctx, transaction, accountID, groupID); linked {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func policiesReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
|
||||
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, policy := range policies {
|
||||
for _, rule := range policy.Rules {
|
||||
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func routesReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, r := range routes {
|
||||
if anyInSet(r.Groups, groupSet) || anyInSet(r.PeerGroups, groupSet) || anyInSet(r.AccessControlGroups, groupSet) {
|
||||
for _, group := range groups {
|
||||
if group.HasPeers() || group.HasResources() {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func networkRoutersReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, router := range routers {
|
||||
if anyInSet(router.PeerGroups, groupSet) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func anyInSet(ids []string, set map[string]struct{}) bool {
|
||||
for _, id := range ids {
|
||||
if _, ok := set[id]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/affectedpeers"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
@@ -39,7 +38,7 @@ type MockAccountManager struct {
|
||||
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
|
||||
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
|
||||
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
|
||||
MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
|
||||
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
|
||||
@@ -133,7 +132,6 @@ type MockAccountManager struct {
|
||||
|
||||
AllowSyncFunc func(string, uint64) bool
|
||||
UpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason)
|
||||
ExpandAndUpdateAffectedFunc func(ctx context.Context, accountID string, snap *affectedpeers.Snapshot, change affectedpeers.Change)
|
||||
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason)
|
||||
RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
|
||||
|
||||
@@ -211,12 +209,6 @@ func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID
|
||||
}
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) ExpandAndUpdateAffected(ctx context.Context, accountID string, snap *affectedpeers.Snapshot, change affectedpeers.Change) {
|
||||
if am.ExpandAndUpdateAffectedFunc != nil {
|
||||
am.ExpandAndUpdateAffectedFunc(ctx, accountID, snap, change)
|
||||
}
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
|
||||
if am.BufferUpdateAccountPeersFunc != nil {
|
||||
am.BufferUpdateAccountPeersFunc(ctx, accountID, reason)
|
||||
@@ -345,9 +337,9 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
|
||||
}
|
||||
|
||||
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
|
||||
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||
if am.MarkPeerConnectedFunc != nil {
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
|
||||
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user