mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-13 03:19:54 +00:00
Compare commits
12 Commits
guard-tun-
...
profile-id
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e7b5f7192 | ||
|
|
0ffb9e8535 | ||
|
|
edc60550eb | ||
|
|
7b0f5f42f3 | ||
|
|
23c82b32ff | ||
|
|
f92bc2d325 | ||
|
|
3bb44e72db | ||
|
|
f98fe1e9ec | ||
|
|
c6f003fd18 | ||
|
|
6bdbbcad36 | ||
|
|
df0717dc16 | ||
|
|
33b1cc5449 |
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -24,6 +23,7 @@ const (
|
||||
|
||||
// Profile represents a profile for gomobile
|
||||
type Profile struct {
|
||||
ID string
|
||||
Name string
|
||||
IsActive bool
|
||||
}
|
||||
@@ -53,10 +53,10 @@ func (p *ProfileArray) Get(i int) *Profile {
|
||||
├── state.json ← Default profile state
|
||||
├── active_profile.json ← Active profile tracker (JSON with Name + Username)
|
||||
└── profiles/ ← Subdirectory for non-default profiles
|
||||
├── work.json ← Work profile config
|
||||
├── work.state.json ← Work profile state
|
||||
├── personal.json ← Personal profile config
|
||||
└── personal.state.json ← Personal profile state
|
||||
├── work.json ← Legacy work profile config
|
||||
├── work.state.json ← Legacy work profile state
|
||||
├── 4c5f5c8198c3989cffb5b5394f5a7ae0.json ← ID profile config
|
||||
├── 4c5f5c8198c3989cffb5b5394f5a7ae0.state.json ← ID profile state
|
||||
*/
|
||||
|
||||
// ProfileManager manages profiles for Android
|
||||
@@ -99,6 +99,7 @@ func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
|
||||
var profiles []*Profile
|
||||
for _, p := range internalProfiles {
|
||||
profiles = append(profiles, &Profile{
|
||||
ID: p.ID.String(),
|
||||
Name: p.Name,
|
||||
IsActive: p.IsActive,
|
||||
})
|
||||
@@ -108,55 +109,65 @@ func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
|
||||
}
|
||||
|
||||
// GetActiveProfile returns the currently active profile name
|
||||
func (pm *ProfileManager) GetActiveProfile() (string, error) {
|
||||
func (pm *ProfileManager) GetActiveProfile() (*Profile, error) {
|
||||
// Use ServiceManager to stay consistent with ListProfiles
|
||||
// ServiceManager uses active_profile.json
|
||||
activeState, err := pm.serviceMgr.GetActiveProfileState()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||
return nil, fmt.Errorf("failed to get active profile: %w", err)
|
||||
}
|
||||
return activeState.Name, nil
|
||||
|
||||
// ActiveProfileState only stores the ID (and username), not the display
|
||||
// name. Resolve the ID to the full profile so callers get the real Name.
|
||||
prof, err := pm.serviceMgr.ResolveProfile(activeState.ID.String(), androidUsername)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve active profile %q: %w", activeState.ID, err)
|
||||
}
|
||||
return &Profile{ID: prof.ID.String(), Name: prof.Name, IsActive: true}, nil
|
||||
}
|
||||
|
||||
// SwitchProfile switches to a different profile
|
||||
func (pm *ProfileManager) SwitchProfile(profileName string) error {
|
||||
func (pm *ProfileManager) SwitchProfile(id string) error {
|
||||
// Use ServiceManager to stay consistent with ListProfiles
|
||||
// ServiceManager uses active_profile.json
|
||||
err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: profileName,
|
||||
ID: profilemanager.ID(id),
|
||||
Username: androidUsername,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to switch profile: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("switched to profile: %s", profileName)
|
||||
log.Infof("switched to profile: %s", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddProfile creates a new profile
|
||||
func (pm *ProfileManager) AddProfile(profileName string) error {
|
||||
// Use ServiceManager (creates profile in profiles/ directory)
|
||||
if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil {
|
||||
profile, err := pm.serviceMgr.AddProfile(profileName, androidUsername)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add profile: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("created new profile: %s", profileName)
|
||||
log.Infof("created new profile: %s", profile.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogoutProfile logs out from a profile (clears authentication)
|
||||
func (pm *ProfileManager) LogoutProfile(profileName string) error {
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
|
||||
configPath, err := pm.getProfileConfigPath(profileName)
|
||||
func (pm *ProfileManager) LogoutProfile(id string) error {
|
||||
configPath, err := pm.getProfileConfigPath(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) {
|
||||
return fmt.Errorf("id '%s' is not valid", id)
|
||||
}
|
||||
|
||||
// Check if profile exists
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
return fmt.Errorf("profile '%s' does not exist", profileName)
|
||||
return fmt.Errorf("profile '%s' does not exist", id)
|
||||
}
|
||||
|
||||
// Read current config using internal profilemanager
|
||||
@@ -174,53 +185,56 @@ func (pm *ProfileManager) LogoutProfile(profileName string) error {
|
||||
return fmt.Errorf("failed to save config: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("logged out from profile: %s", profileName)
|
||||
log.Infof("logged out from profile: %s", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveProfile deletes a profile
|
||||
func (pm *ProfileManager) RemoveProfile(profileName string) error {
|
||||
func (pm *ProfileManager) RemoveProfile(id string) error {
|
||||
// Use ServiceManager (removes profile from profiles/ directory)
|
||||
if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil {
|
||||
if err := pm.serviceMgr.RemoveProfile(profilemanager.ID(id), androidUsername); err != nil {
|
||||
return fmt.Errorf("failed to remove profile: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("removed profile: %s", profileName)
|
||||
log.Infof("removed profile: %s", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// getProfileConfigPath returns the config file path for a profile
|
||||
// This is needed for Android-specific path handling (netbird.cfg for default profile)
|
||||
func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) {
|
||||
if profileName == "" || profileName == profilemanager.DefaultProfileName {
|
||||
func (pm *ProfileManager) getProfileConfigPath(id string) (string, error) {
|
||||
if id == "" || id == profilemanager.DefaultProfileName {
|
||||
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) {
|
||||
return "", fmt.Errorf("id %q is not valid", id)
|
||||
}
|
||||
// Android uses netbird.cfg for default profile instead of default.json
|
||||
// Default profile is stored in root configDir, not in profiles/
|
||||
return filepath.Join(pm.configDir, defaultConfigFilename), nil
|
||||
}
|
||||
|
||||
// Non-default profiles are stored in profiles subdirectory
|
||||
// This matches the Java Preferences.java expectation
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
|
||||
return filepath.Join(profilesDir, profileName+".json"), nil
|
||||
return filepath.Join(profilesDir, id+".json"), nil
|
||||
}
|
||||
|
||||
// GetConfigPath returns the config file path for a given profile
|
||||
// GetConfigPath returns the config file path for a given profile id
|
||||
// Java should call this instead of constructing paths with Preferences.configFile()
|
||||
func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) {
|
||||
return pm.getProfileConfigPath(profileName)
|
||||
func (pm *ProfileManager) GetConfigPath(id string) (string, error) {
|
||||
return pm.getProfileConfigPath(id)
|
||||
}
|
||||
|
||||
// GetStateFilePath returns the state file path for a given profile
|
||||
// Java should call this instead of constructing paths with Preferences.stateFile()
|
||||
func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) {
|
||||
if profileName == "" || profileName == profilemanager.DefaultProfileName {
|
||||
func (pm *ProfileManager) GetStateFilePath(id string) (string, error) {
|
||||
if id == "" || id == profilemanager.DefaultProfileName {
|
||||
return filepath.Join(pm.configDir, "state.json"), nil
|
||||
}
|
||||
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) {
|
||||
return "", fmt.Errorf("id %q is not valid", id)
|
||||
}
|
||||
|
||||
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
|
||||
return filepath.Join(profilesDir, profileName+".state.json"), nil
|
||||
return filepath.Join(profilesDir, id+".state.json"), nil
|
||||
}
|
||||
|
||||
// GetActiveConfigPath returns the config file path for the currently active profile
|
||||
@@ -230,7 +244,7 @@ func (pm *ProfileManager) GetActiveConfigPath() (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||
}
|
||||
return pm.GetConfigPath(activeProfile)
|
||||
return pm.GetConfigPath(activeProfile.ID)
|
||||
}
|
||||
|
||||
// GetActiveStateFilePath returns the state file path for the currently active profile
|
||||
@@ -240,18 +254,5 @@ func (pm *ProfileManager) GetActiveStateFilePath() (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||
}
|
||||
return pm.GetStateFilePath(activeProfile)
|
||||
}
|
||||
|
||||
// sanitizeProfileName removes invalid characters from profile name
|
||||
func sanitizeProfileName(name string) string {
|
||||
// Keep only alphanumeric, underscore, and hyphen
|
||||
var result strings.Builder
|
||||
for _, r := range name {
|
||||
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
|
||||
(r >= '0' && r <= '9') || r == '_' || r == '-' {
|
||||
result.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
return pm.GetStateFilePath(activeProfile.ID)
|
||||
}
|
||||
|
||||
@@ -3,14 +3,12 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/user"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
@@ -87,73 +85,6 @@ var persistenceCmd = &cobra.Command{
|
||||
RunE: setSyncResponsePersistence,
|
||||
}
|
||||
|
||||
var debugConfigCmd = &cobra.Command{
|
||||
Use: "config",
|
||||
Example: " netbird debug config",
|
||||
Short: "Dump the effective configuration",
|
||||
Long: "Prints the daemon's resolved configuration (after applying defaults, file, env, CLI input, and MDM policy overrides) as JSON. Includes the list of MDM-managed fields.",
|
||||
RunE: debugConfigDump,
|
||||
}
|
||||
|
||||
// debugConfigDump implements `netbird debug config`. It resolves the
|
||||
// active profile, queries the daemon for the effective configuration
|
||||
// via GetConfig, and prints the resulting GetConfigResponse as JSON
|
||||
// (via protojson with EmitUnpopulated=true so the output is stable
|
||||
// across runs and includes zero-valued fields).
|
||||
//
|
||||
// Useful for verifying MDM enforcement end-to-end: the response's
|
||||
// mDMManagedFields array is the single source of truth for "which
|
||||
// fields is the daemon currently enforcing from the MDM source", and
|
||||
// every config field side-by-side with that list confirms the merge
|
||||
// result. Secrets in the response (e.g. PreSharedKey) are already
|
||||
// redacted by the daemon-side handler.
|
||||
func debugConfigDump(cmd *cobra.Command, _ []string) error {
|
||||
pm := profilemanager.NewProfileManager()
|
||||
activeProf, err := pm.GetActiveProfile()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile: %v", err)
|
||||
}
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %v", err)
|
||||
}
|
||||
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf(errCloseConnection, err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.Name,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get config: %v", status.Convert(err).Message())
|
||||
}
|
||||
|
||||
// Use protojson so well-known fields render correctly; emit defaults so
|
||||
// the operator sees every field even when zero/empty.
|
||||
m := protojson.MarshalOptions{Multiline: true, Indent: " ", EmitUnpopulated: true}
|
||||
out, err := m.Marshal(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal config: %w", err)
|
||||
}
|
||||
cmd.Println(string(out))
|
||||
return nil
|
||||
}
|
||||
|
||||
// debugBundle requests the daemon to create a debug bundle and prints
|
||||
// the resulting local file path and, if uploaded, the uploaded file
|
||||
// key. It uses the package flags (anonymize, system info, log file
|
||||
// count, CLI version, optional upload URL) to configure the bundle
|
||||
// request. Returns an error if the RPC fails or if the daemon reports
|
||||
// an upload failure reason.
|
||||
func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,301 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
KubernetesDNSSuffix = "netbird-kubeapi-proxy"
|
||||
)
|
||||
|
||||
var kubernetesCmd = &cobra.Command{
|
||||
Use: "kubernetes",
|
||||
Short: "Kubernetes cluster commands.",
|
||||
Long: "Kubernetes cluster commands.",
|
||||
}
|
||||
|
||||
var kubernetesListCmd = &cobra.Command{
|
||||
Use: "list",
|
||||
RunE: kubernetesList,
|
||||
Short: "List Kubernetes clusters.",
|
||||
Long: "List Kubernetes clusters by discovering NetBird peers running netbird-kubeapi-proxy.",
|
||||
}
|
||||
|
||||
var kubernetesWriteKubeconfigCmd = &cobra.Command{
|
||||
Use: "write-kubeconfig",
|
||||
RunE: kubernetesWriteKubeconfig,
|
||||
Args: cobra.ExactArgs(1),
|
||||
Short: "Write kubeconfig for a Kubernetes cluster.",
|
||||
Long: "Updates kubeconfig in place to allow token-less access to the Kubernetes cluster through NetBird.",
|
||||
}
|
||||
|
||||
func init() {
|
||||
kubernetesWriteKubeconfigCmd.Flags().String("kubeconfig", "", "path to kubeconfig file")
|
||||
}
|
||||
|
||||
func kubernetesList(cmd *cobra.Command, _ []string) error {
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
statusResp, err := client.Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
kcs, err := getKubernetesClusters(cmd.Context(), statusResp.FullStatus.Peers, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(kcs) == 0 {
|
||||
cmd.Println("No Kubernetes clusters available.")
|
||||
return nil
|
||||
}
|
||||
cmd.Println("Available Kubernetes clusters:")
|
||||
for _, k := range kcs {
|
||||
cmd.Printf("\n - Name: %s\n FQDN: %s\n Version: %s\n", k.name, k.url.Host, k.version)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func kubernetesWriteKubeconfig(cmd *cobra.Command, args []string) error {
|
||||
kubeconfigPath, err := resolveKubeconfigPath(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := getClient(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
statusResp, err := client.Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clusterName := args[0]
|
||||
kcs, err := getKubernetesClusters(cmd.Context(), statusResp.FullStatus.Peers, clusterName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(kcs) == 0 {
|
||||
return fmt.Errorf("kubernetes cluster named %s not found", clusterName)
|
||||
}
|
||||
if len(kcs) > 1 {
|
||||
return fmt.Errorf("too many Kubernetes clusters returned")
|
||||
}
|
||||
err = writeKubeconfig(kubeconfigPath, kcs[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type kubernetesCluster struct {
|
||||
name string
|
||||
url *url.URL
|
||||
version string
|
||||
}
|
||||
|
||||
func getKubernetesClusters(ctx context.Context, peers []*proto.PeerState, nameFilter string) ([]kubernetesCluster, error) {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
httpClient := &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
resolver := net.Resolver{
|
||||
// Required so both DNS records are returned.
|
||||
// https://github.com/golang/go/issues/17093
|
||||
PreferGo: true,
|
||||
}
|
||||
|
||||
kcs := []kubernetesCluster{}
|
||||
attempted := map[string]struct{}{}
|
||||
for _, peer := range peers {
|
||||
fqdns, err := resolver.LookupAddr(ctx, peer.IP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, fqdn := range fqdns {
|
||||
if _, ok := attempted[fqdn]; ok {
|
||||
continue
|
||||
}
|
||||
attempted[fqdn] = struct{}{}
|
||||
comps := strings.Split(fqdn, ".")
|
||||
if len(comps) < 2 {
|
||||
continue
|
||||
}
|
||||
if comps[1] != KubernetesDNSSuffix {
|
||||
continue
|
||||
}
|
||||
if nameFilter != "" && nameFilter != comps[0] {
|
||||
continue
|
||||
}
|
||||
clusterURL, clusterVersion, err := fingerprintClusters(ctx, httpClient, fqdn)
|
||||
if err != nil {
|
||||
log.Debugf("could not fingerprint Kubernetes cluster %s %q", fqdn, err)
|
||||
continue
|
||||
}
|
||||
kc := kubernetesCluster{
|
||||
name: comps[0],
|
||||
url: clusterURL,
|
||||
version: clusterVersion,
|
||||
}
|
||||
if nameFilter != "" {
|
||||
return []kubernetesCluster{kc}, nil
|
||||
}
|
||||
kcs = append(kcs, kc)
|
||||
}
|
||||
}
|
||||
return kcs, nil
|
||||
}
|
||||
|
||||
func fingerprintClusters(ctx context.Context, httpClient *http.Client, fqdn string) (*url.URL, string, error) {
|
||||
clusterURL, err := url.Parse("https://" + fqdn)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
versionURL, err := clusterURL.Parse("/version")
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, versionURL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, "", fmt.Errorf("expected %d response but got %s", http.StatusOK, resp.Status)
|
||||
}
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
versionData := map[string]string{}
|
||||
err = json.Unmarshal(b, &versionData)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
version, ok := versionData["gitVersion"]
|
||||
if !ok {
|
||||
return nil, "", errors.New("no version found in response")
|
||||
}
|
||||
return clusterURL, version, nil
|
||||
}
|
||||
|
||||
func resolveKubeconfigPath(cmd *cobra.Command) (string, error) {
|
||||
if cmd.Flags().Changed("kubeconfig") {
|
||||
path, err := cmd.Flags().GetString("kubeconfig")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
if env := os.Getenv("KUBECONFIG"); env != "" {
|
||||
return env, nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not determine home directory: %w", err)
|
||||
}
|
||||
return filepath.Join(home, ".kube", "config"), nil
|
||||
}
|
||||
|
||||
func writeKubeconfig(kubeconfigPath string, kc kubernetesCluster) error {
|
||||
b, err := os.ReadFile(kubeconfigPath)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
var cfg map[string]any
|
||||
if err := yaml.Unmarshal(b, &cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if cfg == nil {
|
||||
cfg = map[string]any{
|
||||
"apiVersion": "v1",
|
||||
"kind": "Config",
|
||||
}
|
||||
}
|
||||
|
||||
cfg["clusters"] = appendWithName(cfg["clusters"], map[string]any{
|
||||
"name": kc.name,
|
||||
"cluster": map[string]any{
|
||||
"server": kc.url.String(),
|
||||
"insecure-skip-tls-verify": true,
|
||||
},
|
||||
})
|
||||
cfg["users"] = appendWithName(cfg["users"], map[string]any{
|
||||
"name": "netbird",
|
||||
"user": map[string]any{
|
||||
"token": "none",
|
||||
},
|
||||
})
|
||||
cfg["contexts"] = appendWithName(cfg["contexts"], map[string]any{
|
||||
"name": kc.name,
|
||||
"context": map[string]any{
|
||||
"cluster": kc.name,
|
||||
"user": "netbird",
|
||||
"namespace": "default",
|
||||
},
|
||||
})
|
||||
cfg["current-context"] = kc.name
|
||||
|
||||
out, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(kubeconfigPath, out, 0o600); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func appendWithName(data any, add map[string]any) any {
|
||||
if data == nil {
|
||||
return []any{add}
|
||||
}
|
||||
v, ok := data.([]any)
|
||||
if !ok {
|
||||
return []any{add}
|
||||
}
|
||||
i := slices.IndexFunc(v, func(item any) bool {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return m["name"] == add["name"]
|
||||
})
|
||||
if i == -1 {
|
||||
return append(v, add)
|
||||
}
|
||||
v[i] = add
|
||||
return v
|
||||
}
|
||||
@@ -1,120 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFingerprintClusters(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
//nolint: errcheck
|
||||
w.Write([]byte(`{"gitVersion": "foobar"}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
clusterURL, clusterVersion, err := fingerprintClusters(t.Context(), srv.Client(), srv.Listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, srv.URL, clusterURL.String())
|
||||
require.Equal(t, "foobar", clusterVersion)
|
||||
}
|
||||
|
||||
func TestResolveKubeconfigPath(t *testing.T) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
t.Fatalf("could not determine home directory: %v", err)
|
||||
}
|
||||
defaultPath := filepath.Join(home, ".kube", "config")
|
||||
path, err := resolveKubeconfigPath(&cobra.Command{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, defaultPath, path)
|
||||
|
||||
flagPath := "flag-path"
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().String("kubeconfig", "", "")
|
||||
err = cmd.Flags().Set("kubeconfig", flagPath)
|
||||
require.NoError(t, err)
|
||||
path, err = resolveKubeconfigPath(cmd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, flagPath, path)
|
||||
|
||||
envPath := "env-path"
|
||||
t.Setenv("KUBECONFIG", envPath)
|
||||
path, err = resolveKubeconfigPath(&cobra.Command{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, envPath, path)
|
||||
}
|
||||
|
||||
func TestWriteKubeconfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
existing string
|
||||
}{
|
||||
{
|
||||
name: "empty file",
|
||||
},
|
||||
{
|
||||
name: "existing content",
|
||||
existing: `apiVersion: v1
|
||||
clusters:
|
||||
- cluster:
|
||||
insecure-skip-tls-verify: true
|
||||
server: https://foobar.com
|
||||
name: foo
|
||||
current-context: test
|
||||
kind: Config
|
||||
users: []
|
||||
`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
kubeconfigPath := filepath.Join(t.TempDir(), "config")
|
||||
err := os.WriteFile(kubeconfigPath, []byte(tt.existing), 0o644)
|
||||
require.NoError(t, err)
|
||||
|
||||
kc := kubernetesCluster{
|
||||
name: "foo",
|
||||
url: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
}
|
||||
err = writeKubeconfig(kubeconfigPath, kc)
|
||||
require.NoError(t, err)
|
||||
|
||||
b, err := os.ReadFile(kubeconfigPath)
|
||||
require.NoError(t, err)
|
||||
expected := `apiVersion: v1
|
||||
clusters:
|
||||
- cluster:
|
||||
insecure-skip-tls-verify: true
|
||||
server: https://example.com
|
||||
name: foo
|
||||
contexts:
|
||||
- context:
|
||||
cluster: foo
|
||||
namespace: default
|
||||
user: netbird
|
||||
name: foo
|
||||
current-context: foo
|
||||
kind: Config
|
||||
users:
|
||||
- name: netbird
|
||||
user:
|
||||
token: none
|
||||
`
|
||||
require.Equal(t, expected, string(b))
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -96,17 +96,19 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
|
||||
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
|
||||
}
|
||||
|
||||
handle := activeProf.ID.String()
|
||||
|
||||
loginRequest := proto.LoginRequest{
|
||||
SetupKey: providedSetupKey,
|
||||
ManagementUrl: managementURL,
|
||||
IsUnixDesktopClient: isUnixRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
DnsLabels: dnsLabelsReq,
|
||||
ProfileName: &activeProf.Name,
|
||||
ProfileName: &handle,
|
||||
Username: &username,
|
||||
}
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
profileState, err := pm.GetProfileState(activeProf.ID)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
@@ -170,14 +172,13 @@ func getActiveProfile(ctx context.Context, pm *profilemanager.ProfileManager, pr
|
||||
return activeProf, nil
|
||||
}
|
||||
|
||||
func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) error {
|
||||
err := switchProfile(context.Background(), profileName, username)
|
||||
func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, handle string, username string) error {
|
||||
resolvedID, err := switchProfile(ctx, handle, username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("switch profile on daemon: %v", err)
|
||||
}
|
||||
|
||||
err = pm.SwitchProfile(profileName)
|
||||
if err != nil {
|
||||
if err := pm.SwitchProfile(resolvedID); err != nil {
|
||||
return fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
|
||||
@@ -205,11 +206,15 @@ func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManage
|
||||
return nil
|
||||
}
|
||||
|
||||
func switchProfile(ctx context.Context, profileName string, username string) error {
|
||||
// switchProfile asks the daemon to switch to the profile identified by
|
||||
// handle (a name, ID, or unique ID prefix). Returns the resolved profile
|
||||
// ID so the caller can update the local active-profile state without
|
||||
// re-resolving the handle.
|
||||
func switchProfile(ctx context.Context, handle string, username string) (profilemanager.ID, error) {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
@@ -217,15 +222,15 @@ func switchProfile(ctx context.Context, profileName string, username string) err
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
_, err = client.SwitchProfile(ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &profileName,
|
||||
resp, err := client.SwitchProfile(ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &handle,
|
||||
Username: &username,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("switch profile failed: %v", err)
|
||||
return "", fmt.Errorf("switch profile failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return profilemanager.ID(resp.Id), nil
|
||||
}
|
||||
|
||||
func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, activeProf *profilemanager.Profile) error {
|
||||
@@ -249,7 +254,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
||||
return fmt.Errorf("read config file %s: %v", configFilePath, err)
|
||||
}
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@@ -277,7 +282,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
|
||||
return nil
|
||||
}
|
||||
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string, profileID profilemanager.ID) error {
|
||||
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create auth client: %v", err)
|
||||
@@ -291,7 +296,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
|
||||
jwtToken := ""
|
||||
if setupKey == "" && needsLogin {
|
||||
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
|
||||
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
@@ -306,10 +311,10 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
return nil
|
||||
}
|
||||
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileID profilemanager.ID) (*auth.TokenInfo, error) {
|
||||
hint := ""
|
||||
pm := profilemanager.NewProfileManager()
|
||||
profileState, err := pm.GetProfileState(profileName)
|
||||
profileState, err := pm.GetProfileState(profileID)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
|
||||
@@ -27,7 +27,7 @@ func TestLogin(t *testing.T) {
|
||||
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
|
||||
sm := profilemanager.ServiceManager{}
|
||||
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: "default",
|
||||
ID: "default",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -2,11 +2,16 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/user"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
@@ -14,6 +19,8 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var profileListShowID bool
|
||||
|
||||
var profileCmd = &cobra.Command{
|
||||
Use: "profile",
|
||||
Short: "Manage NetBird client profiles",
|
||||
@@ -31,27 +38,40 @@ var profileListCmd = &cobra.Command{
|
||||
var profileAddCmd = &cobra.Command{
|
||||
Use: "add <profile_name>",
|
||||
Short: "Add a new profile",
|
||||
Long: `Add a new profile to the NetBird client. The profile name must be unique.`,
|
||||
Long: `Add a new profile. Profile name is free-form, a unique ID is generated for the on-disk config file.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: addProfileFunc,
|
||||
}
|
||||
|
||||
var profileRenameCmd = &cobra.Command{
|
||||
Use: "rename <profile> <new_profile_name>",
|
||||
Short: "Renames an existing profile",
|
||||
Long: `Renames an existing profile (by a name, ID, or unique ID prefix). Profile name is free-form.`,
|
||||
Args: cobra.ExactArgs(2),
|
||||
RunE: renameProfileFunc,
|
||||
}
|
||||
|
||||
var profileRemoveCmd = &cobra.Command{
|
||||
Use: "remove <profile_name>",
|
||||
Short: "Remove a profile",
|
||||
Long: `Remove a profile from the NetBird client. The profile must not be inactive.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: removeProfileFunc,
|
||||
Use: "remove <profile>",
|
||||
Short: "Remove a profile",
|
||||
Long: `Remove a profile by name, ID, or unique ID prefix.`,
|
||||
Aliases: []string{"rm"},
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: removeProfileFunc,
|
||||
}
|
||||
|
||||
var profileSelectCmd = &cobra.Command{
|
||||
Use: "select <profile_name>",
|
||||
Use: "select <profile>",
|
||||
Short: "Select a profile",
|
||||
Long: `Make the specified profile active. This will switch the client to use the selected profile's configuration.`,
|
||||
Long: `Make the specified profile active. Accepts a name, ID, or unique ID prefix.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: selectProfileFunc,
|
||||
}
|
||||
|
||||
func init() {
|
||||
profileListCmd.Flags().BoolVar(&profileListShowID, "show-id", false, "show the profile ID column")
|
||||
}
|
||||
|
||||
func setupCmd(cmd *cobra.Command) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
@@ -65,6 +85,7 @@ func setupCmd(cmd *cobra.Command) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func listProfilesFunc(cmd *cobra.Command, _ []string) error {
|
||||
if err := setupCmd(cmd); err != nil {
|
||||
return err
|
||||
@@ -83,25 +104,33 @@ func listProfilesFunc(cmd *cobra.Command, _ []string) error {
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
profiles, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{
|
||||
resp, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// list profiles, add a tick if the profile is active
|
||||
cmd.Println("Found", len(profiles.Profiles), "profiles:")
|
||||
for _, profile := range profiles.Profiles {
|
||||
// use a cross to indicate the passive profiles
|
||||
activeMarker := "✗"
|
||||
if profile.IsActive {
|
||||
activeMarker = "✓"
|
||||
}
|
||||
cmd.Println(activeMarker, profile.Name)
|
||||
tw := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 0, 2, ' ', 0)
|
||||
if profileListShowID {
|
||||
fmt.Fprintln(tw, "ID\tNAME\tACTIVE")
|
||||
} else {
|
||||
fmt.Fprintln(tw, "NAME\tACTIVE")
|
||||
}
|
||||
|
||||
return nil
|
||||
for _, profile := range resp.Profiles {
|
||||
marker := ""
|
||||
if profile.IsActive {
|
||||
marker = "✓"
|
||||
}
|
||||
name := profilemanager.StripCtrlChars(profile.Name)
|
||||
id := profilemanager.ID(profile.Id)
|
||||
if profileListShowID {
|
||||
fmt.Fprintf(tw, "%s\t%s\t%s\n", id.ShortID(), name, marker)
|
||||
} else {
|
||||
fmt.Fprintf(tw, "%s\t%s\n", name, marker)
|
||||
}
|
||||
}
|
||||
return tw.Flush()
|
||||
}
|
||||
|
||||
func addProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -121,21 +150,82 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
profileName := args[0]
|
||||
|
||||
_, err = daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
|
||||
resp, err := daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
|
||||
ProfileName: profileName,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("add profile request: %w", err)
|
||||
}
|
||||
|
||||
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
|
||||
if dupCount > 1 {
|
||||
cmd.Printf("Warning: %d other profile(s) already use the name %q.\n", dupCount-1, profileName)
|
||||
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
|
||||
}
|
||||
|
||||
id := profilemanager.ID(resp.Id)
|
||||
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func renameProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
if err := setupCmd(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Println("Profile added successfully:", profileName)
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to service CLI interface: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
handle := args[0]
|
||||
newProfilename := args[1]
|
||||
|
||||
resp, err := daemonClient.RenameProfile(cmd.Context(), &proto.RenameProfileRequest{
|
||||
Handle: handle,
|
||||
Username: currUser.Username,
|
||||
NewProfileName: newProfilename,
|
||||
})
|
||||
if err != nil {
|
||||
return wrapAmbiguityError(err, handle)
|
||||
}
|
||||
|
||||
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, newProfilename)
|
||||
if dupCount > 1 {
|
||||
cmd.Printf("Warning: %d other profile(s) already use the name %q.\n", dupCount-1, newProfilename)
|
||||
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
|
||||
}
|
||||
|
||||
cmd.Printf("Profile renamed from %s to %s\n", profilemanager.StripCtrlChars(resp.OldProfileName), profilemanager.StripCtrlChars(newProfilename))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func countProfilesWithName(ctx context.Context, c proto.DaemonServiceClient, username, name string) (int, error) {
|
||||
resp, err := c.ListProfiles(ctx, &proto.ListProfilesRequest{Username: username})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n := 0
|
||||
for _, p := range resp.Profiles {
|
||||
if p.Name == name {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func removeProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
if err := setupCmd(cmd); err != nil {
|
||||
return err
|
||||
@@ -153,18 +243,17 @@ func removeProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
handle := args[0]
|
||||
|
||||
profileName := args[0]
|
||||
|
||||
_, err = daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{
|
||||
ProfileName: profileName,
|
||||
resp, err := daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{
|
||||
ProfileName: handle,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return wrapAmbiguityError(err, handle)
|
||||
}
|
||||
|
||||
cmd.Println("Profile removed successfully:", profileName)
|
||||
cmd.Printf("Profile removed: %s\n", resp.Id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -174,7 +263,7 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
profileManager := profilemanager.NewProfileManager()
|
||||
profileName := args[0]
|
||||
handle := args[0]
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
@@ -191,32 +280,15 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
profiles, err := daemonClient.ListProfiles(ctx, &proto.ListProfilesRequest{
|
||||
Username: currUser.Username,
|
||||
switchResp, err := daemonClient.SwitchProfile(ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &handle,
|
||||
Username: &currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("list profiles: %w", err)
|
||||
return wrapAmbiguityError(err, handle)
|
||||
}
|
||||
|
||||
var profileExists bool
|
||||
|
||||
for _, profile := range profiles.Profiles {
|
||||
if profile.Name == profileName {
|
||||
profileExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !profileExists {
|
||||
return fmt.Errorf("profile %s does not exist", profileName)
|
||||
}
|
||||
|
||||
if err := switchProfile(cmd.Context(), profileName, currUser.Username); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = profileManager.SwitchProfile(profileName)
|
||||
if err != nil {
|
||||
if err := profileManager.SwitchProfile(profilemanager.ID(switchResp.Id)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -231,6 +303,30 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Println("Profile switched successfully to:", profileName)
|
||||
id := profilemanager.ID(switchResp.Id)
|
||||
cmd.Printf("Profile switched to: %s\n", id.ShortID())
|
||||
return nil
|
||||
}
|
||||
|
||||
// wrapAmbiguityError turns the daemon's gRPC InvalidArgument errors
|
||||
// (which carry the resolver's message verbatim) into CLI-friendly text
|
||||
// that points the user at --show-id.
|
||||
func wrapAmbiguityError(err error, handle string) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
st, ok := gstatus.FromError(err)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
switch st.Code() {
|
||||
case codes.InvalidArgument:
|
||||
msg := st.Message()
|
||||
if strings.Contains(msg, "ambiguous") {
|
||||
return errors.New(msg + "\nRun `netbird profile list --show-id` to see IDs, then select by ID prefix:\n netbird profile select|remove <id-prefix>")
|
||||
}
|
||||
case codes.NotFound:
|
||||
return fmt.Errorf("profile %q not found", handle)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -95,9 +95,7 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// Execute runs the appropriate Cobra command for the CLI.
|
||||
// If the process is the update binary it delegates to updateCmd; otherwise it runs the root command.
|
||||
// It returns any error produced during command execution.
|
||||
// Execute executes the root command.
|
||||
func Execute() error {
|
||||
if isUpdateBinary() {
|
||||
return updateCmd.Execute()
|
||||
@@ -105,16 +103,6 @@ func Execute() error {
|
||||
return rootCmd.Execute()
|
||||
}
|
||||
|
||||
// init initialises package-level defaults and configures the root
|
||||
// Cobra command tree. Sets platform-specific config / log directory
|
||||
// paths (including legacy Wiretrustee fallbacks) and a default daemon
|
||||
// address; registers persistent CLI flags (daemon address,
|
||||
// management / admin URLs, logging, setup key (file and inline,
|
||||
// mutually exclusive), preshared key, hostname, anonymise, config
|
||||
// path); attaches top-level and nested subcommands to the root
|
||||
// command; and registers `up`-specific persistent flags (external IP
|
||||
// maps, custom DNS resolver address, Rosenpass options, auto-connect
|
||||
// disabling, lazy connection).
|
||||
func init() {
|
||||
defaultConfigPathDir = "/etc/netbird/"
|
||||
defaultLogFileDir = "/var/log/netbird/"
|
||||
@@ -180,16 +168,11 @@ func init() {
|
||||
logCmd.AddCommand(logLevelCmd)
|
||||
debugCmd.AddCommand(forCmd)
|
||||
debugCmd.AddCommand(persistenceCmd)
|
||||
debugCmd.AddCommand(debugConfigCmd)
|
||||
|
||||
// kubernetes commands
|
||||
rootCmd.AddCommand(kubernetesCmd)
|
||||
kubernetesCmd.AddCommand(kubernetesListCmd)
|
||||
kubernetesCmd.AddCommand(kubernetesWriteKubeconfigCmd)
|
||||
|
||||
// profile commands
|
||||
profileCmd.AddCommand(profileListCmd)
|
||||
profileCmd.AddCommand(profileAddCmd)
|
||||
profileCmd.AddCommand(profileRenameCmd)
|
||||
profileCmd.AddCommand(profileRemoveCmd)
|
||||
profileCmd.AddCommand(profileSelectCmd)
|
||||
|
||||
|
||||
@@ -128,13 +128,12 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
||||
var profileSwitched bool
|
||||
// switch profile if provided
|
||||
if profileName != "" {
|
||||
err = switchProfile(cmd.Context(), profileName, username.Username)
|
||||
resolvedID, err := switchProfile(cmd.Context(), profileName, username.Username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
|
||||
err = pm.SwitchProfile(profileName)
|
||||
if err != nil {
|
||||
if err := pm.SwitchProfile(resolvedID); err != nil {
|
||||
return fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
|
||||
@@ -190,7 +189,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
|
||||
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@@ -261,10 +260,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
|
||||
}
|
||||
|
||||
// set the new config
|
||||
req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.Name, username.Username)
|
||||
req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.ID.String(), username.Username)
|
||||
if _, err := client.SetConfig(ctx, req); err != nil {
|
||||
if st, ok := gstatus.FromError(err); ok && st.Code() == codes.Unavailable {
|
||||
log.Warnf("setConfig method is not available in the daemon")
|
||||
log.Warnf("setConfig method is not available in the daemon: %s", st.Message())
|
||||
} else {
|
||||
return fmt.Errorf("call service setConfig method: %v", err)
|
||||
}
|
||||
@@ -289,10 +288,11 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
|
||||
return fmt.Errorf("setup login request: %v", err)
|
||||
}
|
||||
|
||||
loginRequest.ProfileName = &activeProf.Name
|
||||
profileID := activeProf.ID.String()
|
||||
loginRequest.ProfileName = &profileID
|
||||
loginRequest.Username = &username
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
profileState, err := pm.GetProfileState(activeProf.ID)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
@@ -329,7 +329,7 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
|
||||
}
|
||||
|
||||
if _, err := client.Up(ctx, &proto.UpRequest{
|
||||
ProfileName: &activeProf.Name,
|
||||
ProfileName: &profileID,
|
||||
Username: &username,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("call service up method: %v", err)
|
||||
|
||||
@@ -29,14 +29,14 @@ func TestUpDaemon(t *testing.T) {
|
||||
}
|
||||
|
||||
sm := profilemanager.ServiceManager{}
|
||||
err = sm.AddProfile("test1", currUser.Username)
|
||||
created, err := sm.AddProfile("test1", currUser.Username)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to add profile: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: "test1",
|
||||
ID: created.ID,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -279,10 +279,6 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
|
||||
select {
|
||||
case <-startCtx.Done():
|
||||
// Cancel the client context before stopping: Engine.Start blocks on the
|
||||
// signal stream while holding the engine mutex and only unblocks on
|
||||
// cancellation. Stopping first would deadlock on that mutex.
|
||||
cancel()
|
||||
if stopErr := client.Stop(); stopErr != nil {
|
||||
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||
}
|
||||
@@ -446,8 +442,8 @@ func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession,
|
||||
|
||||
// IdentityForIP looks up a remote peer by its tunnel IP using the
|
||||
// embedded client's status recorder. Returns the peer's WireGuard public
|
||||
// key and FQDN. ok=false means the IP doesn't belong to an active peer
|
||||
// — offline roster peers are treated as unknown, same as foreign IPs.
|
||||
// key and FQDN. ok=false means the IP isn't in this client's peer
|
||||
// roster — callers should treat that as "unknown peer".
|
||||
func (c *Client) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
|
||||
if !ip.IsValid() || c.recorder == nil {
|
||||
return "", "", false
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
package embed
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
mgmt "github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const testSetupKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
// TestClientStartTimeoutRollback reproduces a deadlock between Engine.Start and
|
||||
// Engine.Stop. The signal endpoint accepts gRPC connections but never serves the
|
||||
// SignalExchange service, so Engine.Start parks in WaitStreamConnected while
|
||||
// holding the engine mutex. When the Start context expires, the rollback path
|
||||
// calls ConnectClient.Stop, which must not block forever acquiring that mutex.
|
||||
func TestClientStartTimeoutRollback(t *testing.T) {
|
||||
signalAddr := startBlackholeSignal(t)
|
||||
mgmAddr := startManagement(t, signalAddr)
|
||||
|
||||
wgPort := 0
|
||||
client, err := New(Options{
|
||||
DeviceName: "embed-rollback-test",
|
||||
SetupKey: testSetupKey,
|
||||
ManagementURL: "http://" + mgmAddr,
|
||||
WireguardPort: &wgPort,
|
||||
})
|
||||
require.NoError(t, err, "embed client creation must succeed")
|
||||
|
||||
startCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
startErr := make(chan error, 1)
|
||||
go func() {
|
||||
startErr <- client.Start(startCtx)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-startErr:
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
case <-time.After(60 * time.Second):
|
||||
t.Fatal("client.Start did not return after its context expired: Engine.Stop deadlocked against Engine.Start waiting for the signal stream")
|
||||
}
|
||||
}
|
||||
|
||||
// startBlackholeSignal starts a gRPC server without the SignalExchange service
|
||||
// registered. Connections succeed, but the signal stream can never be
|
||||
// established, which keeps Engine.Start parked in WaitStreamConnected.
|
||||
func startBlackholeSignal(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
s := grpc.NewServer()
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
t.Cleanup(s.Stop)
|
||||
|
||||
return lis.Addr().String()
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T, signalAddr string) string {
|
||||
t.Helper()
|
||||
|
||||
cfg := &config.Config{
|
||||
Stuns: []*config.Host{},
|
||||
TURNConfig: &config.TURNConfig{},
|
||||
Relay: &config.Relay{
|
||||
Addresses: []string{"127.0.0.1:1234"},
|
||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||
Secret: "222222222222222222",
|
||||
},
|
||||
Signal: &config.Host{
|
||||
Proto: "http",
|
||||
URI: signalAddr,
|
||||
},
|
||||
Datadir: t.TempDir(),
|
||||
HttpConfig: nil,
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", "localhost:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
s := grpc.NewServer()
|
||||
|
||||
testStore, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", cfg.Datadir)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
permissionsManager := permissions.NewManager(testStore)
|
||||
peersManager := peers.NewManager(testStore, permissionsManager)
|
||||
jobManager := job.NewJobManager(nil, testStore, peersManager)
|
||||
|
||||
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
iv, err := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
require.NoError(t, err)
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.EXPECT().
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
settingsMockManager.EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := mgmt.NewAccountRequestBuffer(context.Background(), testStore)
|
||||
networkMapController := controller.NewController(context.Background(), testStore, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(testStore, peersManager), cfg)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), cfg, testStore, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
require.NoError(t, err)
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, cfg.TURNConfig, cfg.Relay, settingsMockManager, groupsManager)
|
||||
require.NoError(t, err)
|
||||
|
||||
mgmtServer, err := nbgrpc.NewServer(cfg, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
require.NoError(t, err)
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
t.Cleanup(s.Stop)
|
||||
|
||||
return lis.Addr().String()
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package iptables
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net"
|
||||
"slices"
|
||||
|
||||
@@ -422,17 +421,12 @@ func (m *aclManager) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
// Clone the maps so the persisted state holds a private snapshot. The
|
||||
// live maps keep being mutated by subsequent rule operations while the
|
||||
// state manager marshals the state from its periodic-save goroutine.
|
||||
// Sharing them by reference races the two and aborts the process with a
|
||||
// concurrent map iteration and write.
|
||||
if m.v6 {
|
||||
currentState.ACLEntries6 = maps.Clone(m.entries)
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore.clone()
|
||||
currentState.ACLEntries6 = m.entries
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore
|
||||
} else {
|
||||
currentState.ACLEntries = maps.Clone(m.entries)
|
||||
currentState.ACLIPsetStore = m.ipsetStore.clone()
|
||||
currentState.ACLEntries = m.entries
|
||||
currentState.ACLIPsetStore = m.ipsetStore
|
||||
}
|
||||
|
||||
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||
|
||||
@@ -4,7 +4,6 @@ package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -750,17 +749,11 @@ func (r *router) updateState() {
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
// Clone the rule map so the persisted state holds a private snapshot. The
|
||||
// live map keeps being mutated by subsequent rule operations while the
|
||||
// state manager marshals the state from its periodic-save goroutine.
|
||||
// Sharing it by reference races the two and aborts the process with a
|
||||
// concurrent map iteration and write. The ipset counter guards itself
|
||||
// during marshaling, so it can be shared directly.
|
||||
if r.v6 {
|
||||
currentState.RouteRules6 = maps.Clone(r.rules)
|
||||
currentState.RouteRules6 = r.rules
|
||||
currentState.RouteIPsetCounter6 = r.ipsetCounter
|
||||
} else {
|
||||
currentState.RouteRules = maps.Clone(r.rules)
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"maps"
|
||||
)
|
||||
import "encoding/json"
|
||||
|
||||
type ipList struct {
|
||||
ips map[string]struct{}
|
||||
@@ -22,14 +19,6 @@ func (s *ipList) addIP(ip string) {
|
||||
s.ips[ip] = struct{}{}
|
||||
}
|
||||
|
||||
// clone returns a deep copy of the ipList with its own ips map.
|
||||
func (s *ipList) clone() *ipList {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return &ipList{ips: maps.Clone(s.ips)}
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (s *ipList) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
@@ -66,19 +55,6 @@ func newIpsetStore() *ipsetStore {
|
||||
}
|
||||
}
|
||||
|
||||
// clone returns a deep copy of the ipsetStore with its own ipsets map and
|
||||
// independent ipList entries.
|
||||
func (s *ipsetStore) clone() *ipsetStore {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := &ipsetStore{ipsets: make(map[string]*ipList, len(s.ipsets))}
|
||||
for name, list := range s.ipsets {
|
||||
cloned.ipsets[name] = list.clone()
|
||||
}
|
||||
return cloned
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
@@ -44,13 +41,10 @@ type PacketCapture interface {
|
||||
type FilteredDevice struct {
|
||||
tun.Device
|
||||
|
||||
filter PacketFilter
|
||||
capture atomic.Pointer[PacketCapture]
|
||||
// panicHandler is invoked after a panic in the underlying device is
|
||||
// recovered in Read or Write.
|
||||
panicHandler atomic.Pointer[func()]
|
||||
mutex sync.RWMutex
|
||||
closeOnce sync.Once
|
||||
filter PacketFilter
|
||||
capture atomic.Pointer[PacketCapture]
|
||||
mutex sync.RWMutex
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
// newDeviceFilter constructor function
|
||||
@@ -76,7 +70,7 @@ func (d *FilteredDevice) Close() error {
|
||||
|
||||
// Read wraps read method with filtering feature
|
||||
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
if n, err = d.deviceRead(bufs, sizes, offset); err != nil {
|
||||
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
@@ -118,7 +112,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
d.mutex.RUnlock()
|
||||
|
||||
if filter == nil {
|
||||
return d.deviceWrite(bufs, offset)
|
||||
return d.Device.Write(bufs, offset)
|
||||
}
|
||||
|
||||
filteredBufs := make([][]byte, 0, len(bufs))
|
||||
@@ -131,44 +125,9 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
n, err := d.deviceWrite(filteredBufs, offset)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
return n + dropped, nil
|
||||
}
|
||||
|
||||
// deviceRead calls the underlying device Read, recovering from panics in the
|
||||
// wintun read path and converting them into errors.
|
||||
func (d *FilteredDevice) deviceRead(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
defer d.recoverFromPanic("read", &n, &err)
|
||||
return d.Device.Read(bufs, sizes, offset)
|
||||
}
|
||||
|
||||
// deviceWrite calls the underlying device Write, recovering from panics in the
|
||||
// wintun write path and converting them into errors.
|
||||
func (d *FilteredDevice) deviceWrite(bufs [][]byte, offset int) (n int, err error) {
|
||||
defer d.recoverFromPanic("write", &n, &err)
|
||||
return d.Device.Write(bufs, offset)
|
||||
}
|
||||
|
||||
// recoverFromPanic converts a panic in the underlying device into a regular
|
||||
// error and invokes the registered panic handler. The wintun read path is
|
||||
// known to panic on zero-length packets that third-party filter drivers can
|
||||
// place in the ring.
|
||||
func (d *FilteredDevice) recoverFromPanic(op string, n *int, err *error) {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Errorf("recovered panic in tun device %s: %v\n%s", op, r, debug.Stack())
|
||||
*n = 0
|
||||
*err = fmt.Errorf("tun device %s panic: %v", op, r)
|
||||
|
||||
if handler := d.panicHandler.Load(); handler != nil {
|
||||
(*handler)()
|
||||
}
|
||||
n, err := d.Device.Write(filteredBufs, offset)
|
||||
n += dropped
|
||||
return n, err
|
||||
}
|
||||
|
||||
// SetFilter sets packet filter to device
|
||||
@@ -178,17 +137,6 @@ func (d *FilteredDevice) SetFilter(filter PacketFilter) {
|
||||
d.mutex.Unlock()
|
||||
}
|
||||
|
||||
// SetPanicHandler registers a handler invoked after a recovered panic in Read
|
||||
// or Write. The device is unusable after such a panic; the handler should
|
||||
// trigger recreation of the interface. Pass nil to remove.
|
||||
func (d *FilteredDevice) SetPanicHandler(handler func()) {
|
||||
if handler == nil {
|
||||
d.panicHandler.Store(nil)
|
||||
return
|
||||
}
|
||||
d.panicHandler.Store(&handler)
|
||||
}
|
||||
|
||||
// SetCapture sets or clears the packet capture sink. Pass nil to disable.
|
||||
// Uses atomic store so the hot path (Read/Write) is a single pointer load
|
||||
// with no locking overhead when capture is off.
|
||||
|
||||
@@ -221,60 +221,3 @@ func TestDeviceWrapperRead(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDeviceWrapperReadPanic(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
tun := mocks.NewMockDevice(ctrl)
|
||||
tun.EXPECT().Read(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||
// Reproduce the wintun zero-length packet panic (index out of range).
|
||||
packet := make([]byte, 0)
|
||||
return int(packet[0]), nil
|
||||
})
|
||||
|
||||
wrapped := newDeviceFilter(tun)
|
||||
|
||||
handlerCalled := false
|
||||
wrapped.SetPanicHandler(func() { handlerCalled = true })
|
||||
|
||||
n, err := wrapped.Read([][]byte{{}}, []int{0}, 0)
|
||||
if err == nil {
|
||||
t.Errorf("expected error from recovered panic, got nil")
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("expected n=0, got %d", n)
|
||||
}
|
||||
if !handlerCalled {
|
||||
t.Errorf("expected panic handler to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceWrapperWritePanic(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
tun := mocks.NewMockDevice(ctrl)
|
||||
tun.EXPECT().Write(gomock.Any(), gomock.Any()).
|
||||
DoAndReturn(func(bufs [][]byte, offset int) (int, error) {
|
||||
packet := make([]byte, 0)
|
||||
return int(packet[0]), nil
|
||||
})
|
||||
|
||||
wrapped := newDeviceFilter(tun)
|
||||
|
||||
handlerCalled := false
|
||||
wrapped.SetPanicHandler(func() { handlerCalled = true })
|
||||
|
||||
n, err := wrapped.Write([][]byte{{0x45, 0x00}}, 0)
|
||||
if err == nil {
|
||||
t.Errorf("expected error from recovered panic, got nil")
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("expected n=0, got %d", n)
|
||||
}
|
||||
if !handlerCalled {
|
||||
t.Errorf("expected panic handler to be called")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -516,14 +516,6 @@ func (g *BundleGenerator) addConfig() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Surface the set of MDM-enforced keys so a support engineer reading
|
||||
// the bundle can tell which field values are user-set vs MDM-overridden.
|
||||
// Same semantics as the mDMManagedFields list returned by the
|
||||
// GetConfig RPC consumed by `netbird debug config`.
|
||||
if managed := g.internalConfig.Policy().ManagedKeys(); len(managed) > 0 {
|
||||
configContent.WriteString(fmt.Sprintf("MDMManagedFields: %v\n", managed))
|
||||
}
|
||||
|
||||
configReader := strings.NewReader(configContent.String())
|
||||
if err := g.addFileToZip(configReader, "config.txt"); err != nil {
|
||||
return fmt.Errorf("add config file to zip: %w", err)
|
||||
|
||||
@@ -843,7 +843,7 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
||||
"PreSharedKey": "sensitive: WireGuard pre-shared key",
|
||||
"SSHKey": "sensitive: SSH private key",
|
||||
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
|
||||
"policy": "non-config: in-memory MDM policy snapshot, surfaced via Config.Policy() / GetConfigResponse.MDMManagedFields",
|
||||
"Name": "non-config: profile name is not needed for debug purposes",
|
||||
}
|
||||
|
||||
mURL, _ := url.Parse("https://api.example.com:443")
|
||||
|
||||
@@ -482,7 +482,7 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16,
|
||||
// completely when every proxy peer is offline (the upstream may still
|
||||
// be reachable some other way, or the peerstore may be stale).
|
||||
func (d *Resolver) filterDisconnectedPeerAnswers(logger *log.Entry, question dns.Question, records []dns.RR) []dns.RR {
|
||||
if len(records) < 2 {
|
||||
if len(records) == 0 {
|
||||
return records
|
||||
}
|
||||
d.mu.RLock()
|
||||
|
||||
@@ -2738,17 +2738,6 @@ func TestLocalResolver_FilterDisconnectedPeerAnswers(t *testing.T) {
|
||||
connByIP: nil,
|
||||
wantInOrder: []string{"100.64.0.10", "100.64.0.11"},
|
||||
},
|
||||
{
|
||||
// A single answer is never filtered: dropping it would only
|
||||
// trigger the empty-answer escape hatch, so the fast path
|
||||
// returns it untouched.
|
||||
name: "single disconnected answer passes through",
|
||||
records: []nbdns.SimpleRecord{disconnectedRec},
|
||||
connByIP: map[string]ipState{
|
||||
"100.64.0.11": {known: true, connected: false},
|
||||
},
|
||||
wantInOrder: []string{"100.64.0.11"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
|
||||
@@ -777,24 +777,13 @@ func (s *DefaultServer) applyHostConfig() {
|
||||
// context is released rather than leaked until GC.
|
||||
func (s *DefaultServer) registerFallback() {
|
||||
originalNameservers := s.hostManager.getOriginalNameservers()
|
||||
|
||||
serverIP := s.service.RuntimeIP()
|
||||
var servers []netip.AddrPort
|
||||
for _, ns := range originalNameservers {
|
||||
if ns == serverIP {
|
||||
log.Debugf("skipping original nameserver %s as it is the same as the server IP %s", ns, serverIP)
|
||||
continue
|
||||
}
|
||||
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
|
||||
}
|
||||
|
||||
if len(servers) == 0 {
|
||||
if len(originalNameservers) == 0 {
|
||||
log.Debugf("no fallback upstreams to register; clearing PriorityFallback handler")
|
||||
s.clearFallback()
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("registering original nameservers %v as upstream handlers with priority %d", servers, PriorityFallback)
|
||||
log.Infof("registering original nameservers %v as upstream handlers with priority %d", originalNameservers, PriorityFallback)
|
||||
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
@@ -808,6 +797,11 @@ func (s *DefaultServer) registerFallback() {
|
||||
return
|
||||
}
|
||||
handler.selectedRoutes = s.selectedRoutes
|
||||
|
||||
var servers []netip.AddrPort
|
||||
for _, ns := range originalNameservers {
|
||||
servers = append(servers, netip.AddrPortFrom(ns, DefaultPort))
|
||||
}
|
||||
handler.addRace(servers)
|
||||
|
||||
prev := s.fallbackHandler
|
||||
|
||||
@@ -240,7 +240,7 @@ type Engine struct {
|
||||
syncStore syncstore.Store
|
||||
syncStoreDir string
|
||||
|
||||
flowManager nftypes.FlowManager
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// auto-update
|
||||
updateManager *updater.Manager
|
||||
@@ -531,10 +531,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return fmt.Errorf("create wg interface: %w", err)
|
||||
}
|
||||
|
||||
if filteredDevice := e.wgInterface.GetDevice(); filteredDevice != nil {
|
||||
filteredDevice.SetPanicHandler(e.triggerClientRestart)
|
||||
}
|
||||
|
||||
if err := e.createFirewall(); err != nil {
|
||||
e.close()
|
||||
return err
|
||||
@@ -884,25 +880,62 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
||||
}
|
||||
|
||||
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
|
||||
return err
|
||||
}
|
||||
if update.GetNetbirdConfig() != nil {
|
||||
wCfg := update.GetNetbirdConfig()
|
||||
err := e.updateTURNs(wCfg.GetTurns())
|
||||
if err != nil {
|
||||
return fmt.Errorf("update TURNs: %w", err)
|
||||
}
|
||||
|
||||
// Posture checks are bound to the network map presence:
|
||||
// NetworkMap != nil, checks present -> apply the received checks
|
||||
// NetworkMap != nil, checks nil -> posture checks were removed, clear them
|
||||
// NetworkMap == nil -> config-only update (e.g. relay token rotation),
|
||||
// leave the previously applied checks untouched
|
||||
nm := update.GetNetworkMap()
|
||||
if nm == nil {
|
||||
return nil
|
||||
err = e.updateSTUNs(wCfg.GetStuns())
|
||||
if err != nil {
|
||||
return fmt.Errorf("update STUNs: %w", err)
|
||||
}
|
||||
|
||||
var stunTurn []*stun.URI
|
||||
stunTurn = append(stunTurn, e.STUNs...)
|
||||
stunTurn = append(stunTurn, e.TURNs...)
|
||||
e.stunTurn.Store(stunTurn)
|
||||
|
||||
err = e.handleRelayUpdate(wCfg.GetRelay())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = e.handleFlowUpdate(wCfg.GetFlow())
|
||||
if err != nil {
|
||||
return fmt.Errorf("handle the flow configuration: %w", err)
|
||||
}
|
||||
|
||||
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
|
||||
log.Warnf("Failed to update DNS server config: %v", err)
|
||||
}
|
||||
|
||||
// todo update signal
|
||||
}
|
||||
|
||||
if err := e.updateChecksIfNew(update.Checks); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.persistSyncResponse(update)
|
||||
nm := update.GetNetworkMap()
|
||||
if nm == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
|
||||
// A non-nil syncStore is what marks persistence as enabled. Hold the lock for
|
||||
// the whole Set so the store cannot be cleared (disabled / engine close)
|
||||
// mid-call and have this write resurrect a file that was just removed.
|
||||
e.syncRespMux.RLock()
|
||||
if e.syncStore != nil {
|
||||
if err := e.syncStore.Set(update); err != nil {
|
||||
log.Errorf("failed to persist sync response: %v", err)
|
||||
} else {
|
||||
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
|
||||
}
|
||||
}
|
||||
e.syncRespMux.RUnlock()
|
||||
|
||||
// only apply new changes and ignore old ones
|
||||
if err := e.updateNetworkMap(nm); err != nil {
|
||||
@@ -914,64 +947,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateNetbirdConfig applies the management-provided NetBird configuration:
|
||||
// STUN/TURN and relay servers, flow logging and DNS settings. A nil config is a no-op,
|
||||
// which is the case for sync updates carrying only a network map.
|
||||
func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error {
|
||||
if wCfg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.updateTURNs(wCfg.GetTurns()); err != nil {
|
||||
return fmt.Errorf("update TURNs: %w", err)
|
||||
}
|
||||
|
||||
if err := e.updateSTUNs(wCfg.GetStuns()); err != nil {
|
||||
return fmt.Errorf("update STUNs: %w", err)
|
||||
}
|
||||
|
||||
var stunTurn []*stun.URI
|
||||
stunTurn = append(stunTurn, e.STUNs...)
|
||||
stunTurn = append(stunTurn, e.TURNs...)
|
||||
e.stunTurn.Store(stunTurn)
|
||||
|
||||
if err := e.handleRelayUpdate(wCfg.GetRelay()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := e.handleFlowUpdate(wCfg.GetFlow()); err != nil {
|
||||
return fmt.Errorf("handle the flow configuration: %w", err)
|
||||
}
|
||||
|
||||
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
|
||||
log.Warnf("Failed to update DNS server config: %v", err)
|
||||
}
|
||||
|
||||
// todo update signal
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// persistSyncResponse stores the full sync response so it can be restored on the next
|
||||
// startup. Persistence is enabled only when syncStore is set. The dedicated syncRespMux
|
||||
// (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled /
|
||||
// engine close) mid-call and have this write resurrect a file that was just removed.
|
||||
func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) {
|
||||
e.syncRespMux.RLock()
|
||||
defer e.syncRespMux.RUnlock()
|
||||
|
||||
if e.syncStore == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := e.syncStore.Set(update); err != nil {
|
||||
log.Errorf("failed to persist sync response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("sync response persisted with serial %d", update.GetNetworkMap().GetSerial())
|
||||
}
|
||||
|
||||
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
||||
if update != nil {
|
||||
// when we receive token we expect valid address list too
|
||||
|
||||
@@ -26,6 +26,7 @@ type connStatusInputs struct {
|
||||
iceInProgress bool // a negotiation is currently in flight
|
||||
}
|
||||
|
||||
|
||||
// ConnStatus describe the status of a peer's connection
|
||||
type ConnStatus int32
|
||||
|
||||
|
||||
@@ -193,7 +193,6 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
|
||||
type Status struct {
|
||||
mux sync.RWMutex
|
||||
peers map[string]State
|
||||
ipToKey map[string]string
|
||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||
signalState bool
|
||||
signalError error
|
||||
@@ -232,7 +231,6 @@ type Status struct {
|
||||
func NewRecorder(mgmAddress string) *Status {
|
||||
return &Status{
|
||||
peers: make(map[string]State),
|
||||
ipToKey: make(map[string]string),
|
||||
changeNotify: make(map[string]map[string]*StatusChangeSubscription),
|
||||
eventStreams: make(map[string]chan *proto.SystemEvent),
|
||||
eventQueue: NewEventQueue(eventQueueSize),
|
||||
@@ -284,12 +282,6 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string)
|
||||
Mux: new(sync.RWMutex),
|
||||
}
|
||||
d.peerListChangedForNotification = true
|
||||
if ipv6 != "" {
|
||||
d.ipToKey[ipv6] = peerPubKey
|
||||
}
|
||||
if ip != "" {
|
||||
d.ipToKey[ip] = peerPubKey
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -319,22 +311,28 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
|
||||
|
||||
// PeerStateByIP returns the full peer State for the given tunnel IP.
|
||||
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
|
||||
// address so dual-stack peers are reachable on either family. Only
|
||||
// active peers are matched; peers moved into the offline slice by
|
||||
// ReplaceOfflinePeers are intentionally treated as unknown.
|
||||
// address so dual-stack peers are reachable on either family. Searches
|
||||
// both d.peers and d.offlinePeers — peers that have been moved into
|
||||
// the offline slice by ReplaceOfflinePeers are still part of the
|
||||
// account's roster and callers (DNS filter, embed.Client.IdentityForIP)
|
||||
// need to recognise them rather than treating them as unknown. Returns
|
||||
// the zero State and false when no peer matches or the input is empty.
|
||||
func (d *Status) PeerStateByIP(ip string) (State, bool) {
|
||||
if ip == "" {
|
||||
return State{}, false
|
||||
}
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
key, ok := d.ipToKey[ip]
|
||||
if !ok {
|
||||
return State{}, false
|
||||
|
||||
for _, state := range d.peers {
|
||||
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
|
||||
return state, true
|
||||
}
|
||||
}
|
||||
state, ok := d.peers[key]
|
||||
if ok {
|
||||
return state, true
|
||||
for _, state := range d.offlinePeers {
|
||||
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
|
||||
return state, true
|
||||
}
|
||||
}
|
||||
return State{}, false
|
||||
}
|
||||
@@ -344,18 +342,12 @@ func (d *Status) RemovePeer(peerPubKey string) error {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
|
||||
p, ok := d.peers[peerPubKey]
|
||||
_, ok := d.peers[peerPubKey]
|
||||
if !ok {
|
||||
return errors.New("no peer with to remove")
|
||||
}
|
||||
|
||||
delete(d.peers, peerPubKey)
|
||||
if mappedKey, exists := d.ipToKey[p.IP]; exists && mappedKey == peerPubKey {
|
||||
delete(d.ipToKey, p.IP)
|
||||
}
|
||||
if mappedKey, exists := d.ipToKey[p.IPv6]; exists && mappedKey == peerPubKey {
|
||||
delete(d.ipToKey, p.IPv6)
|
||||
}
|
||||
d.peerListChangedForNotification = true
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -90,11 +90,12 @@ func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
|
||||
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
|
||||
}
|
||||
|
||||
// TestStatus_PeerStateByIP_IgnoresOfflinePeers documents that peers
|
||||
// moved into the offline slice via ReplaceOfflinePeers are intentionally
|
||||
// not resolvable by IP: only active peers can carry traffic, so callers
|
||||
// (DNS filter, embed.Client.IdentityForIP) treat them as unknown.
|
||||
func TestStatus_PeerStateByIP_IgnoresOfflinePeers(t *testing.T) {
|
||||
// TestStatus_PeerStateByIP_MatchesOfflinePeers covers peers that have
|
||||
// been moved into the offline slice via ReplaceOfflinePeers. Callers
|
||||
// (DNS filter, embed.Client.IdentityForIP) need to treat them as known
|
||||
// rather than unknown — otherwise authentication / DNS filtering treats
|
||||
// known-but-offline peers as foreign IPs.
|
||||
func TestStatus_PeerStateByIP_MatchesOfflinePeers(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
req := require.New(t)
|
||||
|
||||
@@ -102,31 +103,13 @@ func TestStatus_PeerStateByIP_IgnoresOfflinePeers(t *testing.T) {
|
||||
{PubKey: "pk-offline", FQDN: "offline.netbird", IP: "100.64.0.20", IPv6: "fd00::20"},
|
||||
})
|
||||
|
||||
_, ok := status.PeerStateByIP("100.64.0.20")
|
||||
req.False(ok, "offline peer must not resolve by IPv4 tunnel address")
|
||||
state, ok := status.PeerStateByIP("100.64.0.20")
|
||||
req.True(ok, "offline peer must resolve by IPv4 tunnel address")
|
||||
req.Equal("pk-offline", state.PubKey, "matching state must carry the offline peer's pub key")
|
||||
|
||||
_, ok = status.PeerStateByIP("fd00::20")
|
||||
req.False(ok, "offline peer must not resolve by IPv6 tunnel address")
|
||||
}
|
||||
|
||||
// TestStatus_PeerStateByIP_RemovedPeer verifies RemovePeer drops the
|
||||
// IP index entries for both address families.
|
||||
func TestStatus_PeerStateByIP_RemovedPeer(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
req := require.New(t)
|
||||
|
||||
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", "fd00::1"))
|
||||
|
||||
_, ok := status.PeerStateByIP("100.64.0.10")
|
||||
req.True(ok, "active peer must resolve before removal")
|
||||
|
||||
req.NoError(status.RemovePeer("pk-1"))
|
||||
|
||||
_, ok = status.PeerStateByIP("100.64.0.10")
|
||||
req.False(ok, "removed peer must not resolve by IPv4 tunnel address")
|
||||
|
||||
_, ok = status.PeerStateByIP("fd00::1")
|
||||
req.False(ok, "removed peer must not resolve by IPv6 tunnel address")
|
||||
state, ok = status.PeerStateByIP("fd00::20")
|
||||
req.True(ok, "offline peer must resolve by IPv6 tunnel address")
|
||||
req.Equal("pk-offline", state.PubKey, "IPv6 match must carry the offline peer's pub key")
|
||||
}
|
||||
|
||||
func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
||||
|
||||
@@ -22,7 +22,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -58,10 +57,6 @@ var DefaultInterfaceBlacklist = []string{
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
|
||||
}
|
||||
|
||||
// loadMDMPolicy is the package-level indirection used by apply() to read the
|
||||
// active MDM policy. Tests override this to inject a fake policy.
|
||||
var loadMDMPolicy = mdm.LoadPolicy
|
||||
|
||||
// ConfigInput carries configuration changes to the client
|
||||
type ConfigInput struct {
|
||||
ManagementURL string
|
||||
@@ -108,6 +103,10 @@ type ConfigInput struct {
|
||||
|
||||
// Config Configuration type
|
||||
type Config struct {
|
||||
// Name is the human-readable profile name shown in CLI/UI listings.
|
||||
// It is independent of the profile's on-disk filename (which is the ID).
|
||||
Name string
|
||||
|
||||
// Wireguard private key of local peer
|
||||
PrivateKey string
|
||||
PreSharedKey string
|
||||
@@ -179,23 +178,6 @@ type Config struct {
|
||||
LazyConnectionEnabled bool
|
||||
|
||||
MTU uint16
|
||||
|
||||
// policy is the MDM policy that produced the currently-set values for
|
||||
// any MDM-enforced fields. Set by applyMDMPolicy at the tail of apply()
|
||||
// and reset on every apply() invocation. Never persisted to disk.
|
||||
// Callers query enforcement state via Policy() and the mdm.Policy API
|
||||
// (HasKey, ManagedKeys, IsEmpty).
|
||||
policy *mdm.Policy `json:"-"`
|
||||
}
|
||||
|
||||
// Policy returns the MDM policy applied to this Config. Returns a non-nil
|
||||
// empty Policy when MDM enforcement is inactive; callers can always invoke
|
||||
// HasKey / ManagedKeys / IsEmpty without a nil check.
|
||||
func (config *Config) Policy() *mdm.Policy {
|
||||
if config == nil || config.policy == nil {
|
||||
return mdm.NewPolicy(nil)
|
||||
}
|
||||
return config.policy
|
||||
}
|
||||
|
||||
var ConfigDirOverride string
|
||||
@@ -270,6 +252,16 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
||||
}
|
||||
|
||||
func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
if config.Name != "" {
|
||||
sanitized, err := sanitizeDisplayName(config.Name)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("invalid profile name: %w", err)
|
||||
}
|
||||
if sanitized != config.Name {
|
||||
config.Name = sanitized
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
if config.ManagementURL == nil {
|
||||
log.Infof("using default Management URL %s", DefaultManagementURL)
|
||||
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)
|
||||
@@ -634,93 +626,10 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
// MDM is the last override layer: any key present in the policy
|
||||
// supersedes defaults, on-disk config, env vars and CLI input.
|
||||
config.applyMDMPolicy(loadMDMPolicy())
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// applyMDMPolicy overlays MDM-supplied values on top of the resolved Config.
|
||||
// The provided Policy is also stored on the Config so callers can later query
|
||||
// which fields are enforced. Invalid values (e.g. malformed URLs) are logged
|
||||
// and skipped to avoid bricking the client; the field keeps its previous
|
||||
// resolved value but is still marked as managed (Policy.HasKey returns true
|
||||
// for the key, so per-field rejection of user writes still applies).
|
||||
func (config *Config) applyMDMPolicy(policy *mdm.Policy) {
|
||||
config.policy = policy
|
||||
if policy.IsEmpty() {
|
||||
return
|
||||
}
|
||||
|
||||
// Helper: log the application of a single MDM-managed key. Values for
|
||||
// keys in mdm.SecretKeys are redacted.
|
||||
logApplied := func(key string, displayValue any) {
|
||||
if _, secret := mdm.SecretKeys[key]; secret {
|
||||
log.Infof("MDM override %s = ********** (secret)", key)
|
||||
return
|
||||
}
|
||||
log.Infof("MDM override %s = %v", key, displayValue)
|
||||
}
|
||||
|
||||
if v, ok := policy.GetString(mdm.KeyManagementURL); ok {
|
||||
if u, err := parseURL("Management URL", v); err != nil {
|
||||
log.Warnf("MDM management URL %q invalid: %v; keeping previous value", v, err)
|
||||
} else {
|
||||
config.ManagementURL = u
|
||||
logApplied(mdm.KeyManagementURL, u.String())
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := policy.GetString(mdm.KeyPreSharedKey); ok {
|
||||
// Defensive: refuse the redaction mask in case it round-tripped
|
||||
// through a manifest by mistake.
|
||||
if !isPreSharedKeyHidden(&v) {
|
||||
config.PreSharedKey = v
|
||||
logApplied(mdm.KeyPreSharedKey, "")
|
||||
}
|
||||
}
|
||||
|
||||
// applyBool collapses the per-key "read + set + log" boilerplate
|
||||
// for every plain bool MDM key into a single helper. Keeps the
|
||||
// outer function's cognitive complexity below SonarCube's
|
||||
// threshold; functional behaviour is identical to the inlined
|
||||
// branches it replaces.
|
||||
applyBool := func(key string, setter func(bool)) {
|
||||
v, ok := policy.GetBool(key)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
setter(v)
|
||||
logApplied(key, v)
|
||||
}
|
||||
|
||||
applyBool(mdm.KeyAllowServerSSH, func(v bool) { bv := v; config.ServerSSHAllowed = &bv })
|
||||
applyBool(mdm.KeyDisableClientRoutes, func(v bool) { config.DisableClientRoutes = v })
|
||||
applyBool(mdm.KeyDisableServerRoutes, func(v bool) { config.DisableServerRoutes = v })
|
||||
applyBool(mdm.KeyBlockInbound, func(v bool) { config.BlockInbound = v })
|
||||
applyBool(mdm.KeyDisableAutoConnect, func(v bool) { config.DisableAutoConnect = v })
|
||||
applyBool(mdm.KeyRosenpassEnabled, func(v bool) { config.RosenpassEnabled = v })
|
||||
applyBool(mdm.KeyRosenpassPermissive, func(v bool) { config.RosenpassPermissive = v })
|
||||
|
||||
if v, ok := policy.GetInt(mdm.KeyWireguardPort); ok {
|
||||
// REG_DWORD is 32-bit; UDP port range is 1-65535. Clamp at the
|
||||
// upper bound and reject obviously-invalid values to avoid the
|
||||
// engine binding to an unusable port if the admin pushes garbage.
|
||||
if v >= 1 && v <= 65535 {
|
||||
config.WgPort = int(v)
|
||||
logApplied(mdm.KeyWireguardPort, v)
|
||||
} else {
|
||||
log.Warnf("MDM wireguard port %d out of range [1,65535]; keeping previous value", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseURL parses and validates the URL for the named service. The URL
|
||||
// must use the http or https scheme; if no port is present, ":443" is
|
||||
// appended for https or ":80" for http. The serviceName parameter is
|
||||
// used to contextualise error messages. On success returns the parsed
|
||||
// *url.URL; on failure returns a non-nil error.
|
||||
// parseURL parses and validates a service URL
|
||||
func parseURL(serviceName, serviceURL string) (*url.URL, error) {
|
||||
parsedMgmtURL, err := url.ParseRequestURI(serviceURL)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
package profilemanager
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
)
|
||||
|
||||
// withMDMPolicy temporarily overrides the package-level loadMDMPolicy hook so
|
||||
// apply() observes the supplied Policy. The original loader is restored at
|
||||
// test cleanup.
|
||||
func withMDMPolicy(t *testing.T, policy *mdm.Policy) {
|
||||
t.Helper()
|
||||
prev := loadMDMPolicy
|
||||
loadMDMPolicy = func() *mdm.Policy { return policy }
|
||||
t.Cleanup(func() { loadMDMPolicy = prev })
|
||||
}
|
||||
|
||||
func TestApply_MDMEmpty_NoEnforcement(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(nil))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.True(t, cfg.Policy().IsEmpty(), "no MDM source ⇒ empty Policy")
|
||||
assert.False(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
assert.Empty(t, cfg.Policy().ManagedKeys())
|
||||
|
||||
// Default management URL still resolves.
|
||||
assert.Equal(t, DefaultManagementURL, cfg.ManagementURL.String())
|
||||
}
|
||||
|
||||
func TestApply_MDMOnly_OverridesDefaults(t *testing.T) {
|
||||
const mdmURL = "https://corp.mdm.example.com:443"
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: mdmURL,
|
||||
mdm.KeyDisableClientRoutes: true,
|
||||
mdm.KeyBlockInbound: true,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.Equal(t, mdmURL, cfg.ManagementURL.String())
|
||||
assert.True(t, cfg.DisableClientRoutes)
|
||||
assert.True(t, cfg.BlockInbound)
|
||||
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyDisableClientRoutes))
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyBlockInbound))
|
||||
assert.False(t, cfg.Policy().HasKey(mdm.KeyAllowServerSSH))
|
||||
}
|
||||
|
||||
func TestApply_MDMBeatsCLIInput(t *testing.T) {
|
||||
const mdmURL = "https://mdm.example.com:443"
|
||||
const cliURL = "https://cli.example.com:443"
|
||||
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: mdmURL,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
ManagementURL: cliURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// MDM wins over CLI-supplied management URL.
|
||||
assert.Equal(t, mdmURL, cfg.ManagementURL.String())
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
}
|
||||
|
||||
func TestApply_MDMInvalidURL_KeepsPreviousValue(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "not-a-url",
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// Invalid MDM URL is logged and skipped: default URL stays in place
|
||||
// to keep the client functional.
|
||||
assert.Equal(t, DefaultManagementURL, cfg.ManagementURL.String())
|
||||
|
||||
// But the key is still considered MDM-managed (admin intent is to
|
||||
// enforce, daemon rejects user writes to this field — phase-1 scaffolding
|
||||
// reflects this by keeping Policy.HasKey true even on parse failure).
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
}
|
||||
|
||||
func TestApply_MDMBoolKeysOverrideOnDiskValue(t *testing.T) {
|
||||
tmp := filepath.Join(t.TempDir(), "config.json")
|
||||
|
||||
// Seed without MDM.
|
||||
withMDMPolicy(t, mdm.NewPolicy(nil))
|
||||
_, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: tmp,
|
||||
DisableClientRoutes: boolPtr(false),
|
||||
RosenpassEnabled: boolPtr(false),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now enable MDM enforcement for these keys.
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyDisableClientRoutes: true,
|
||||
mdm.KeyRosenpassEnabled: true,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{ConfigPath: tmp})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.True(t, cfg.DisableClientRoutes, "MDM override should flip on-disk false to true")
|
||||
assert.True(t, cfg.RosenpassEnabled)
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyDisableClientRoutes))
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyRosenpassEnabled))
|
||||
}
|
||||
|
||||
func TestApply_MDMPreSharedKeyRedactionSentinelRejected(t *testing.T) {
|
||||
const maskSentinel = "**********"
|
||||
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyPreSharedKey: maskSentinel,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// Mask sentinel must not be persisted as the actual PSK.
|
||||
assert.NotEqual(t, maskSentinel, cfg.PreSharedKey)
|
||||
// Key still marked managed so user writes are still rejected.
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyPreSharedKey))
|
||||
}
|
||||
|
||||
func boolPtr(b bool) *bool { return &b }
|
||||
118
client/internal/profilemanager/id.go
Normal file
118
client/internal/profilemanager/id.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package profilemanager
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const (
|
||||
// profileIDByteLen is the number of random bytes generated for a new
|
||||
// profile ID. The resulting hex string is twice this length.
|
||||
profileIDByteLen = 16
|
||||
|
||||
// shortIDLen is the number of leading characters of an ID we render in
|
||||
// list output. Profiles per device are few, so 8 chars is collision-safe
|
||||
// in practice and easy to type as a prefix.
|
||||
shortIDLen = 8
|
||||
|
||||
// maxProfileNameLen caps the human-readable profile name to keep table
|
||||
// output legible and prevent denial-of-service via huge JSON fields.
|
||||
maxProfileNameLen = 128
|
||||
|
||||
// maxProfileIDLen bounds the on-disk filename we'll accept. New
|
||||
// IDs are 32 hex chars, legacy stems are sanitized profile names. The
|
||||
// cap is generous enough to cover both without permitting absurdly
|
||||
// long filenames.
|
||||
maxProfileIDLen = 64
|
||||
)
|
||||
|
||||
type ID string
|
||||
|
||||
// generateProfileID returns a new random hex ID for a profile file.
|
||||
func generateProfileID() (ID, error) {
|
||||
buf := make([]byte, profileIDByteLen)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", fmt.Errorf("read random bytes: %w", err)
|
||||
}
|
||||
return ID(hex.EncodeToString(buf)), nil
|
||||
}
|
||||
|
||||
// IsValidProfileFilenameStem reports whether id is safe to use as the stem
|
||||
// of a profile JSON filename.
|
||||
func IsValidProfileFilenameStem(id ID) bool {
|
||||
s := id.String()
|
||||
if s == "" || len(s) > maxProfileIDLen {
|
||||
return false
|
||||
}
|
||||
if s == defaultProfileName {
|
||||
return true
|
||||
}
|
||||
if strings.ContainsAny(s, `/\`) || strings.Contains(s, "..") {
|
||||
return false
|
||||
}
|
||||
// filepath.Base catches any leftover separators on platforms with
|
||||
// exotic path conventions.
|
||||
if filepath.Base(s) != s {
|
||||
return false
|
||||
}
|
||||
for _, r := range s {
|
||||
if !(unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '-') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// sanitizeDisplayName normalizes a user-supplied profile display name for
|
||||
// storage. It strips ASCII control characters, rejects invalid UTF-8, and
|
||||
// caps the length. Emojis, spaces, punctuation, and non-ASCII letters are
|
||||
// preserved. Returns an error if nothing usable remains.
|
||||
func sanitizeDisplayName(name string) (string, error) {
|
||||
if !utf8.ValidString(name) {
|
||||
return "", fmt.Errorf("name is not valid UTF-8")
|
||||
}
|
||||
name = StripCtrlChars(name)
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return "", fmt.Errorf("name is empty after sanitization")
|
||||
}
|
||||
if utf8.RuneCountInString(name) > maxProfileNameLen {
|
||||
return "", fmt.Errorf("name exceeds %d characters", maxProfileNameLen)
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
// StripCtrlChars control characters from a name before printing it.
|
||||
func StripCtrlChars(name string) string {
|
||||
var b strings.Builder
|
||||
b.Grow(len(name))
|
||||
for _, r := range name {
|
||||
// Skip C0 controls and DEL, plus C1 controls (0x80–0x9F).
|
||||
if r < 0x20 || r == 0x7F || (r >= 0x80 && r <= 0x9F) {
|
||||
continue
|
||||
}
|
||||
b.WriteRune(r)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// ShortID truncates an ID for display.
|
||||
func (id ID) ShortID() string {
|
||||
if id == DefaultProfileName {
|
||||
return DefaultProfileName
|
||||
}
|
||||
runes := []rune(id)
|
||||
if len(runes) <= shortIDLen {
|
||||
return id.String()
|
||||
}
|
||||
return string(runes[:shortIDLen])
|
||||
}
|
||||
|
||||
func (id ID) String() string {
|
||||
return string(id)
|
||||
}
|
||||
@@ -19,19 +19,41 @@ const (
|
||||
)
|
||||
|
||||
type Profile struct {
|
||||
Name string
|
||||
// ID is the on-disk filename stem (without .json). For new profiles
|
||||
// it is a 32-char hex string; legacy profiles created before the
|
||||
// ID-keyed layout keep their original name as their ID. The reserved
|
||||
// value "default" identifies the special default profile.
|
||||
ID ID
|
||||
// Name is the human-readable display name. Falls back to ID when the
|
||||
// underlying JSON has no "name" field set.
|
||||
Name string
|
||||
// Path is the absolute path to the profile JSON. Populated by the
|
||||
// loader so callers do not have to reconstruct it from ID + dir.
|
||||
Path string
|
||||
IsActive bool
|
||||
}
|
||||
|
||||
func (p *Profile) FilePath() (string, error) {
|
||||
if p.Name == "" {
|
||||
return "", fmt.Errorf("active profile name is empty")
|
||||
if p.Path != "" {
|
||||
return p.Path, nil
|
||||
}
|
||||
|
||||
if p.Name == defaultProfileName {
|
||||
id := p.ID
|
||||
if id == "" {
|
||||
id = ID(p.Name)
|
||||
}
|
||||
if id == "" {
|
||||
return "", fmt.Errorf("profile ID is empty")
|
||||
}
|
||||
|
||||
if id == defaultProfileName {
|
||||
return DefaultConfigPath, nil
|
||||
}
|
||||
|
||||
if !IsValidProfileFilenameStem(id) {
|
||||
return "", fmt.Errorf("invalid profile ID: %q", id)
|
||||
}
|
||||
|
||||
username, err := user.Current()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get current user: %w", err)
|
||||
@@ -42,10 +64,13 @@ func (p *Profile) FilePath() (string, error) {
|
||||
return "", fmt.Errorf("failed to get config directory for user %s: %w", username.Username, err)
|
||||
}
|
||||
|
||||
return filepath.Join(configDir, p.Name+".json"), nil
|
||||
return filepath.Join(configDir, id.String()+".json"), nil
|
||||
}
|
||||
|
||||
func (p *Profile) IsDefault() bool {
|
||||
if p.ID != "" {
|
||||
return p.ID == defaultProfileName
|
||||
}
|
||||
return p.Name == defaultProfileName
|
||||
}
|
||||
|
||||
@@ -57,18 +82,24 @@ func NewProfileManager() *ProfileManager {
|
||||
return &ProfileManager{}
|
||||
}
|
||||
|
||||
// GetActiveProfile returns the active profile as recorded in the local
|
||||
// user state file. Only ID is populated.
|
||||
func (pm *ProfileManager) GetActiveProfile() (*Profile, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
prof := pm.getActiveProfileState()
|
||||
return &Profile{Name: prof}, nil
|
||||
id := pm.getActiveProfileState()
|
||||
return &Profile{ID: id}, nil
|
||||
}
|
||||
|
||||
func (pm *ProfileManager) SwitchProfile(profileName string) error {
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
// SwitchProfile records the given profile ID as active in the local user
|
||||
// state file.
|
||||
func (pm *ProfileManager) SwitchProfile(id ID) error {
|
||||
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
|
||||
return fmt.Errorf("invalid profile ID: %q", id)
|
||||
}
|
||||
|
||||
if err := pm.setActiveProfileState(profileName); err != nil {
|
||||
if err := pm.setActiveProfileState(id); err != nil {
|
||||
return fmt.Errorf("failed to switch profile: %w", err)
|
||||
}
|
||||
return nil
|
||||
@@ -85,7 +116,7 @@ func sanitizeProfileName(name string) string {
|
||||
}, name)
|
||||
}
|
||||
|
||||
func (pm *ProfileManager) getActiveProfileState() string {
|
||||
func (pm *ProfileManager) getActiveProfileState() ID {
|
||||
|
||||
configDir, err := getConfigDir()
|
||||
if err != nil {
|
||||
@@ -113,10 +144,10 @@ func (pm *ProfileManager) getActiveProfileState() string {
|
||||
return defaultProfileName
|
||||
}
|
||||
|
||||
return profileName
|
||||
return ID(profileName)
|
||||
}
|
||||
|
||||
func (pm *ProfileManager) setActiveProfileState(profileName string) error {
|
||||
func (pm *ProfileManager) setActiveProfileState(id ID) error {
|
||||
|
||||
configDir, err := getConfigDir()
|
||||
if err != nil {
|
||||
@@ -125,7 +156,7 @@ func (pm *ProfileManager) setActiveProfileState(profileName string) error {
|
||||
|
||||
statePath := filepath.Join(configDir, activeProfileStateFilename)
|
||||
|
||||
err = os.WriteFile(statePath, []byte(profileName), 0600)
|
||||
err = os.WriteFile(statePath, []byte(id), 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write active profile state: %w", err)
|
||||
}
|
||||
@@ -142,7 +173,7 @@ func GetLoginHint() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
profileState, err := pm.GetProfileState(activeProf.ID)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
return ""
|
||||
|
||||
@@ -50,14 +50,14 @@ func TestServiceManager_CreateAndGetDefaultProfile(t *testing.T) {
|
||||
|
||||
state, err := sm.GetActiveProfileState()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, state.Name, defaultProfileName) // No active profile state yet
|
||||
assert.Equal(t, defaultProfileName, state.ID.String()) // No active profile state yet
|
||||
|
||||
err = sm.SetActiveProfileStateToDefault()
|
||||
assert.NoError(t, err)
|
||||
|
||||
active, err := sm.GetActiveProfileState()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "default", active.Name)
|
||||
assert.Equal(t, "default", active.ID.String())
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -92,14 +92,14 @@ func TestServiceManager_SetActiveProfileState(t *testing.T) {
|
||||
currUser, err := user.Current()
|
||||
assert.NoError(t, err)
|
||||
sm := &ServiceManager{}
|
||||
state := &ActiveProfileState{Name: "foo", Username: currUser.Username}
|
||||
state := &ActiveProfileState{ID: "foo", Username: currUser.Username}
|
||||
err = sm.SetActiveProfileState(state)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Should error on nil or incomplete state
|
||||
err = sm.SetActiveProfileState(nil)
|
||||
assert.Error(t, err)
|
||||
err = sm.SetActiveProfileState(&ActiveProfileState{Name: "", Username: ""})
|
||||
err = sm.SetActiveProfileState(&ActiveProfileState{ID: "", Username: ""})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -2,6 +2,7 @@ package profilemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -23,12 +24,43 @@ var (
|
||||
DefaultConfigPathDir = ""
|
||||
DefaultConfigPath = ""
|
||||
ActiveProfileStatePath = ""
|
||||
)
|
||||
|
||||
var (
|
||||
ErrorOldDefaultConfigNotFound = errors.New("old default config not found")
|
||||
)
|
||||
|
||||
// ErrAmbiguousHandle is returned when a profile handle (ID prefix or name)
|
||||
// matches more than one profile. Callers can render Candidates to help the
|
||||
// user disambiguate.
|
||||
type ErrAmbiguousHandle struct {
|
||||
Handle string
|
||||
Candidates []Profile
|
||||
Kind AmbiguityKind
|
||||
}
|
||||
|
||||
// AmbiguityKind describes which matcher produced the ambiguity, so callers
|
||||
// can tailor the error message.
|
||||
type AmbiguityKind int
|
||||
|
||||
const (
|
||||
AmbiguityKindIDPrefix AmbiguityKind = iota
|
||||
AmbiguityKindName
|
||||
)
|
||||
|
||||
// profileMeta is the minimal slice of a profile JSON we need, so we avoid
|
||||
// reading all fields
|
||||
type profileMeta struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
func (e *ErrAmbiguousHandle) Error() string {
|
||||
switch e.Kind {
|
||||
case AmbiguityKindIDPrefix:
|
||||
return fmt.Sprintf("ID prefix %q is ambiguous (matches %d profiles)", e.Handle, len(e.Candidates))
|
||||
default:
|
||||
return fmt.Sprintf("name %q is ambiguous (%d profiles share this name)", e.Handle, len(e.Candidates))
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
DefaultConfigPathDir = "/var/lib/netbird/"
|
||||
@@ -54,25 +86,34 @@ func init() {
|
||||
}
|
||||
|
||||
type ActiveProfileState struct {
|
||||
Name string `json:"name"`
|
||||
// ID is the on-disk filename stem of the active profile. The JSON tag stays
|
||||
// as "name" for backwards compatibility with active state files written
|
||||
// before the ID-based config files. Legacy values were profile names, which
|
||||
// were also the legacy filename stems, so they still resolve to the correct
|
||||
// file on disk.
|
||||
ID ID `json:"name"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
func (a *ActiveProfileState) FilePath() (string, error) {
|
||||
if a.Name == "" {
|
||||
return "", fmt.Errorf("active profile name is empty")
|
||||
if a.ID == "" {
|
||||
return "", fmt.Errorf("active profile ID is empty")
|
||||
}
|
||||
|
||||
if a.Name == defaultProfileName {
|
||||
if a.ID == defaultProfileName {
|
||||
return DefaultConfigPath, nil
|
||||
}
|
||||
|
||||
if !IsValidProfileFilenameStem(a.ID) {
|
||||
return "", fmt.Errorf("invalid profile ID: %q", a.ID)
|
||||
}
|
||||
|
||||
configDir, err := getConfigDirForUser(a.Username)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get config directory for user %s: %w", a.Username, err)
|
||||
}
|
||||
|
||||
return filepath.Join(configDir, a.Name+".json"), nil
|
||||
return filepath.Join(configDir, a.ID.String()+".json"), nil
|
||||
}
|
||||
|
||||
type ServiceManager struct {
|
||||
@@ -178,7 +219,7 @@ func (s *ServiceManager) GetActiveProfileState() (*ActiveProfileState, error) {
|
||||
return nil, fmt.Errorf("failed to set active profile to default: %w", err)
|
||||
}
|
||||
return &ActiveProfileState{
|
||||
Name: "default",
|
||||
ID: defaultProfileName,
|
||||
Username: "",
|
||||
}, nil
|
||||
} else {
|
||||
@@ -186,12 +227,12 @@ func (s *ServiceManager) GetActiveProfileState() (*ActiveProfileState, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if activeProfile.Name == "" {
|
||||
if activeProfile.ID == "" {
|
||||
if err := s.SetActiveProfileStateToDefault(); err != nil {
|
||||
return nil, fmt.Errorf("failed to set active profile to default: %w", err)
|
||||
}
|
||||
return &ActiveProfileState{
|
||||
Name: "default",
|
||||
ID: defaultProfileName,
|
||||
Username: "",
|
||||
}, nil
|
||||
}
|
||||
@@ -216,25 +257,29 @@ func (s *ServiceManager) setDefaultActiveState() error {
|
||||
}
|
||||
|
||||
func (s *ServiceManager) SetActiveProfileState(a *ActiveProfileState) error {
|
||||
if a == nil || a.Name == "" {
|
||||
if a == nil || a.ID == "" {
|
||||
return errors.New("invalid active profile state")
|
||||
}
|
||||
|
||||
if a.Name != defaultProfileName && a.Username == "" {
|
||||
return fmt.Errorf("username must be set for non-default profiles, got: %s", a.Name)
|
||||
if a.ID != defaultProfileName && a.Username == "" {
|
||||
return fmt.Errorf("username must be set for non-default profiles, got: %s", a.ID)
|
||||
}
|
||||
|
||||
if a.ID != defaultProfileName && !IsValidProfileFilenameStem(a.ID) {
|
||||
return fmt.Errorf("invalid profile ID: %q", a.ID)
|
||||
}
|
||||
|
||||
if err := util.WriteJsonWithRestrictedPermission(context.Background(), ActiveProfileStatePath, a); err != nil {
|
||||
return fmt.Errorf("failed to write active profile state: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("active profile set to %s for %s", a.Name, a.Username)
|
||||
log.Infof("active profile set to %s for %s", a.ID, a.Username)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServiceManager) SetActiveProfileStateToDefault() error {
|
||||
return s.SetActiveProfileState(&ActiveProfileState{
|
||||
Name: "default",
|
||||
ID: defaultProfileName,
|
||||
Username: "",
|
||||
})
|
||||
}
|
||||
@@ -243,57 +288,117 @@ func (s *ServiceManager) DefaultProfilePath() string {
|
||||
return DefaultConfigPath
|
||||
}
|
||||
|
||||
func (s *ServiceManager) AddProfile(profileName, username string) error {
|
||||
// AddProfile creates a new profile with a generated ID. The user-supplied
|
||||
// displayName is stored inside the JSON's name field, the on-disk filename
|
||||
// uses the generated ID.
|
||||
//
|
||||
// The returned Profile carries the freshly-generated ID so callers can
|
||||
// show it to the user (and so the gRPC AddProfileResponse can include
|
||||
// it).
|
||||
func (s *ServiceManager) AddProfile(displayName, username string) (*Profile, error) {
|
||||
configDir, err := s.getConfigDir(username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get config directory: %w", err)
|
||||
return nil, fmt.Errorf("failed to get config directory: %w", err)
|
||||
}
|
||||
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
|
||||
if profileName == defaultProfileName {
|
||||
return fmt.Errorf("cannot create profile with reserved name: %s", defaultProfileName)
|
||||
}
|
||||
|
||||
profPath := filepath.Join(configDir, profileName+".json")
|
||||
profileExists, err := fileExists(profPath)
|
||||
displayName, err = sanitizeDisplayName(displayName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if profile exists: %w", err)
|
||||
}
|
||||
if profileExists {
|
||||
return ErrProfileAlreadyExists
|
||||
return nil, fmt.Errorf("invalid profile name: %w", err)
|
||||
}
|
||||
|
||||
id, err := generateProfileID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate profile id: %w", err)
|
||||
}
|
||||
|
||||
profPath := filepath.Join(configDir, id.String()+".json")
|
||||
cfg, err := createNewConfig(ConfigInput{ConfigPath: profPath})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create new config: %w", err)
|
||||
return nil, fmt.Errorf("failed to create new config: %w", err)
|
||||
}
|
||||
cfg.Name = displayName
|
||||
|
||||
if err := util.WriteJson(context.Background(), profPath, cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to write profile config: %w", err)
|
||||
}
|
||||
|
||||
err = util.WriteJson(context.Background(), profPath, cfg)
|
||||
return &Profile{
|
||||
ID: id,
|
||||
Name: displayName,
|
||||
Path: profPath,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ServiceManager) RenameProfile(id ID, username string, newName string) error {
|
||||
displayName, err := sanitizeDisplayName(newName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write profile config: %w", err)
|
||||
return fmt.Errorf("invalid profile name: %w", err)
|
||||
}
|
||||
|
||||
if !IsValidProfileFilenameStem(id) {
|
||||
return fmt.Errorf("invalid profile ID: %q", id)
|
||||
}
|
||||
|
||||
profiles, err := s.loadAllProfiles(username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load profiles: %w", err)
|
||||
}
|
||||
|
||||
var target *Profile
|
||||
for i := range profiles {
|
||||
if profiles[i].ID == id {
|
||||
target = &profiles[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if target == nil {
|
||||
return ErrProfileNotFound
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(target.Path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var cfg Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.Name = displayName
|
||||
|
||||
if err := util.WriteJson(context.Background(), target.Path, cfg); err != nil {
|
||||
return fmt.Errorf("failed to write profile name: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServiceManager) RemoveProfile(profileName, username string) error {
|
||||
configDir, err := s.getConfigDir(username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get config directory: %w", err)
|
||||
// RemoveProfile deletes the profile identified by id. Callers must have
|
||||
// already resolved any user-supplied handle to a concrete ID via
|
||||
// ResolveProfile.
|
||||
func (s *ServiceManager) RemoveProfile(id ID, username string) error {
|
||||
if id == defaultProfileName {
|
||||
defaultName := readProfileName(DefaultConfigPath)
|
||||
if defaultName == "" {
|
||||
defaultName = defaultProfileName
|
||||
}
|
||||
return fmt.Errorf("cannot remove default profile with name: %s", defaultName)
|
||||
}
|
||||
if !IsValidProfileFilenameStem(id) {
|
||||
return fmt.Errorf("invalid profile ID: %q", id)
|
||||
}
|
||||
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
|
||||
if profileName == defaultProfileName {
|
||||
return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
|
||||
}
|
||||
profPath := filepath.Join(configDir, profileName+".json")
|
||||
profileExists, err := fileExists(profPath)
|
||||
profiles, err := s.loadAllProfiles(username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if profile exists: %w", err)
|
||||
return fmt.Errorf("load profiles: %w", err)
|
||||
}
|
||||
if !profileExists {
|
||||
|
||||
var target *Profile
|
||||
for i := range profiles {
|
||||
if profiles[i].ID == id {
|
||||
target = &profiles[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if target == nil {
|
||||
return ErrProfileNotFound
|
||||
}
|
||||
|
||||
@@ -301,57 +406,26 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
|
||||
if err != nil && !errors.Is(err, ErrNoActiveProfile) {
|
||||
return fmt.Errorf("failed to get active profile: %w", err)
|
||||
}
|
||||
|
||||
if activeProf != nil && activeProf.Name == profileName {
|
||||
return fmt.Errorf("cannot remove active profile: %s", profileName)
|
||||
if activeProf != nil && activeProf.ID == id {
|
||||
return fmt.Errorf("cannot remove active profile: %s", id)
|
||||
}
|
||||
|
||||
err = util.RemoveJson(profPath)
|
||||
if err != nil {
|
||||
if err := util.RemoveJson(target.Path); err != nil {
|
||||
return fmt.Errorf("failed to remove profile config: %w", err)
|
||||
}
|
||||
|
||||
stateFile := filepath.Join(filepath.Dir(target.Path), id.String()+".state.json")
|
||||
if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) {
|
||||
log.Warnf("failed to remove profile state file %s: %v", stateFile, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListProfiles returns every profile for the given user, including the
|
||||
// default profile, with IsActive flags set.
|
||||
func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) {
|
||||
configDir, err := s.getConfigDir(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get config directory: %w", err)
|
||||
}
|
||||
|
||||
files, err := util.ListFiles(configDir, "*.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list profile files: %w", err)
|
||||
}
|
||||
|
||||
var filtered []string
|
||||
for _, file := range files {
|
||||
if strings.HasSuffix(file, "state.json") {
|
||||
continue // skip state files
|
||||
}
|
||||
filtered = append(filtered, file)
|
||||
}
|
||||
sort.Strings(filtered)
|
||||
|
||||
var activeProfName string
|
||||
activeProf, err := s.GetActiveProfileState()
|
||||
if err == nil {
|
||||
activeProfName = activeProf.Name
|
||||
}
|
||||
|
||||
var profiles []Profile
|
||||
// add default profile always
|
||||
profiles = append(profiles, Profile{Name: defaultProfileName, IsActive: activeProfName == "" || activeProfName == defaultProfileName})
|
||||
for _, file := range filtered {
|
||||
profileName := strings.TrimSuffix(filepath.Base(file), ".json")
|
||||
var isActive bool
|
||||
if activeProfName != "" && activeProfName == profileName {
|
||||
isActive = true
|
||||
}
|
||||
profiles = append(profiles, Profile{Name: profileName, IsActive: isActive})
|
||||
}
|
||||
|
||||
return profiles, nil
|
||||
return s.loadAllProfiles(username)
|
||||
}
|
||||
|
||||
// GetStatePath returns the path to the state file based on the operating system
|
||||
@@ -369,7 +443,12 @@ func (s *ServiceManager) GetStatePath() string {
|
||||
return defaultStatePath
|
||||
}
|
||||
|
||||
if activeProf.Name == defaultProfileName {
|
||||
if activeProf.ID == defaultProfileName {
|
||||
return defaultStatePath
|
||||
}
|
||||
|
||||
if !IsValidProfileFilenameStem(activeProf.ID) {
|
||||
log.Warnf("invalid active profile ID %q, using default state path", activeProf.ID)
|
||||
return defaultStatePath
|
||||
}
|
||||
|
||||
@@ -379,7 +458,7 @@ func (s *ServiceManager) GetStatePath() string {
|
||||
return defaultStatePath
|
||||
}
|
||||
|
||||
return filepath.Join(configDir, activeProf.Name+".state.json")
|
||||
return filepath.Join(configDir, activeProf.ID.String()+".state.json")
|
||||
}
|
||||
|
||||
// getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser
|
||||
@@ -390,3 +469,169 @@ func (s *ServiceManager) getConfigDir(username string) (string, error) {
|
||||
|
||||
return getConfigDirForUser(username)
|
||||
}
|
||||
|
||||
// loadAllProfiles returns every profile visible to the daemon for the
|
||||
// given user, including the default profile. The returned slice is sorted
|
||||
// by ID for a stable display order.
|
||||
//
|
||||
// Each Profile is fully populated: ID is the filename stem, Name comes
|
||||
// from the JSON's "name" field (falling back to the filename stem when absent)
|
||||
// and Path is built from a basename read off disk.
|
||||
func (s *ServiceManager) loadAllProfiles(username string) ([]Profile, error) {
|
||||
activeID, activeIsDefault := s.activeProfileID()
|
||||
defaultName := readProfileName(DefaultConfigPath)
|
||||
if defaultName == "" {
|
||||
defaultName = defaultProfileName
|
||||
}
|
||||
|
||||
profiles := []Profile{{
|
||||
ID: defaultProfileName,
|
||||
Name: defaultName,
|
||||
Path: DefaultConfigPath,
|
||||
IsActive: activeIsDefault,
|
||||
}}
|
||||
|
||||
configDir, err := s.getConfigDir(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get config directory: %w", err)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(configDir)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return profiles, nil
|
||||
}
|
||||
return nil, fmt.Errorf("read profile directory: %w", err)
|
||||
}
|
||||
|
||||
var fileProfiles []Profile
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
base := entry.Name()
|
||||
if !strings.HasSuffix(base, ".json") {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(base, ".state.json") {
|
||||
continue
|
||||
}
|
||||
stem := ID(strings.TrimSuffix(base, ".json"))
|
||||
if stem == defaultProfileName {
|
||||
// default lives at the top-level config dir, not under /<user>
|
||||
continue
|
||||
}
|
||||
if !IsValidProfileFilenameStem(ID(stem)) {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(configDir, base)
|
||||
name := readProfileName(path)
|
||||
if name == "" {
|
||||
name = stem.String()
|
||||
}
|
||||
fileProfiles = append(fileProfiles, Profile{
|
||||
ID: stem,
|
||||
Name: name,
|
||||
Path: path,
|
||||
IsActive: stem == ID(activeID),
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(fileProfiles, func(i, j int) bool {
|
||||
if fileProfiles[i].Name != fileProfiles[j].Name {
|
||||
return fileProfiles[i].Name < fileProfiles[j].Name
|
||||
}
|
||||
// Sort tie-break on ID so duplicate names always render in the same order.
|
||||
return fileProfiles[i].ID < fileProfiles[j].ID
|
||||
})
|
||||
profiles = append(profiles, fileProfiles...)
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
// readProfileName parses just the "name" field from the profile Json.
|
||||
func readProfileName(path string) string {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
var meta profileMeta
|
||||
if err := json.Unmarshal(data, &meta); err != nil {
|
||||
return ""
|
||||
}
|
||||
return meta.Name
|
||||
}
|
||||
|
||||
// activeProfileID returns the currently-active profile's ID. The second
|
||||
// return value is true when the active profile is the default one.
|
||||
func (s *ServiceManager) activeProfileID() (ID, bool) {
|
||||
state, err := s.GetActiveProfileState()
|
||||
if err != nil || state == nil {
|
||||
return defaultProfileName, true
|
||||
}
|
||||
if state.ID == "" || state.ID == defaultProfileName {
|
||||
return defaultProfileName, true
|
||||
}
|
||||
return state.ID, false
|
||||
}
|
||||
|
||||
// ResolveProfile turns a user-supplied handle into a Profile. Resolution
|
||||
// precedence is: exact ID match, then unique ID prefix, then unique exact
|
||||
// name. Ambiguous matches return *ErrAmbiguousHandle so callers can
|
||||
// surface the candidates.
|
||||
func (s *ServiceManager) ResolveProfile(handle, username string) (*Profile, error) {
|
||||
if handle == "" {
|
||||
return nil, fmt.Errorf("profile handle is empty")
|
||||
}
|
||||
|
||||
profiles, err := s.loadAllProfiles(username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range profiles {
|
||||
if profiles[i].ID == ID(handle) {
|
||||
return &profiles[i], nil
|
||||
}
|
||||
}
|
||||
|
||||
// ID prefix match. Skip the default profile so `select d` does not
|
||||
// accidentally pick it via prefix.
|
||||
var prefixMatches []Profile
|
||||
for i := range profiles {
|
||||
if profiles[i].ID == defaultProfileName {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(profiles[i].ID.String(), handle) {
|
||||
prefixMatches = append(prefixMatches, profiles[i])
|
||||
}
|
||||
}
|
||||
if len(prefixMatches) == 1 {
|
||||
return &prefixMatches[0], nil
|
||||
}
|
||||
if len(prefixMatches) > 1 {
|
||||
return nil, &ErrAmbiguousHandle{
|
||||
Handle: handle,
|
||||
Candidates: prefixMatches,
|
||||
Kind: AmbiguityKindIDPrefix,
|
||||
}
|
||||
}
|
||||
|
||||
var nameMatches []Profile
|
||||
for i := range profiles {
|
||||
if profiles[i].Name == handle {
|
||||
nameMatches = append(nameMatches, profiles[i])
|
||||
}
|
||||
}
|
||||
if len(nameMatches) == 1 {
|
||||
return &nameMatches[0], nil
|
||||
}
|
||||
if len(nameMatches) > 1 {
|
||||
return nil, &ErrAmbiguousHandle{
|
||||
Handle: handle,
|
||||
Candidates: nameMatches,
|
||||
Kind: AmbiguityKindName,
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrProfileNotFound
|
||||
}
|
||||
|
||||
230
client/internal/profilemanager/service_test.go
Normal file
230
client/internal/profilemanager/service_test.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package profilemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// withTestSM wires up patched globals + a clean config dir and returns a
|
||||
// fully initialized ServiceManager plus the username we are scoped to.
|
||||
func withTestSM(t *testing.T, fn func(sm *ServiceManager, username string)) {
|
||||
t.Helper()
|
||||
withTempConfigDir(t, func(configDir string) {
|
||||
withPatchedGlobals(t, configDir, func() {
|
||||
u, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
sm := &ServiceManager{}
|
||||
require.NoError(t, sm.CreateDefaultProfile())
|
||||
fn(sm, u.Username)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_ExactID(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
created, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := sm.ResolveProfile(created.ID.String(), username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, created.ID, got.ID)
|
||||
assert.Equal(t, "work", got.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_IDPrefix(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
created, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
|
||||
prefix := created.ID[:4]
|
||||
got, err := sm.ResolveProfile(prefix.String(), username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, created.ID, got.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_AmbiguousPrefix(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
// Plant two profiles whose IDs share a known prefix by writing
|
||||
// the files directly, since generated IDs are random.
|
||||
configDir, err := sm.getConfigDir(username)
|
||||
require.NoError(t, err)
|
||||
for _, id := range []string{"abcd1111aaaa", "abcd2222bbbb"} {
|
||||
path := filepath.Join(configDir, id+".json")
|
||||
require.NoError(t, util.WriteJson(context.Background(), path, &Config{Name: id}))
|
||||
}
|
||||
|
||||
_, err = sm.ResolveProfile("abcd", username)
|
||||
var amb *ErrAmbiguousHandle
|
||||
require.ErrorAs(t, err, &amb)
|
||||
assert.Equal(t, AmbiguityKindIDPrefix, amb.Kind)
|
||||
assert.Len(t, amb.Candidates, 2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_ExactNameUnique(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
_, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := sm.ResolveProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "work", got.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_AmbiguousName(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
_, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
_, err = sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = sm.ResolveProfile("work", username)
|
||||
var amb *ErrAmbiguousHandle
|
||||
require.ErrorAs(t, err, &amb)
|
||||
assert.Equal(t, AmbiguityKindName, amb.Kind)
|
||||
assert.Len(t, amb.Candidates, 2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_NotFound(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
_, err := sm.ResolveProfile("nope", username)
|
||||
assert.ErrorIs(t, err, ErrProfileNotFound)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_DefaultByExactID(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
got, err := sm.ResolveProfile(defaultProfileName, username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, defaultProfileName, got.ID.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_LegacyFilenameCoexists(t *testing.T) {
|
||||
// Legacy profiles stored as <name>.json with no "name" JSON field
|
||||
// should still be discoverable by name and removable by name.
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
configDir, err := sm.getConfigDir(username)
|
||||
require.NoError(t, err)
|
||||
path := filepath.Join(configDir, "legacy.json")
|
||||
require.NoError(t, util.WriteJson(context.Background(), path, &Config{}))
|
||||
|
||||
got, err := sm.ResolveProfile("legacy", username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "legacy", got.ID.String())
|
||||
// Name falls back to the filename stem when JSON omits it.
|
||||
assert.Equal(t, "legacy", got.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddProfile_AllowsDuplicateWithFlag(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
first, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
|
||||
second, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, first.ID, second.ID)
|
||||
assert.Equal(t, "work", second.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddProfile_RejectsInvalidNames(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
cases := []string{
|
||||
"", // empty
|
||||
"\x00\x01", // only control chars (becomes empty)
|
||||
strings.Repeat("a", maxProfileNameLen+1), // too long
|
||||
}
|
||||
for _, name := range cases {
|
||||
_, err := sm.AddProfile(name, username)
|
||||
assert.Error(t, err, "expected error for %q", name)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRemoveProfile_RejectsInvalidID(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
err := sm.RemoveProfile("../escape", username)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSanitizeDisplayName(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"work", "work", false},
|
||||
{"My Work Account", "My Work Account", false},
|
||||
{"emoji 🚀 ok", "emoji 🚀 ok", false},
|
||||
{"漢字テスト", "漢字テスト", false},
|
||||
{"with\x00null", "withnull", false},
|
||||
{"\x01\x02\x03", "", true},
|
||||
{"", "", true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
got, err := sanitizeDisplayName(tc.in)
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err, "case %q", tc.in)
|
||||
continue
|
||||
}
|
||||
assert.NoError(t, err, "case %q", tc.in)
|
||||
assert.Equal(t, tc.want, got, "case %q", tc.in)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidProfileFilenameStem(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{"default", true},
|
||||
{"abc123def456", true},
|
||||
{"legacy-name", true},
|
||||
{"legacy_name", true},
|
||||
{"", false},
|
||||
{"..", false},
|
||||
{"../etc", false},
|
||||
{"foo/bar", false},
|
||||
{`foo\bar`, false},
|
||||
{"with space", false},
|
||||
{"with.dot", false},
|
||||
{strings.Repeat("a", maxProfileIDLen+1), false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
got := IsValidProfileFilenameStem(ID(tc.in))
|
||||
assert.Equal(t, tc.want, got, "case %q", tc.in)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveProfile_DeletesStateFile(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
created, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
|
||||
configDir, err := sm.getConfigDir(username)
|
||||
require.NoError(t, err)
|
||||
statePath := filepath.Join(configDir, created.ID.String()+".state.json")
|
||||
require.NoError(t, os.WriteFile(statePath, []byte(`{"email":"a@b"}`), 0600))
|
||||
|
||||
require.NoError(t, sm.RemoveProfile(created.ID, username))
|
||||
_, err = os.Stat(statePath)
|
||||
assert.True(t, errors.Is(err, os.ErrNotExist), "state file should be removed")
|
||||
})
|
||||
}
|
||||
@@ -13,13 +13,20 @@ type ProfileState struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, error) {
|
||||
// GetProfileState reads the per-profile state file keyed by profile ID.
|
||||
// The state file lives in the user's config directory. Legacy state files
|
||||
// keyed by the old profile name remain readable.
|
||||
func (pm *ProfileManager) GetProfileState(id ID) (*ProfileState, error) {
|
||||
configDir, err := getConfigDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get config directory: %w", err)
|
||||
}
|
||||
|
||||
stateFile := filepath.Join(configDir, profileName+".state.json")
|
||||
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
|
||||
return nil, fmt.Errorf("invalid profile ID: %q", id)
|
||||
}
|
||||
|
||||
stateFile := filepath.Join(configDir, id.String()+".state.json")
|
||||
stateFileExists, err := fileExists(stateFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if profile state file exists: %w", err)
|
||||
@@ -51,7 +58,12 @@ func (pm *ProfileManager) SetActiveProfileState(state *ProfileState) error {
|
||||
return fmt.Errorf("get active profile: %w", err)
|
||||
}
|
||||
|
||||
stateFile := filepath.Join(configDir, activeProf.Name+".state.json")
|
||||
id := activeProf.ID
|
||||
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
|
||||
return fmt.Errorf("invalid active profile ID: %q", id)
|
||||
}
|
||||
|
||||
stateFile := filepath.Join(configDir, id.String()+".state.json")
|
||||
err = util.WriteJsonWithRestrictedPermission(context.Background(), stateFile, state)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write profile state: %w", err)
|
||||
|
||||
@@ -700,13 +700,6 @@ func resolveURLsToIPs(urls []string) []net.IP {
|
||||
|
||||
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
|
||||
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
|
||||
// An explicit user "deselect all" must not be overridden by management auto-apply.
|
||||
// Auto-applying an exit node here would call SelectRoutes, which clears the
|
||||
// deselect-all flag and re-enables every route the user turned off.
|
||||
if m.routeSelector.IsDeselectAll() {
|
||||
return
|
||||
}
|
||||
|
||||
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
|
||||
if len(exitNodeInfo.allIDs) == 0 {
|
||||
return
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func exitNodeRoutes(netID route.NetID, skipAutoApply bool) route.HAMap {
|
||||
haID := route.HAUniqueID(string(netID) + "|0.0.0.0/0")
|
||||
return route.HAMap{
|
||||
haID: []*route.Route{
|
||||
{
|
||||
ID: "r-" + route.ID(netID),
|
||||
NetID: netID,
|
||||
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
NetworkType: route.IPv4Network,
|
||||
Enabled: true,
|
||||
SkipAutoApply: skipAutoApply,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement(t *testing.T) {
|
||||
t.Run("management auto-apply selects exit node without user selection", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
routes := exitNodeRoutes("exit1", false)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsSelected("exit1"), "auto-apply exit node should be selected")
|
||||
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "selected exit node should pass the filter")
|
||||
})
|
||||
|
||||
t.Run("management SkipAutoApply leaves exit node deselected", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
routes := exitNodeRoutes("exit1", true)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.False(t, m.routeSelector.IsSelected("exit1"), "SkipAutoApply exit node should not be selected")
|
||||
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "deselected exit node should be filtered out")
|
||||
})
|
||||
|
||||
t.Run("user selection is not overridden by management", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
require.NoError(t, m.routeSelector.SelectRoutes([]route.NetID{"exit1"}, true, []route.NetID{"exit1"}))
|
||||
routes := exitNodeRoutes("exit1", true)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsSelected("exit1"), "explicit user selection must survive a management sync that wants to skip auto-apply")
|
||||
require.Len(t, m.routeSelector.FilterSelectedExitNodes(routes), 1, "user-selected exit node should pass the filter")
|
||||
})
|
||||
|
||||
t.Run("deselect-all is preserved across a management sync", func(t *testing.T) {
|
||||
m := &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
m.routeSelector.DeselectAllRoutes()
|
||||
routes := exitNodeRoutes("exit1", false)
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
require.True(t, m.routeSelector.IsDeselectAll(), "an explicit deselect-all must not be cleared by management auto-apply")
|
||||
require.Empty(t, m.routeSelector.FilterSelectedExitNodes(routes), "no routes should be selected while deselect-all is set")
|
||||
})
|
||||
}
|
||||
@@ -116,14 +116,6 @@ func (rs *RouteSelector) DeselectAllRoutes() {
|
||||
clear(rs.selectedRoutes)
|
||||
}
|
||||
|
||||
// IsDeselectAll reports whether the user has explicitly deselected all routes.
|
||||
func (rs *RouteSelector) IsDeselectAll() bool {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
return rs.deselectAll
|
||||
}
|
||||
|
||||
// IsSelected checks if a specific route is selected.
|
||||
func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
rs.mu.RLock()
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package mdm
|
||||
|
||||
import "strings"
|
||||
|
||||
// allKeys is the set of recognised MDM keys. Unknown keys in a managed
|
||||
// configuration are ignored but logged. Lives in this build-tagged file
|
||||
// (windows || darwin) because only desktop loaders need the
|
||||
// canonicalisation table that consumes it; including it unconditionally
|
||||
// would trigger the `unused` golangci-lint check on platforms that
|
||||
// don't import canonical_loaders.go.
|
||||
var allKeys = []string{
|
||||
KeyManagementURL,
|
||||
KeyDisableUpdateSettings,
|
||||
KeyDisableProfiles,
|
||||
KeyDisableNetworks,
|
||||
KeyDisableClientRoutes,
|
||||
KeyDisableServerRoutes,
|
||||
KeyBlockInbound,
|
||||
KeyDisableMetricsCollection,
|
||||
KeyAllowServerSSH,
|
||||
KeyDisableAutoConnect,
|
||||
KeyPreSharedKey,
|
||||
KeyRosenpassEnabled,
|
||||
KeyRosenpassPermissive,
|
||||
KeyWireguardPort,
|
||||
KeySplitTunnelMode,
|
||||
KeySplitTunnelApps,
|
||||
}
|
||||
|
||||
// canonicalKey maps the lowercase form of a managed-config value name to
|
||||
// its canonical mdm.Key* form. Admins commonly write PascalCase value
|
||||
// names in ADMX / Group Policy ("ManagementURL"); the iOS/AppConfig and
|
||||
// macOS plist conventions are camelCase ("managementURL"); both must
|
||||
// resolve to the same Policy lookup.
|
||||
//
|
||||
// Lives in a desktop-loader-only file (build tag `windows || darwin`)
|
||||
// because no other build path consumes it. Linux / FreeBSD / mobile
|
||||
// builds don't ship a platform loader that reads arbitrary-case key
|
||||
// names, so they don't need the canonicalisation table — and including
|
||||
// the var unconditionally would trigger the `unused` golangci-lint
|
||||
// check on those platforms.
|
||||
var canonicalKey = func() map[string]string {
|
||||
m := make(map[string]string, len(allKeys))
|
||||
for _, k := range allKeys {
|
||||
m[strings.ToLower(k)] = k
|
||||
}
|
||||
return m
|
||||
}()
|
||||
@@ -1,247 +0,0 @@
|
||||
// Package mdm reads MDM-managed configuration from platform-native sources
|
||||
// (plist on macOS, registry on Windows, UserDefaults on iOS,
|
||||
// RestrictionsManager on Android). The returned Policy is consumed by
|
||||
// profilemanager.Config.apply() as the highest-priority override layer.
|
||||
//
|
||||
// An empty Policy (no source present, or source present with zero keys)
|
||||
// means no MDM enforcement is active and the client behaves as if the
|
||||
// feature did not exist.
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Well-known policy keys. Names mirror the corresponding ConfigInput Go field
|
||||
// names (lowerCamelCase) so the daemon can map a Policy key directly to a
|
||||
// configuration field.
|
||||
const (
|
||||
KeyManagementURL = "managementURL"
|
||||
KeyDisableUpdateSettings = "disableUpdateSettings"
|
||||
KeyDisableProfiles = "disableProfiles"
|
||||
KeyDisableNetworks = "disableNetworks"
|
||||
KeyDisableClientRoutes = "disableClientRoutes"
|
||||
KeyDisableServerRoutes = "disableServerRoutes"
|
||||
KeyBlockInbound = "blockInbound"
|
||||
KeyDisableMetricsCollection = "disableMetricsCollection"
|
||||
KeyAllowServerSSH = "allowServerSSH"
|
||||
KeyDisableAutoConnect = "disableAutoConnect"
|
||||
KeyPreSharedKey = "preSharedKey"
|
||||
KeyRosenpassEnabled = "rosenpassEnabled"
|
||||
KeyRosenpassPermissive = "rosenpassPermissive"
|
||||
KeyWireguardPort = "wireguardPort"
|
||||
|
||||
// Split tunnel is modeled as a single conceptual policy with two
|
||||
// registry/plist values. KeySplitTunnelMode is the discriminator
|
||||
// ("allow" or "disallow"); KeySplitTunnelApps is a comma-separated
|
||||
// list of package names. The values are mutually exclusive by
|
||||
// construction — only one mode can be set at a time.
|
||||
KeySplitTunnelMode = "splitTunnelMode"
|
||||
KeySplitTunnelApps = "splitTunnelApps"
|
||||
)
|
||||
|
||||
// Split-tunnel mode literals (KeySplitTunnelMode values).
|
||||
const (
|
||||
SplitTunnelModeAllow = "allow"
|
||||
SplitTunnelModeDisallow = "disallow"
|
||||
)
|
||||
|
||||
// SecretKeys lists keys whose values must be redacted in logs.
|
||||
var SecretKeys = map[string]struct{}{
|
||||
KeyPreSharedKey: {},
|
||||
}
|
||||
|
||||
// boolStringLiterals enumerates the textual boolean encodings the
|
||||
// platform loaders may produce (Windows REG_SZ "true", iOS / Android
|
||||
// managed-config booleans-as-strings, etc.). Lookup keeps GetBool flat
|
||||
// (no nested switch on the string case).
|
||||
var boolStringLiterals = map[string]bool{
|
||||
"true": true,
|
||||
"1": true,
|
||||
"yes": true,
|
||||
"false": false,
|
||||
"0": false,
|
||||
"no": false,
|
||||
}
|
||||
|
||||
|
||||
// Policy holds MDM-managed settings read from the platform source. A nil or
|
||||
// empty Policy means no enforcement is active.
|
||||
type Policy struct {
|
||||
values map[string]any
|
||||
}
|
||||
|
||||
// NewPolicy constructs a Policy from a key→value map. Pass nil or an
|
||||
// empty map to construct an empty (no-enforcement) Policy. The returned
|
||||
// *Policy is always non-nil.
|
||||
func NewPolicy(values map[string]any) *Policy {
|
||||
if values == nil {
|
||||
values = map[string]any{}
|
||||
}
|
||||
return &Policy{values: values}
|
||||
}
|
||||
|
||||
// LoadPolicy reads the platform-native MDM configuration. Returns an
|
||||
// empty (but non-nil) Policy when no source is present, the source is
|
||||
// empty, or the platform is unsupported.
|
||||
//
|
||||
// Diagnostic logging differentiates the three states:
|
||||
// - source absent / unsupported platform: trace log only
|
||||
// - source present, zero keys: info "MDM enrolled (no managed keys)"
|
||||
// - source present, N keys: info "MDM enrolled with N managed keys: [...]"
|
||||
func LoadPolicy() *Policy {
|
||||
values, err := loadPlatformPolicy()
|
||||
if err != nil {
|
||||
log.Tracef("MDM policy load: %v", err)
|
||||
return &Policy{values: map[string]any{}}
|
||||
}
|
||||
if values == nil {
|
||||
return &Policy{values: map[string]any{}}
|
||||
}
|
||||
if len(values) == 0 {
|
||||
log.Info("MDM enrolled (no managed keys)")
|
||||
} else {
|
||||
log.Infof("MDM enrolled with %d managed key(s): %v", len(values), sortedKeys(values))
|
||||
}
|
||||
return &Policy{values: values}
|
||||
}
|
||||
|
||||
// IsEmpty reports whether the Policy has no managed keys.
|
||||
func (p *Policy) IsEmpty() bool {
|
||||
return p == nil || len(p.values) == 0
|
||||
}
|
||||
|
||||
// HasKey reports whether the given key is MDM-managed.
|
||||
func (p *Policy) HasKey(key string) bool {
|
||||
if p == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := p.values[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ManagedKeys returns the sorted list of managed key names. Returns an empty
|
||||
// slice (not nil) on an empty Policy.
|
||||
func (p *Policy) ManagedKeys() []string {
|
||||
if p == nil {
|
||||
return []string{}
|
||||
}
|
||||
return sortedKeys(p.values)
|
||||
}
|
||||
|
||||
// GetString returns the managed value for key coerced to string, and whether
|
||||
// the key was set. A non-string value returns ("", false).
|
||||
func (p *Policy) GetString(key string) (string, bool) {
|
||||
if p == nil {
|
||||
return "", false
|
||||
}
|
||||
v, ok := p.values[key]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok || s == "" {
|
||||
return "", false
|
||||
}
|
||||
return s, true
|
||||
}
|
||||
|
||||
// GetBool returns the managed value for key coerced to bool, and whether the
|
||||
// key was set. Accepts native bool and string literals "true"/"false"/"1"/"0".
|
||||
func (p *Policy) GetBool(key string) (bool, bool) {
|
||||
if p == nil {
|
||||
return false, false
|
||||
}
|
||||
v, ok := p.values[key]
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case bool:
|
||||
return t, true
|
||||
case string:
|
||||
b, known := boolStringLiterals[t]
|
||||
return b, known
|
||||
case int:
|
||||
return t != 0, true
|
||||
case int64:
|
||||
return t != 0, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
// GetInt returns the managed value for key as int64, and whether the key
|
||||
// was set. Accepts native int / int64 (as produced by the Windows registry
|
||||
// loader for REG_DWORD/REG_QWORD) and numeric strings (decimal).
|
||||
func (p *Policy) GetInt(key string) (int64, bool) {
|
||||
if p == nil {
|
||||
return 0, false
|
||||
}
|
||||
v, ok := p.values[key]
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case int64:
|
||||
return t, true
|
||||
case int:
|
||||
return int64(t), true
|
||||
case int32:
|
||||
return int64(t), true
|
||||
case uint64:
|
||||
return int64(t), true
|
||||
case float64:
|
||||
return int64(t), true
|
||||
case string:
|
||||
if n, err := strconv.ParseInt(t, 10, 64); err == nil {
|
||||
return n, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetStringSlice returns the managed value for key as []string, and whether
|
||||
// the key was set. Accepts []string, []any (of strings), and a single string
|
||||
// (treated as a one-element list).
|
||||
func (p *Policy) GetStringSlice(key string) ([]string, bool) {
|
||||
if p == nil {
|
||||
return nil, false
|
||||
}
|
||||
v, ok := p.values[key]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case []string:
|
||||
return append([]string(nil), t...), true
|
||||
case []any:
|
||||
out := make([]string, 0, len(t))
|
||||
for _, item := range t {
|
||||
s, ok := item.(string)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
out = append(out, s)
|
||||
}
|
||||
return out, true
|
||||
case string:
|
||||
return []string{t}, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// sortedKeys returns the keys of m as a deterministic, lexicographically
|
||||
// sorted slice. Used internally by Policy.ManagedKeys and LoadPolicy's
|
||||
// diagnostic log line so callers see a stable key order across runs
|
||||
// regardless of Go's randomised map iteration.
|
||||
func sortedKeys(m map[string]any) []string {
|
||||
out := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
out = append(out, k)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"howett.net/plist"
|
||||
)
|
||||
|
||||
// policyPlistPath is the well-known location where macOS writes the
|
||||
// device-level mandatory MDM payload for NetBird. The path is fixed by
|
||||
// Apple convention: when an MDM provider (Jamf / Kandji / Mosyle /
|
||||
// Intune for Mac / Workspace ONE) pushes a Configuration Profile that
|
||||
// contains a com.apple.ManagedClient.preferences payload targeting the
|
||||
// bundle id io.netbird.client, the OS materializes the payload here.
|
||||
//
|
||||
// Read-only — only the OS (root) is supposed to write this file. The
|
||||
// loader sanity-checks the file mode and refuses to honour a world-
|
||||
// writable plist, as a defense against tampered installs.
|
||||
const policyPlistPath = "/Library/Managed Preferences/io.netbird.client.plist"
|
||||
|
||||
// loadPlatformPolicy reads the MDM-managed configuration from the macOS
|
||||
// managed-preferences plist at policyPlistPath. Returns:
|
||||
// - (nil, nil) when the plist is absent (device not MDM-enrolled for
|
||||
// NetBird, or admin has not yet pushed a payload)
|
||||
// - (map, nil) with N entries when N managed values are present
|
||||
// (N may be 0 — empty plist still signals enrollment to the caller)
|
||||
// - (nil, err) on permission / parse / safety errors (including
|
||||
// refusal to read a world-writable plist)
|
||||
//
|
||||
// Top-level plist keys are canonicalised case-insensitively to the
|
||||
// package's internal mdm.Key* names; unknown keys are logged and
|
||||
// skipped so a stray entry in the payload does not block startup.
|
||||
// Native plist value types map naturally onto the Policy accessor
|
||||
// expectations (GetString / GetBool / GetInt / GetStringSlice).
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
f, err := os.Open(policyPlistPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
// Not enrolled for NetBird. Caller treats nil as
|
||||
// "no MDM source present".
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("open %s: %w", policyPlistPath, err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := f.Close(); closeErr != nil {
|
||||
log.Warnf("MDM close plist %s: %v", policyPlistPath, closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
info, err := f.Stat()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stat %s: %w", policyPlistPath, err)
|
||||
}
|
||||
// World-writable plist => tampered install. Refuse rather than
|
||||
// honour potentially attacker-controlled policy values.
|
||||
if info.Mode().Perm()&0o002 != 0 {
|
||||
return nil, fmt.Errorf("refusing to read world-writable MDM source %s (mode %o)",
|
||||
policyPlistPath, info.Mode().Perm())
|
||||
}
|
||||
|
||||
raw := make(map[string]any)
|
||||
if err := plist.NewDecoder(f).Decode(&raw); err != nil {
|
||||
return nil, fmt.Errorf("decode plist %s: %w", policyPlistPath, err)
|
||||
}
|
||||
|
||||
out := make(map[string]any, len(raw))
|
||||
for name, val := range raw {
|
||||
// macOS / AppConfig conventions both use camelCase for managed
|
||||
// preferences keys; canonicalize to the mdm.Key* form so a key
|
||||
// written as "ManagementURL" (PascalCase, rare on macOS but
|
||||
// possible if the admin reused an ADMX-style name) still
|
||||
// resolves.
|
||||
canonical, known := canonicalKey[strings.ToLower(name)]
|
||||
if !known {
|
||||
log.Warnf("MDM ignoring unknown plist key %s: %s", policyPlistPath, name)
|
||||
continue
|
||||
}
|
||||
out[canonical] = val
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
//go:build ios || android
|
||||
|
||||
package mdm
|
||||
|
||||
// loadPlatformPolicy is unused on mobile: the native layer (Swift on iOS,
|
||||
// Kotlin/Java on Android) reads the OS managed-config store and pushes the
|
||||
// resulting dictionary in-process via a gomobile entry point that lands in
|
||||
// Phase 5 / Phase 6. The stub keeps the package compilable for mobile
|
||||
// builds and returns (nil, nil) — the platform-absent sentinel that
|
||||
// LoadPolicy in policy.go treats as "no MDM source present".
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
//go:build !windows && !darwin && !ios && !android
|
||||
|
||||
package mdm
|
||||
|
||||
// loadPlatformPolicy returns no policy on platforms without an MDM channel
|
||||
// (Linux, FreeBSD). MDM enforcement is off and the client behaves as if
|
||||
// the feature did not exist. Returns (nil, nil) — the platform-absent
|
||||
// sentinel the caller (LoadPolicy in policy.go) treats as "no MDM
|
||||
// source present"; an error here would just translate to the same
|
||||
// outcome with an extra log line.
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
@@ -1,160 +0,0 @@
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPolicy_NilSafe(t *testing.T) {
|
||||
var p *Policy
|
||||
assert.True(t, p.IsEmpty())
|
||||
assert.False(t, p.HasKey(KeyManagementURL))
|
||||
assert.Empty(t, p.ManagedKeys())
|
||||
|
||||
_, ok := p.GetString(KeyManagementURL)
|
||||
assert.False(t, ok)
|
||||
_, ok = p.GetBool(KeyDisableProfiles)
|
||||
assert.False(t, ok)
|
||||
_, ok = p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPolicy_Empty(t *testing.T) {
|
||||
p := NewPolicy(nil)
|
||||
require.NotNil(t, p)
|
||||
assert.True(t, p.IsEmpty())
|
||||
assert.False(t, p.HasKey(KeyManagementURL))
|
||||
assert.Empty(t, p.ManagedKeys())
|
||||
}
|
||||
|
||||
func TestPolicy_HasKey(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true,
|
||||
})
|
||||
assert.False(t, p.IsEmpty())
|
||||
assert.True(t, p.HasKey(KeyManagementURL))
|
||||
assert.True(t, p.HasKey(KeyDisableProfiles))
|
||||
assert.False(t, p.HasKey(KeyPreSharedKey))
|
||||
}
|
||||
|
||||
func TestPolicy_ManagedKeysSorted(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyDisableProfiles: true,
|
||||
KeyManagementURL: "https://x",
|
||||
KeyAllowServerSSH: false,
|
||||
})
|
||||
got := p.ManagedKeys()
|
||||
assert.Equal(t, []string{KeyAllowServerSSH, KeyDisableProfiles, KeyManagementURL}, got)
|
||||
}
|
||||
|
||||
func TestPolicy_GetString(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true, // wrong type for GetString
|
||||
KeyPreSharedKey: "", // empty rejected
|
||||
})
|
||||
v, ok := p.GetString(KeyManagementURL)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "https://corp.example.com", v)
|
||||
|
||||
_, ok = p.GetString(KeyDisableProfiles)
|
||||
assert.False(t, ok, "non-string value must not be reported as string")
|
||||
|
||||
_, ok = p.GetString(KeyPreSharedKey)
|
||||
assert.False(t, ok, "empty string treated as unset")
|
||||
|
||||
_, ok = p.GetString("nonexistent")
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPolicy_GetBool(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
raw any
|
||||
want bool
|
||||
ok bool
|
||||
}{
|
||||
{"native true", true, true, true},
|
||||
{"native false", false, false, true},
|
||||
{"string true", "true", true, true},
|
||||
{"string false", "false", false, true},
|
||||
{"string 1", "1", true, true},
|
||||
{"string 0", "0", false, true},
|
||||
{"string yes", "yes", true, true},
|
||||
{"string no", "no", false, true},
|
||||
{"int nonzero", 1, true, true},
|
||||
{"int zero", 0, false, true},
|
||||
{"int64 nonzero", int64(2), true, true},
|
||||
{"int64 zero", int64(0), false, true},
|
||||
{"string garbage", "maybe", false, false},
|
||||
{"float unsupported", 1.0, false, false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{KeyDisableProfiles: c.raw})
|
||||
got, ok := p.GetBool(KeyDisableProfiles)
|
||||
assert.Equal(t, c.ok, ok)
|
||||
if c.ok {
|
||||
assert.Equal(t, c.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
_, ok := NewPolicy(nil).GetBool(KeyDisableProfiles)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestPolicy_GetStringSlice(t *testing.T) {
|
||||
t.Run("native string slice", func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeySplitTunnelApps: []string{"com.a", "com.b"},
|
||||
})
|
||||
got, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"com.a", "com.b"}, got)
|
||||
})
|
||||
|
||||
t.Run("any slice of strings", func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeySplitTunnelApps: []any{"com.a", "com.b"},
|
||||
})
|
||||
got, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"com.a", "com.b"}, got)
|
||||
})
|
||||
|
||||
t.Run("single string lifts to one-element slice", func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeySplitTunnelApps: "com.a",
|
||||
})
|
||||
got, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, []string{"com.a"}, got)
|
||||
})
|
||||
|
||||
t.Run("mixed any slice rejected", func(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeySplitTunnelApps: []any{"com.a", 1},
|
||||
})
|
||||
_, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("missing key", func(t *testing.T) {
|
||||
p := NewPolicy(nil)
|
||||
_, ok := p.GetStringSlice(KeySplitTunnelApps)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadPolicy_PlatformStubReturnsEmpty(t *testing.T) {
|
||||
// loadPlatformPolicy is a stub on every OS for Phase 1. LoadPolicy must
|
||||
// degrade gracefully and never return nil.
|
||||
p := LoadPolicy()
|
||||
require.NotNil(t, p)
|
||||
assert.True(t, p.IsEmpty())
|
||||
assert.Empty(t, p.ManagedKeys())
|
||||
}
|
||||
@@ -1,108 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
)
|
||||
|
||||
// policyRegistryPath is the well-known MDM policy registry key for NetBird.
|
||||
// Admins push values here through Group Policy, Intune ADMX ingestion, an
|
||||
// Intune custom Registry CSP profile, or `reg add` during MSI deployment.
|
||||
// Listed in the project's docs/mdm/netbird.admx schema.
|
||||
const policyRegistryPath = `Software\Policies\NetBird`
|
||||
|
||||
// readRegistryValue reads a single value under policyRegistryPath and,
|
||||
// on success, stores the type-coerced result in out[canonical]. Type
|
||||
// coercion mirrors loadPlatformPolicy's documented mapping:
|
||||
// - REG_SZ / REG_EXPAND_SZ -> string (REG_EXPAND_SZ is expanded by the API)
|
||||
// - REG_DWORD / REG_QWORD -> int64
|
||||
// - REG_MULTI_SZ -> []string
|
||||
//
|
||||
// Unsupported value types and per-value read failures are logged at
|
||||
// warn level and skipped — one malformed value must not block the
|
||||
// surrounding loop. Extracted from loadPlatformPolicy to keep that
|
||||
// function's cognitive complexity in check.
|
||||
func readRegistryValue(k registry.Key, name, canonical string, out map[string]any) {
|
||||
_, valType, err := k.GetValue(name, nil)
|
||||
if err != nil {
|
||||
log.Warnf("MDM stat %s\\%s: %v", policyRegistryPath, name, err)
|
||||
return
|
||||
}
|
||||
switch valType {
|
||||
case registry.SZ, registry.EXPAND_SZ:
|
||||
if v, _, err := k.GetStringValue(name); err == nil {
|
||||
out[canonical] = v
|
||||
} else {
|
||||
log.Warnf("MDM read string %s\\%s: %v", policyRegistryPath, name, err)
|
||||
}
|
||||
case registry.DWORD, registry.QWORD:
|
||||
if v, _, err := k.GetIntegerValue(name); err == nil {
|
||||
// uint64 from the registry API; Policy.GetBool / GetInt
|
||||
// helpers consume int64, so narrow safely.
|
||||
out[canonical] = int64(v)
|
||||
} else {
|
||||
log.Warnf("MDM read int %s\\%s: %v", policyRegistryPath, name, err)
|
||||
}
|
||||
case registry.MULTI_SZ:
|
||||
if v, _, err := k.GetStringsValue(name); err == nil {
|
||||
out[canonical] = v
|
||||
} else {
|
||||
log.Warnf("MDM read multi-string %s\\%s: %v", policyRegistryPath, name, err)
|
||||
}
|
||||
default:
|
||||
log.Warnf("MDM ignoring unsupported registry value type %d at %s\\%s",
|
||||
valType, policyRegistryPath, name)
|
||||
}
|
||||
}
|
||||
|
||||
// loadPlatformPolicy reads the MDM-managed configuration from the
|
||||
// Windows registry under HKLM\Software\Policies\NetBird. Returns:
|
||||
// - (nil, nil) when the key is absent (device not MDM-enrolled for NetBird)
|
||||
// - (map, nil) with N entries when N managed values are set (N may be 0)
|
||||
// - (nil, err) on open / enumerate registry errors
|
||||
//
|
||||
// Per-value type coercion + skip-on-error is delegated to
|
||||
// readRegistryValue. Unknown value names are logged and skipped so a
|
||||
// malformed deployment does not block startup.
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, policyRegistryPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
if errors.Is(err, registry.ErrNotExist) {
|
||||
// Not enrolled. Caller treats nil as "no MDM source present".
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("open %s: %w", policyRegistryPath, err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := k.Close(); closeErr != nil {
|
||||
log.Warnf("MDM close registry key %s: %v", policyRegistryPath, closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
names, err := k.ReadValueNames(-1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("enumerate values of %s: %w", policyRegistryPath, err)
|
||||
}
|
||||
|
||||
out := make(map[string]any, len(names))
|
||||
for _, name := range names {
|
||||
// Canonicalize the registry value name against the known MDM key
|
||||
// set so Policy.HasKey lookups (which use the canonical names)
|
||||
// succeed regardless of the casing used by the admin's ADMX or
|
||||
// `reg add` command.
|
||||
canonical, known := canonicalKey[strings.ToLower(name)]
|
||||
if !known {
|
||||
log.Warnf("MDM ignoring unknown registry value %s\\%s", policyRegistryPath, name)
|
||||
continue
|
||||
}
|
||||
readRegistryValue(k, name, canonical, out)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -1,129 +0,0 @@
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DefaultReloadInterval is the production cadence at which the desktop daemon
|
||||
// re-reads the OS-native MDM policy. Picked to balance responsiveness against
|
||||
// registry/plist I/O overhead. Mobile builds use OS-side notifications
|
||||
// instead, hence anticipating the ticker mechanism entirely.
|
||||
const DefaultReloadInterval = 1 * time.Minute
|
||||
|
||||
// policyLoader is the indirection through which the ticker reads the
|
||||
// OS-native policy, both for the initial observation and on every tick.
|
||||
// Production points it at LoadPolicy; tests in this package override it to
|
||||
// feed a scripted sequence of policies without touching the real OS store.
|
||||
var policyLoader = LoadPolicy
|
||||
|
||||
// Ticker periodically re-reads the OS-native MDM policy via LoadPolicy and
|
||||
// invokes the onChange callback (supplied to Run) whenever the observed
|
||||
// Policy diverges from the last observation (added / removed / changed
|
||||
// keys). Launch with Run from a goroutine; cancel the supplied context
|
||||
// to stop.
|
||||
type Ticker struct {
|
||||
interval time.Duration
|
||||
prev *Policy
|
||||
}
|
||||
|
||||
// NewTicker constructs a Ticker that will re-read the OS-native policy
|
||||
// every reloadInterval once Run is called.
|
||||
// The initial snapshot is populated by calling policyLoader at
|
||||
// construction time so the first tick only fires
|
||||
// onChange when the policy actually changed since boot — without
|
||||
// this baseline the first tick would report every currently-managed
|
||||
// key as "added" and trigger a spurious engine restart.
|
||||
func NewTicker(reloadInterval time.Duration) *Ticker {
|
||||
return &Ticker{
|
||||
interval: reloadInterval,
|
||||
prev: policyLoader(),
|
||||
}
|
||||
}
|
||||
|
||||
// Run blocks until ctx is cancelled, polling the OS-native policy store at
|
||||
// the configured cadence and emitting log lines + onChange callback on
|
||||
// every observed diff. onChange must be non-nil.
|
||||
func (t *Ticker) Run(ctx context.Context, onChange func(prev, curr *Policy) error) {
|
||||
tk := time.NewTicker(t.interval)
|
||||
defer tk.Stop()
|
||||
log.Infof("MDM policy reload ticker started (interval=%s)", t.interval)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("MDM policy reload ticker stopped")
|
||||
return
|
||||
case <-tk.C:
|
||||
curr := policyLoader()
|
||||
if policiesEqual(t.prev, curr) {
|
||||
continue
|
||||
}
|
||||
added, removed, changed := diffPolicies(t.prev, curr)
|
||||
log.Infof("MDM policy changed: added=%v removed=%v changed=%v",
|
||||
added, removed, changed)
|
||||
prev := t.prev
|
||||
if err := onChange(prev, curr); err != nil {
|
||||
log.Errorf("MDM policy change handler failed (retrying in 1 minute): %v", err)
|
||||
continue
|
||||
}
|
||||
t.prev = curr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// policiesEqual reports whether two Policy instances carry the same
|
||||
// managed key set with identical values. Nil and empty policies
|
||||
// compare equal; one-nil/one-non-empty compare not equal; otherwise
|
||||
// the underlying values maps are compared with reflect.DeepEqual.
|
||||
func policiesEqual(a, b *Policy) bool {
|
||||
if a.IsEmpty() && b.IsEmpty() {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return reflect.DeepEqual(a.values, b.values)
|
||||
}
|
||||
|
||||
// diffPolicies returns the keys added in curr, removed from prev, and
|
||||
// whose values changed between prev and curr. Each slice is sorted
|
||||
// lexicographically for stable log output; value differences are
|
||||
// determined with reflect.DeepEqual.
|
||||
func diffPolicies(prev, curr *Policy) (added, removed, changed []string) {
|
||||
prevKVs := mapOf(prev)
|
||||
currKVs := mapOf(curr)
|
||||
for k := range currKVs {
|
||||
if _, ok := prevKVs[k]; !ok {
|
||||
added = append(added, k)
|
||||
} else if !reflect.DeepEqual(prevKVs[k], currKVs[k]) {
|
||||
changed = append(changed, k)
|
||||
}
|
||||
}
|
||||
for k := range prevKVs {
|
||||
if _, ok := currKVs[k]; !ok {
|
||||
removed = append(removed, k)
|
||||
}
|
||||
}
|
||||
sort.Strings(added)
|
||||
sort.Strings(removed)
|
||||
sort.Strings(changed)
|
||||
return added, removed, changed
|
||||
}
|
||||
|
||||
// mapOf returns a (possibly empty, never nil) copy of the underlying
|
||||
// values map of a Policy so callers outside this package can compare
|
||||
// keys/values across the type boundary. Returns an empty map on nil p.
|
||||
func mapOf(p *Policy) map[string]any {
|
||||
if p == nil {
|
||||
return map[string]any{}
|
||||
}
|
||||
out := make(map[string]any, len(p.values))
|
||||
for k, v := range p.values {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -1,100 +0,0 @@
|
||||
package mdm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testReloadInterval for speeding up the ticker cadence under `go test`
|
||||
const testReloadInterval = 1 * time.Second
|
||||
|
||||
// withPolicyLoader overrides the package-level policyLoader for the duration
|
||||
// of the test so the ticker observes a scripted policy instead of the real
|
||||
// OS-native store. The original loader is restored on cleanup.
|
||||
func withPolicyLoader(t *testing.T, fn func() *Policy) {
|
||||
t.Helper()
|
||||
prev := policyLoader
|
||||
policyLoader = fn
|
||||
t.Cleanup(func() { policyLoader = prev })
|
||||
}
|
||||
|
||||
func TestTicker_FiresOnChangeWithDelta(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
current := NewPolicy(nil) // initial observation: empty (no enforcement)
|
||||
withPolicyLoader(t, func() *Policy {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return current
|
||||
})
|
||||
|
||||
type change struct{ prev, curr *Policy }
|
||||
changes := make(chan change, 1)
|
||||
tk := NewTicker(testReloadInterval)
|
||||
require.Equal(t, testReloadInterval, tk.interval)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
tk.Run(ctx, func(prev, curr *Policy) error {
|
||||
select {
|
||||
case changes <- change{prev, curr}:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
// Stop Run and wait for it to exit before returning, so the policyLoader
|
||||
// restore in t.Cleanup can't race the ticker goroutine still reading it.
|
||||
defer func() { cancel(); <-done }()
|
||||
|
||||
// Flip the OS-observed policy from empty to one managed key. The next
|
||||
// tick must detect the diff and invoke onChange.
|
||||
mu.Lock()
|
||||
current = NewPolicy(map[string]any{KeyManagementURL: "https://mdm.example.com:443"})
|
||||
mu.Unlock()
|
||||
|
||||
select {
|
||||
case c := <-changes:
|
||||
assert.True(t, c.prev.IsEmpty(), "prev should be the initial empty policy")
|
||||
assert.True(t, c.curr.HasKey(KeyManagementURL), "curr should carry the newly-pushed managed key")
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("onChange not invoked within 5s; ticker should fire every 1s under test")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTicker_NoCallbackWhenPolicyUnchanged(t *testing.T) {
|
||||
withPolicyLoader(t, func() *Policy {
|
||||
return NewPolicy(map[string]any{KeyBlockInbound: true})
|
||||
})
|
||||
|
||||
fired := make(chan struct{}, 1)
|
||||
tk := NewTicker(testReloadInterval)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
tk.Run(ctx, func(_, _ *Policy) error {
|
||||
select {
|
||||
case fired <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
return nil
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
defer func() { cancel(); <-done }()
|
||||
|
||||
// Over ~2 ticks at the 1s test cadence the policy never changes, so the
|
||||
// diff guard must suppress the callback entirely.
|
||||
select {
|
||||
case <-fired:
|
||||
t.Fatal("onChange fired despite an unchanged policy")
|
||||
case <-time.After(2500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -85,6 +85,8 @@ service DaemonService {
|
||||
|
||||
rpc AddProfile(AddProfileRequest) returns (AddProfileResponse) {}
|
||||
|
||||
rpc RenameProfile(RenameProfileRequest) returns (RenameProfileResponse) {}
|
||||
|
||||
rpc RemoveProfile(RemoveProfileRequest) returns (RemoveProfileResponse) {}
|
||||
|
||||
rpc ListProfiles(ListProfilesRequest) returns (ListProfilesResponse) {}
|
||||
@@ -314,13 +316,6 @@ message GetConfigResponse {
|
||||
int32 sshJWTCacheTTL = 26;
|
||||
|
||||
bool disable_ipv6 = 27;
|
||||
|
||||
// mDMManagedFields lists the names of configuration keys whose value is
|
||||
// currently enforced by an MDM policy. Names match mdm.Key* constants
|
||||
// (e.g. "managementURL", "disableClientRoutes"). UI/CLI clients should
|
||||
// render the corresponding inputs as read-only and display a "managed
|
||||
// by MDM" indicator.
|
||||
repeated string mDMManagedFields = 28;
|
||||
}
|
||||
|
||||
// PeerState contains the latest state of a peer
|
||||
@@ -622,11 +617,18 @@ message GetEventsResponse {
|
||||
}
|
||||
|
||||
message SwitchProfileRequest {
|
||||
// profileName is treated as a handle: exact ID, unique ID prefix, or
|
||||
// unique display name. The daemon resolves it server-side.
|
||||
optional string profileName = 1;
|
||||
optional string username = 2;
|
||||
}
|
||||
|
||||
message SwitchProfileResponse {}
|
||||
message SwitchProfileResponse {
|
||||
// id is the resolved on-disk ID of the profile that became active.
|
||||
// Lets CLI clients update their local active-profile state without
|
||||
// duplicating the resolution logic.
|
||||
string id = 1;
|
||||
}
|
||||
|
||||
message SetConfigRequest {
|
||||
string username = 1;
|
||||
@@ -693,17 +695,42 @@ message SetConfigResponse{}
|
||||
|
||||
message AddProfileRequest {
|
||||
string username = 1;
|
||||
// profileName carries the human-readable display name for the new
|
||||
// profile. The on-disk filename is a separately-generated ID.
|
||||
string profileName = 2;
|
||||
}
|
||||
|
||||
message AddProfileResponse {}
|
||||
message AddProfileResponse {
|
||||
// id is the generated on-disk ID of the new profile. CLI clients
|
||||
// display a truncated form, UI clients can ignore it.
|
||||
string id = 1;
|
||||
}
|
||||
|
||||
message RenameProfileRequest {
|
||||
string username = 1;
|
||||
// handle: an exact ID, a unique ID prefix, or a unique display name.
|
||||
string handle = 2;
|
||||
// newProfileName is the new human-readable display name for the profile.
|
||||
string newProfileName = 3;
|
||||
}
|
||||
|
||||
message RenameProfileResponse {
|
||||
// confirm the old profile name after resolving handle.
|
||||
string oldProfileName = 1;
|
||||
}
|
||||
|
||||
message RemoveProfileRequest {
|
||||
string username = 1;
|
||||
// profileName is treated as a handle: an exact ID, a unique ID
|
||||
// prefix, or a unique display name. Resolution happens server-side.
|
||||
string profileName = 2;
|
||||
}
|
||||
|
||||
message RemoveProfileResponse {}
|
||||
message RemoveProfileResponse {
|
||||
// id is the full resolved ID of the removed profile, so callers can
|
||||
// confirm exactly which profile a name/prefix handle resolved to.
|
||||
string id = 1;
|
||||
}
|
||||
|
||||
message ListProfilesRequest {
|
||||
string username = 1;
|
||||
@@ -716,6 +743,7 @@ message ListProfilesResponse {
|
||||
message Profile {
|
||||
string name = 1;
|
||||
bool is_active = 2;
|
||||
string id = 3;
|
||||
}
|
||||
|
||||
message GetActiveProfileRequest {}
|
||||
@@ -723,6 +751,7 @@ message GetActiveProfileRequest {}
|
||||
message GetActiveProfileResponse {
|
||||
string profileName = 1;
|
||||
string username = 2;
|
||||
string id = 3;
|
||||
}
|
||||
|
||||
message LogoutRequest {
|
||||
@@ -740,15 +769,6 @@ message GetFeaturesResponse{
|
||||
bool disable_networks = 3;
|
||||
}
|
||||
|
||||
// MDMManagedFieldsViolation is attached as a gRPC error detail on a
|
||||
// FailedPrecondition status returned from SetConfig (and similar mutating
|
||||
// RPCs) when the caller tries to modify one or more MDM-enforced fields.
|
||||
// The fields list contains the offending key names; the entire request is
|
||||
// rejected (no partial apply).
|
||||
message MDMManagedFieldsViolation {
|
||||
repeated string fields = 1;
|
||||
}
|
||||
|
||||
message TriggerUpdateRequest {}
|
||||
|
||||
message TriggerUpdateResponse {
|
||||
|
||||
@@ -45,6 +45,7 @@ const (
|
||||
DaemonService_SwitchProfile_FullMethodName = "/daemon.DaemonService/SwitchProfile"
|
||||
DaemonService_SetConfig_FullMethodName = "/daemon.DaemonService/SetConfig"
|
||||
DaemonService_AddProfile_FullMethodName = "/daemon.DaemonService/AddProfile"
|
||||
DaemonService_RenameProfile_FullMethodName = "/daemon.DaemonService/RenameProfile"
|
||||
DaemonService_RemoveProfile_FullMethodName = "/daemon.DaemonService/RemoveProfile"
|
||||
DaemonService_ListProfiles_FullMethodName = "/daemon.DaemonService/ListProfiles"
|
||||
DaemonService_GetActiveProfile_FullMethodName = "/daemon.DaemonService/GetActiveProfile"
|
||||
@@ -112,6 +113,7 @@ type DaemonServiceClient interface {
|
||||
SwitchProfile(ctx context.Context, in *SwitchProfileRequest, opts ...grpc.CallOption) (*SwitchProfileResponse, error)
|
||||
SetConfig(ctx context.Context, in *SetConfigRequest, opts ...grpc.CallOption) (*SetConfigResponse, error)
|
||||
AddProfile(ctx context.Context, in *AddProfileRequest, opts ...grpc.CallOption) (*AddProfileResponse, error)
|
||||
RenameProfile(ctx context.Context, in *RenameProfileRequest, opts ...grpc.CallOption) (*RenameProfileResponse, error)
|
||||
RemoveProfile(ctx context.Context, in *RemoveProfileRequest, opts ...grpc.CallOption) (*RemoveProfileResponse, error)
|
||||
ListProfiles(ctx context.Context, in *ListProfilesRequest, opts ...grpc.CallOption) (*ListProfilesResponse, error)
|
||||
GetActiveProfile(ctx context.Context, in *GetActiveProfileRequest, opts ...grpc.CallOption) (*GetActiveProfileResponse, error)
|
||||
@@ -422,6 +424,16 @@ func (c *daemonServiceClient) AddProfile(ctx context.Context, in *AddProfileRequ
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) RenameProfile(ctx context.Context, in *RenameProfileRequest, opts ...grpc.CallOption) (*RenameProfileResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(RenameProfileResponse)
|
||||
err := c.cc.Invoke(ctx, DaemonService_RenameProfile_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) RemoveProfile(ctx context.Context, in *RemoveProfileRequest, opts ...grpc.CallOption) (*RemoveProfileResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(RemoveProfileResponse)
|
||||
@@ -613,6 +625,7 @@ type DaemonServiceServer interface {
|
||||
SwitchProfile(context.Context, *SwitchProfileRequest) (*SwitchProfileResponse, error)
|
||||
SetConfig(context.Context, *SetConfigRequest) (*SetConfigResponse, error)
|
||||
AddProfile(context.Context, *AddProfileRequest) (*AddProfileResponse, error)
|
||||
RenameProfile(context.Context, *RenameProfileRequest) (*RenameProfileResponse, error)
|
||||
RemoveProfile(context.Context, *RemoveProfileRequest) (*RemoveProfileResponse, error)
|
||||
ListProfiles(context.Context, *ListProfilesRequest) (*ListProfilesResponse, error)
|
||||
GetActiveProfile(context.Context, *GetActiveProfileRequest) (*GetActiveProfileResponse, error)
|
||||
@@ -723,6 +736,9 @@ func (UnimplementedDaemonServiceServer) SetConfig(context.Context, *SetConfigReq
|
||||
func (UnimplementedDaemonServiceServer) AddProfile(context.Context, *AddProfileRequest) (*AddProfileResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method AddProfile not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) RenameProfile(context.Context, *RenameProfileRequest) (*RenameProfileResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method RenameProfile not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) RemoveProfile(context.Context, *RemoveProfileRequest) (*RemoveProfileResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method RemoveProfile not implemented")
|
||||
}
|
||||
@@ -1237,6 +1253,24 @@ func _DaemonService_AddProfile_Handler(srv interface{}, ctx context.Context, dec
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_RenameProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RenameProfileRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).RenameProfile(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: DaemonService_RenameProfile_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).RenameProfile(ctx, req.(*RenameProfileRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_RemoveProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RemoveProfileRequest)
|
||||
if err := dec(in); err != nil {
|
||||
@@ -1567,6 +1601,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "AddProfile",
|
||||
Handler: _DaemonService_AddProfile_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "RenameProfile",
|
||||
Handler: _DaemonService_RenameProfile_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "RemoveProfile",
|
||||
Handler: _DaemonService_RemoveProfile_Handler,
|
||||
|
||||
@@ -79,7 +79,7 @@ func TestPersistLoginOverrides(t *testing.T) {
|
||||
_, err := profilemanager.UpdateOrCreateConfig(seed)
|
||||
require.NoError(t, err, "seed config")
|
||||
|
||||
activeProf := &profilemanager.ActiveProfileState{Name: "default"}
|
||||
activeProf := &profilemanager.ActiveProfileState{ID: "default"}
|
||||
err = persistLoginOverrides(activeProf, tt.newMgmtURL, tt.newPSK)
|
||||
require.NoError(t, err, "persistLoginOverrides")
|
||||
|
||||
|
||||
@@ -1,419 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// preSharedKeyRedactedSentinel is the value GetConfig returns in place
|
||||
// of an actual PSK, so a UI that round-trips the field back to the
|
||||
// daemon (via SetConfig / Login) can be distinguished from a deliberate
|
||||
// override. Any incoming PSK that equals this sentinel is treated as
|
||||
// a no-op echo, never as a conflict with the policy.
|
||||
const preSharedKeyRedactedSentinel = "**********"
|
||||
|
||||
// loadMDMPolicy is the indirection used by server handlers to read the
|
||||
// active MDM policy. Tests override this to inject a fake policy.
|
||||
var loadMDMPolicy = mdm.LoadPolicy
|
||||
|
||||
// conflictCheck is a value-aware comparison between a single field in
|
||||
// the incoming request and the corresponding MDM-enforced value. It
|
||||
// runs only when the field was actually set in the request (presence
|
||||
// already filtered upstream); ok=true reports the policy value, ok=false
|
||||
// means the policy is silent on the key — both are treated as conflicts
|
||||
// to be safe (an MDM key declared as managed must hold a value).
|
||||
type conflictCheck struct {
|
||||
key string
|
||||
check func(*mdm.Policy) (match bool)
|
||||
}
|
||||
|
||||
// onMDMPolicyChange is invoked by the MDM reload ticker every time the
|
||||
// OS-native managed-config store reports a diff vs the last observation.
|
||||
//
|
||||
// Restart sequence:
|
||||
// 1. Cancel the active engine context (terminates connectWithRetryRuns).
|
||||
// 2. Wait briefly for that goroutine to exit (giveUpChan is closed on exit).
|
||||
// 3. Re-resolve Config from disk + MDM policy (Config.apply re-runs
|
||||
// applyMDMPolicy with the freshly loaded Policy).
|
||||
// 4. Spawn a fresh connectWithRetryRuns with the new context and config.
|
||||
// 5. Broadcast a SystemEvent so any GUI / CLI subscriber (SubscribeEvents
|
||||
// RPC) can refresh its cached config view without polling.
|
||||
//
|
||||
// The callback runs in the ticker's own goroutine. Ticker has already
|
||||
// logged the per-key diff before invoking this hook.
|
||||
func (s *Server) onMDMPolicyChange(_, _ *mdm.Policy) error {
|
||||
log.Warn("MDM policy changed; restarting engine to apply new configuration")
|
||||
|
||||
// Hold s.mutex for the entire restart sequence (cancel + quiescence
|
||||
// wait + re-spawn). Any concurrent Up/Down/Status arriving while
|
||||
// MDM is restarting blocks on the Lock until we are done — they
|
||||
// then observe the post-restart state coherently. This is safe
|
||||
// because the connectWithRetryRuns goroutine no longer acquires
|
||||
// s.mutex in its defer (intent vs. goroutine-alive concerns are
|
||||
// fully separated; see the connectionGoroutineRunning helper).
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if !s.clientRunning {
|
||||
// The client is not running, so there's no engine to restart.
|
||||
return nil
|
||||
}
|
||||
if s.actCancel != nil {
|
||||
s.actCancel()
|
||||
}
|
||||
|
||||
// Wait for previous connectWithRetryRuns to exit so we don't end up
|
||||
// with two goroutines fighting over the same status recorder + engine.
|
||||
// The teardown engages a fan-out of engine goroutines (peer workers,
|
||||
// signal handler, route manager, ...). close(clientGiveUpChan)
|
||||
// happens in the function-scope defer of connectWithRetryRuns, on
|
||||
// every exit path (ctx cancel, backoff exhausted, panic) — see the
|
||||
// defer in server.go.
|
||||
if s.clientGiveUpChan != nil {
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
case <-time.After(10 * time.Second):
|
||||
return fmt.Errorf("failed to restart the engine due to timeout")
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.restartEngineForMDMLocked(); err != nil {
|
||||
log.Errorf("MDM restart failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// publishConfigChangedEvent has already fired inside
|
||||
// restartEngineForMDMLocked with source="mdm". Emit an MDM-specific
|
||||
// user-visible toast so the operator knows their IT policy was
|
||||
// applied (UserMessage != "" triggers the GUI notifier).
|
||||
s.statusRecorder.PublishEvent(
|
||||
proto.SystemEvent_INFO,
|
||||
proto.SystemEvent_SYSTEM,
|
||||
"MDM policy applied",
|
||||
"NetBird configuration was updated by your IT policy.",
|
||||
map[string]string{"source": "mdm", "type": "policy_applied"},
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// publishConfigChangedEvent broadcasts a SystemEvent informing any active
|
||||
// SubscribeEvents subscriber (typically the GUI tray) that the daemon's
|
||||
// effective Config has been replaced and any cached client-side view
|
||||
// should be refreshed. Callers pass a stable `source` label so the GUI
|
||||
// can distinguish a startup spawn from a user-triggered Up or an
|
||||
// MDM-driven restart. Reusing the SYSTEM category keeps the proto enum
|
||||
// stable; metadata.type="config_changed" routes to the GUI's refresh
|
||||
// handler. UserMessage is left empty so the system tray does not toast
|
||||
// for every internal restart; the MDM path emits a separate
|
||||
// "policy_applied" event (with UserMessage) for that purpose.
|
||||
func (s *Server) publishConfigChangedEvent(source string) {
|
||||
if s.statusRecorder == nil {
|
||||
return
|
||||
}
|
||||
s.statusRecorder.PublishEvent(
|
||||
proto.SystemEvent_INFO,
|
||||
proto.SystemEvent_SYSTEM,
|
||||
fmt.Sprintf("daemon config changed (source=%s)", source),
|
||||
"",
|
||||
map[string]string{
|
||||
"source": source,
|
||||
"type": "config_changed",
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// restartEngineForMDMLocked re-resolves the active profile config
|
||||
// (re-running applyMDMPolicy via Config.apply) and re-spawns
|
||||
// connectWithRetryRuns. Mirrors the tail of Server.Start so a runtime
|
||||
// MDM change behaves identically to a fresh boot under the new policy.
|
||||
//
|
||||
// MUST be called with s.mutex held — onMDMPolicyChange holds the lock
|
||||
// for the entire restart sequence (cancel + quiescence wait + re-spawn)
|
||||
// so concurrent Up/Down/Status RPCs observe a coherent post-restart
|
||||
// state.
|
||||
func (s *Server) restartEngineForMDMLocked() error {
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile state: %w", err)
|
||||
}
|
||||
config, _, err := s.getConfig(activeProf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile config: %w", err)
|
||||
}
|
||||
|
||||
s.config = config
|
||||
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
|
||||
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
|
||||
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
s.actCancel = cancel
|
||||
s.clientRunning = true
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
log.Info("MDM restart: spawning connectWithRetryRuns with re-resolved config")
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
s.publishConfigChangedEvent("mdm")
|
||||
return nil
|
||||
}
|
||||
|
||||
// conflictBool builds a conflictCheck for a boolean MDM key. If p is nil
|
||||
// the field is treated as matching (no override requested); otherwise the
|
||||
// check returns true only when the policy contains the key and its
|
||||
// boolean value equals *p.
|
||||
func conflictBool(key string, p *bool) conflictCheck {
|
||||
return conflictCheck{
|
||||
key: key,
|
||||
check: func(pol *mdm.Policy) bool {
|
||||
if p == nil {
|
||||
return true // absent → match by definition
|
||||
}
|
||||
want, ok := pol.GetBool(key)
|
||||
return ok && want == *p
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// conflictString builds a conflictCheck for a string MDM key. An empty
|
||||
// `got` is treated as "field not set" (no override requested); otherwise
|
||||
// the check returns true only when the policy contains the key and its
|
||||
// value equals got.
|
||||
func conflictString(key, got string) conflictCheck {
|
||||
return conflictCheck{
|
||||
key: key,
|
||||
check: func(pol *mdm.Policy) bool {
|
||||
if got == "" {
|
||||
return true
|
||||
}
|
||||
want, ok := pol.GetString(key)
|
||||
return ok && want == got
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// conflictInt64 builds a conflictCheck for an integer MDM key. If p is
|
||||
// nil the field is treated as matching; otherwise the check returns
|
||||
// true only when the policy contains the key and its int value equals *p.
|
||||
func conflictInt64(key string, p *int64) conflictCheck {
|
||||
return conflictCheck{
|
||||
key: key,
|
||||
check: func(pol *mdm.Policy) bool {
|
||||
if p == nil {
|
||||
return true
|
||||
}
|
||||
want, ok := pol.GetInt(key)
|
||||
return ok && want == *p
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// resolveConflicts walks the per-field checks against the active MDM
|
||||
// policy and returns the names of keys whose requested value diverges
|
||||
// from the policy-enforced value. Keys not present in the policy are
|
||||
// skipped silently (the gate fires only for keys the admin has
|
||||
// actually pushed). Returns nil for an empty policy.
|
||||
func resolveConflicts(policy *mdm.Policy, checks []conflictCheck) []string {
|
||||
if policy.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
var conflicts []string
|
||||
for _, c := range checks {
|
||||
if !policy.HasKey(c.key) {
|
||||
continue
|
||||
}
|
||||
if !c.check(policy) {
|
||||
conflicts = append(conflicts, c.key)
|
||||
}
|
||||
}
|
||||
return conflicts
|
||||
}
|
||||
|
||||
// mdmManagedFieldConflicts returns the names of MDM-managed keys whose
|
||||
// requested value in the SetConfigRequest differs from the MDM-enforced
|
||||
// value. A field set to the same value the policy already enforces is
|
||||
// treated as a no-op echo (the GUI tray sends a full Config snapshot on
|
||||
// every toggle, so most fields in a typical request match the policy
|
||||
// exactly and must NOT be flagged as conflicts). The redacted PSK
|
||||
// sentinel ("**********") returned by GetConfig is recognised and
|
||||
// treated as no-op so the UI can safely round-trip it.
|
||||
func mdmManagedFieldConflicts(msg *proto.SetConfigRequest, policy *mdm.Policy) []string {
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// PSK round-trip echo: collapse the sentinel to empty so the
|
||||
// shared check treats it as "field not set".
|
||||
pskGot := ""
|
||||
if msg.OptionalPreSharedKey != nil && *msg.OptionalPreSharedKey != preSharedKeyRedactedSentinel {
|
||||
pskGot = *msg.OptionalPreSharedKey
|
||||
}
|
||||
|
||||
return resolveConflicts(policy, []conflictCheck{
|
||||
conflictString(mdm.KeyManagementURL, msg.ManagementUrl),
|
||||
conflictString(mdm.KeyPreSharedKey, pskGot),
|
||||
conflictBool(mdm.KeyRosenpassEnabled, msg.RosenpassEnabled),
|
||||
conflictBool(mdm.KeyRosenpassPermissive, msg.RosenpassPermissive),
|
||||
conflictBool(mdm.KeyDisableAutoConnect, msg.DisableAutoConnect),
|
||||
conflictBool(mdm.KeyAllowServerSSH, msg.ServerSSHAllowed),
|
||||
conflictBool(mdm.KeyDisableClientRoutes, msg.DisableClientRoutes),
|
||||
conflictBool(mdm.KeyDisableServerRoutes, msg.DisableServerRoutes),
|
||||
conflictBool(mdm.KeyBlockInbound, msg.BlockInbound),
|
||||
conflictInt64(mdm.KeyWireguardPort, msg.WireguardPort),
|
||||
})
|
||||
}
|
||||
|
||||
// setConfigRequestHasConfigOverrides reports whether the SetConfigRequest
|
||||
// carries ANY field that would actually mutate the persisted config.
|
||||
// The CLI builds a SetConfigRequest unconditionally on every
|
||||
// `netbird up` (see setupSetConfigReq in cmd/up.go) — a plain
|
||||
// `netbird up` produces a request with every field at its zero value;
|
||||
// the gate must skip such no-op invocations or it would always fire
|
||||
// even when the user did not pass any --flag. Returns false on a nil
|
||||
// msg; true when any management/admin URL, PSK, DNS/NAT list+clean
|
||||
// flag, interface/port/MTU, or any optional bool/duration field is set.
|
||||
func setConfigRequestHasConfigOverrides(msg *proto.SetConfigRequest) bool {
|
||||
if msg == nil {
|
||||
return false
|
||||
}
|
||||
return msg.ManagementUrl != "" ||
|
||||
msg.AdminURL != "" ||
|
||||
msg.OptionalPreSharedKey != nil ||
|
||||
len(msg.CustomDNSAddress) > 0 ||
|
||||
len(msg.NatExternalIPs) > 0 || msg.CleanNATExternalIPs ||
|
||||
len(msg.ExtraIFaceBlacklist) > 0 ||
|
||||
len(msg.DnsLabels) > 0 || msg.CleanDNSLabels ||
|
||||
msg.DnsRouteInterval != nil ||
|
||||
msg.RosenpassEnabled != nil ||
|
||||
msg.RosenpassPermissive != nil ||
|
||||
msg.InterfaceName != nil ||
|
||||
msg.WireguardPort != nil ||
|
||||
msg.Mtu != nil ||
|
||||
msg.DisableAutoConnect != nil ||
|
||||
msg.ServerSSHAllowed != nil ||
|
||||
msg.NetworkMonitor != nil ||
|
||||
msg.DisableClientRoutes != nil ||
|
||||
msg.DisableServerRoutes != nil ||
|
||||
msg.DisableDns != nil ||
|
||||
msg.DisableFirewall != nil ||
|
||||
msg.BlockLanAccess != nil ||
|
||||
msg.DisableNotifications != nil ||
|
||||
msg.LazyConnectionEnabled != nil ||
|
||||
msg.BlockInbound != nil ||
|
||||
msg.DisableIpv6 != nil ||
|
||||
msg.EnableSSHRoot != nil ||
|
||||
msg.EnableSSHSFTP != nil ||
|
||||
msg.EnableSSHLocalPortForwarding != nil ||
|
||||
msg.EnableSSHRemotePortForwarding != nil ||
|
||||
msg.DisableSSHAuth != nil ||
|
||||
msg.SshJWTCacheTTL != nil
|
||||
}
|
||||
|
||||
// loginRequestHasConfigOverrides reports whether the LoginRequest
|
||||
// carries ANY field that would mutate persisted daemon configuration
|
||||
// (as opposed to pure-auth fields like setupKey, hostname, hint,
|
||||
// profileName, username). Used by the Login handler to decide whether
|
||||
// the `--disable-update-settings` / MDM gates must run: a re-auth that
|
||||
// changes nothing about the configuration is always allowed.
|
||||
func loginRequestHasConfigOverrides(msg *proto.LoginRequest) bool {
|
||||
if msg == nil {
|
||||
return false
|
||||
}
|
||||
return msg.ManagementUrl != "" ||
|
||||
msg.AdminURL != "" ||
|
||||
msg.PreSharedKey != "" || //nolint:staticcheck // SA1019: legacy proto field still accepted by Login
|
||||
msg.OptionalPreSharedKey != nil ||
|
||||
len(msg.CustomDNSAddress) > 0 ||
|
||||
len(msg.NatExternalIPs) > 0 || msg.CleanNATExternalIPs ||
|
||||
msg.RosenpassEnabled != nil ||
|
||||
msg.InterfaceName != nil ||
|
||||
msg.WireguardPort != nil ||
|
||||
msg.DisableAutoConnect != nil ||
|
||||
msg.ServerSSHAllowed != nil ||
|
||||
msg.RosenpassPermissive != nil ||
|
||||
len(msg.ExtraIFaceBlacklist) > 0 ||
|
||||
msg.NetworkMonitor != nil ||
|
||||
msg.DnsRouteInterval != nil ||
|
||||
msg.DisableClientRoutes != nil ||
|
||||
msg.DisableServerRoutes != nil ||
|
||||
msg.DisableDns != nil ||
|
||||
msg.DisableFirewall != nil ||
|
||||
msg.BlockLanAccess != nil ||
|
||||
msg.DisableNotifications != nil ||
|
||||
len(msg.DnsLabels) > 0 || msg.CleanDNSLabels ||
|
||||
msg.LazyConnectionEnabled != nil ||
|
||||
msg.BlockInbound != nil
|
||||
}
|
||||
|
||||
// loginRequestMDMConflicts mirrors mdmManagedFieldConflicts but for the
|
||||
// LoginRequest surface. Same value-aware semantics: a field set to the
|
||||
// MDM-enforced value is a no-op echo, not a conflict; only a divergent
|
||||
// value is flagged. PSK has two proto fields — PreSharedKey (deprecated)
|
||||
// and OptionalPreSharedKey (current); either route trips the gate if it
|
||||
// diverges from the MDM-enforced PSK. OptionalPreSharedKey wins when
|
||||
// both are set; the redaction sentinel ("**********") is accepted as
|
||||
// a no-op echo.
|
||||
func loginRequestMDMConflicts(msg *proto.LoginRequest, policy *mdm.Policy) []string {
|
||||
if msg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Collapse the two PSK fields + the redaction sentinel down to a
|
||||
// single "got" string the shared check can compare against the
|
||||
// policy: OptionalPreSharedKey wins if set; PreSharedKey (deprecated)
|
||||
// is the fallback; sentinel echo is treated as "field not set".
|
||||
pskGot := ""
|
||||
if msg.OptionalPreSharedKey != nil {
|
||||
pskGot = *msg.OptionalPreSharedKey
|
||||
} else if msg.PreSharedKey != "" { //nolint:staticcheck // SA1019: legacy proto field still accepted by Login
|
||||
pskGot = msg.PreSharedKey //nolint:staticcheck // SA1019
|
||||
}
|
||||
if pskGot == preSharedKeyRedactedSentinel {
|
||||
pskGot = ""
|
||||
}
|
||||
|
||||
return resolveConflicts(policy, []conflictCheck{
|
||||
conflictString(mdm.KeyManagementURL, msg.ManagementUrl),
|
||||
conflictString(mdm.KeyPreSharedKey, pskGot),
|
||||
conflictBool(mdm.KeyRosenpassEnabled, msg.RosenpassEnabled),
|
||||
conflictBool(mdm.KeyRosenpassPermissive, msg.RosenpassPermissive),
|
||||
conflictBool(mdm.KeyDisableAutoConnect, msg.DisableAutoConnect),
|
||||
conflictBool(mdm.KeyAllowServerSSH, msg.ServerSSHAllowed),
|
||||
conflictBool(mdm.KeyDisableClientRoutes, msg.DisableClientRoutes),
|
||||
conflictBool(mdm.KeyDisableServerRoutes, msg.DisableServerRoutes),
|
||||
conflictBool(mdm.KeyBlockInbound, msg.BlockInbound),
|
||||
conflictInt64(mdm.KeyWireguardPort, msg.WireguardPort),
|
||||
})
|
||||
}
|
||||
|
||||
// rejectMDMManagedFieldConflicts returns a FailedPrecondition gRPC error
|
||||
// with an MDMManagedFieldsViolation detail when any of the requested
|
||||
// fields tries to change an MDM-enforced value to something else, and
|
||||
// nil otherwise. The whole request is rejected on any conflict; non-
|
||||
// conflicting fields in the same request are not applied either (no
|
||||
// partial apply).
|
||||
func rejectMDMManagedFieldConflicts(conflicts []string) error {
|
||||
if len(conflicts) == 0 {
|
||||
return nil
|
||||
}
|
||||
log.Warnf("MDM rejected request: tried to modify %d managed key(s): %v",
|
||||
len(conflicts), conflicts)
|
||||
st := gstatus.New(
|
||||
codes.FailedPrecondition,
|
||||
fmt.Sprintf("fields managed by MDM cannot be modified: %v", conflicts),
|
||||
)
|
||||
detailed, err := st.WithDetails(&proto.MDMManagedFieldsViolation{Fields: conflicts})
|
||||
if err != nil {
|
||||
// Detail attachment is best-effort; fall back to the plain status
|
||||
// so the caller still gets a usable FailedPrecondition.
|
||||
return st.Err()
|
||||
}
|
||||
return detailed.Err()
|
||||
}
|
||||
@@ -30,7 +30,7 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.checkNetworksDisabled() {
|
||||
if s.networksDisabled {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
@@ -143,7 +143,7 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.checkNetworksDisabled() {
|
||||
if s.networksDisabled {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
@@ -195,7 +195,7 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.checkNetworksDisabled() {
|
||||
if s.networksDisabled {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled)
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sleephandler "github.com/netbirdio/netbird/client/internal/sleep/handler"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -72,13 +71,7 @@ type Server struct {
|
||||
mutex sync.Mutex
|
||||
config *profilemanager.Config
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
// clientRunning tracks "the daemon wants to be connected" — set true by
|
||||
// Start / Up, cleared by Down / Logout. Persists across retry
|
||||
// loops, signal disconnects, and ErrResetConnection cycles. NOT
|
||||
// changed by connectWithRetryRuns goroutine exit — for that
|
||||
// (goroutine-still-alive) check, see connectionGoroutineRunning() which
|
||||
// derives from clientGiveUpChan close state. Protected by s.mutex.
|
||||
clientRunning bool
|
||||
clientRunning bool // protected by mutex
|
||||
clientRunningChan chan struct{}
|
||||
clientGiveUpChan chan struct{} // closed when connectWithRetryRuns goroutine exits
|
||||
|
||||
@@ -105,11 +98,6 @@ type Server struct {
|
||||
|
||||
sleepHandler *sleephandler.SleepHandler
|
||||
|
||||
// mdmTicker periodically re-reads the OS-native MDM policy and triggers
|
||||
// an engine restart when the policy changes. Launched once by Start;
|
||||
// stopped by the rootCtx cancellation.
|
||||
mdmTicker *mdm.Ticker
|
||||
|
||||
updateManager *updater.Manager
|
||||
|
||||
jwtCache *jwtCache
|
||||
@@ -167,17 +155,6 @@ func (s *Server) Start() error {
|
||||
s.updateManager.CheckUpdateSuccess(s.rootCtx)
|
||||
}
|
||||
|
||||
// MDM policy reload ticker: every minute the desktop daemon re-reads
|
||||
// the OS-native managed-config store and, on diff vs the previous
|
||||
// observation, cancels the active engine context so connectWithRetry-
|
||||
// Runs re-resolves Config (re-running profilemanager.Config.apply which
|
||||
// applies the freshly-read MDM policy as the last layer) and brings
|
||||
// the engine back with the new values.
|
||||
if s.mdmTicker == nil {
|
||||
s.mdmTicker = mdm.NewTicker(mdm.DefaultReloadInterval)
|
||||
go s.mdmTicker.Run(s.rootCtx, s.onMDMPolicyChange)
|
||||
}
|
||||
|
||||
// if current state contains any error, return it
|
||||
// in all other cases we can continue execution only if status is idle and up command was
|
||||
// not in the progress or already successfully established connection.
|
||||
@@ -236,27 +213,17 @@ func (s *Server) Start() error {
|
||||
s.clientRunningChan = make(chan struct{})
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
s.publishConfigChangedEvent("startup")
|
||||
return nil
|
||||
}
|
||||
|
||||
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
|
||||
// mechanism to keep the client connected even when the connection is lost.
|
||||
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
|
||||
//
|
||||
// The goroutine's exit is signalled to the daemon via close(giveUpChan)
|
||||
// — placed in the function-scope defer so every return path (panic,
|
||||
// DisableAutoConnect early-exit, backoff exhausted, ctx cancel) closes
|
||||
// it. Callers that need to observe "is the goroutine still alive?" use
|
||||
// Server.connectionGoroutineRunning() which non-blockingly checks the close state
|
||||
// of clientGiveUpChan. The defer does NOT touch s.mutex; the daemon's
|
||||
// "intent" (clientRunning) is maintained by the RPC handlers, not by this
|
||||
// goroutine.
|
||||
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
|
||||
defer func() {
|
||||
if giveUpChan != nil {
|
||||
close(giveUpChan)
|
||||
}
|
||||
s.mutex.Lock()
|
||||
s.clientRunning = false
|
||||
s.mutex.Unlock()
|
||||
}()
|
||||
|
||||
if s.config.DisableAutoConnect {
|
||||
@@ -302,26 +269,9 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
|
||||
if err := backoff.Retry(runOperation, backOff); err != nil {
|
||||
log.Errorf("operation failed: %v", err)
|
||||
}
|
||||
// giveUpChan is closed by the function-scope defer.
|
||||
}
|
||||
|
||||
// connectionGoroutineRunning reports whether the connectWithRetryRuns goroutine is
|
||||
// still running. Returns false when no goroutine has ever been started
|
||||
// AND when the most recent one has already closed clientGiveUpChan on
|
||||
// exit (whether due to ctx cancel, DisableAutoConnect single-shot
|
||||
// completion, or backoff retry exhaustion).
|
||||
//
|
||||
// MUST be called with s.mutex held — accesses s.clientGiveUpChan which
|
||||
// is written by Start/Up under the same lock.
|
||||
func (s *Server) connectionGoroutineRunning() bool {
|
||||
if s.clientGiveUpChan == nil {
|
||||
return false
|
||||
}
|
||||
select {
|
||||
case <-s.clientGiveUpChan:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
if giveUpChan != nil {
|
||||
close(giveUpChan)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -354,85 +304,53 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
// Skip the update-settings gate when the request carries no actual
|
||||
// overrides: the CLI builds a SetConfigRequest unconditionally on
|
||||
// every `netbird up` (setupSetConfigReq in cmd/up.go), so a plain
|
||||
// `netbird up` would otherwise always trip the gate and surface a
|
||||
// misleading "setConfig method is not available" warning, even when
|
||||
// the user did not pass any config flag.
|
||||
if setConfigRequestHasConfigOverrides(msg) {
|
||||
if s.checkUpdateSettingsDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
|
||||
}
|
||||
if s.checkUpdateSettingsDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
|
||||
}
|
||||
|
||||
// MDM gate: refuse the whole request if any of its fields is enforced
|
||||
// by the active MDM policy. The error carries an MDMManagedFields-
|
||||
// Violation detail listing the offending key names. Non-conflicting
|
||||
// fields in the same request are not applied either.
|
||||
policy := loadMDMPolicy()
|
||||
if err := rejectMDMManagedFieldConflicts(mdmManagedFieldConflicts(msg, policy)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config, err := setConfigInputFromRequest(msg)
|
||||
resolved, err := s.resolveProfileHandle(msg.ProfileName, msg.Username)
|
||||
if err != nil {
|
||||
log.Errorf("failed to resolve profile %q: %v", msg.ProfileName, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := profilemanager.UpdateConfig(config); err != nil {
|
||||
log.Errorf("failed to update profile config: %v", err)
|
||||
return nil, fmt.Errorf("failed to update profile config: %w", err)
|
||||
profPath := resolved.Path
|
||||
if profPath == "" {
|
||||
profPath = profilemanager.DefaultConfigPath
|
||||
}
|
||||
|
||||
return &proto.SetConfigResponse{}, nil
|
||||
}
|
||||
|
||||
// setConfigInputFromRequest translates a SetConfigRequest into the
|
||||
// profilemanager.ConfigInput that profilemanager.UpdateConfig consumes.
|
||||
// Pure mapping with no business logic beyond presence-aware copying of
|
||||
// optional fields and the "empty / clean" semantics for the two slice
|
||||
// fields (DNS labels, NAT external IPs). Extracted from SetConfig to
|
||||
// keep the handler's cognitive complexity below the SonarCube
|
||||
// threshold; the body is intentionally linear because each proto
|
||||
// field is its own optional case. Returns the resolved ConfigInput
|
||||
// and a non-nil error only when the active profile file path cannot
|
||||
// be determined.
|
||||
func setConfigInputFromRequest(msg *proto.SetConfigRequest) (profilemanager.ConfigInput, error) {
|
||||
var config profilemanager.ConfigInput
|
||||
|
||||
profState := profilemanager.ActiveProfileState{
|
||||
Name: msg.ProfileName,
|
||||
Username: msg.Username,
|
||||
}
|
||||
profPath, err := profState.FilePath()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get active profile file path: %v", err)
|
||||
return config, fmt.Errorf("failed to get active profile file path: %w", err)
|
||||
}
|
||||
config.ConfigPath = profPath
|
||||
|
||||
if msg.ManagementUrl != "" {
|
||||
config.ManagementURL = msg.ManagementUrl
|
||||
}
|
||||
|
||||
if msg.AdminURL != "" {
|
||||
config.AdminURL = msg.AdminURL
|
||||
}
|
||||
|
||||
if msg.InterfaceName != nil {
|
||||
config.InterfaceName = msg.InterfaceName
|
||||
}
|
||||
|
||||
if msg.WireguardPort != nil {
|
||||
wgPort := int(*msg.WireguardPort)
|
||||
config.WireguardPort = &wgPort
|
||||
}
|
||||
if msg.OptionalPreSharedKey != nil && *msg.OptionalPreSharedKey != "" {
|
||||
config.PreSharedKey = msg.OptionalPreSharedKey
|
||||
|
||||
if msg.OptionalPreSharedKey != nil {
|
||||
if *msg.OptionalPreSharedKey != "" {
|
||||
config.PreSharedKey = msg.OptionalPreSharedKey
|
||||
}
|
||||
}
|
||||
|
||||
if msg.CleanDNSLabels {
|
||||
config.DNSLabels = domain.List{}
|
||||
|
||||
} else if msg.DnsLabels != nil {
|
||||
config.DNSLabels = domain.FromPunycodeList(msg.DnsLabels)
|
||||
dnsLabels := domain.FromPunycodeList(msg.DnsLabels)
|
||||
config.DNSLabels = dnsLabels
|
||||
}
|
||||
|
||||
if msg.CleanNATExternalIPs {
|
||||
@@ -445,6 +363,7 @@ func setConfigInputFromRequest(msg *proto.SetConfigRequest) (profilemanager.Conf
|
||||
if string(msg.CustomDNSAddress) == "empty" {
|
||||
config.CustomDNSAddress = []byte{}
|
||||
}
|
||||
|
||||
config.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist
|
||||
|
||||
if msg.DnsRouteInterval != nil {
|
||||
@@ -477,31 +396,22 @@ func setConfigInputFromRequest(msg *proto.SetConfigRequest) (profilemanager.Conf
|
||||
ttl := int(*msg.SshJWTCacheTTL)
|
||||
config.SSHJWTCacheTTL = &ttl
|
||||
}
|
||||
|
||||
if msg.Mtu != nil {
|
||||
mtu := uint16(*msg.Mtu)
|
||||
config.MTU = &mtu
|
||||
}
|
||||
return config, nil
|
||||
|
||||
if _, err := profilemanager.UpdateConfig(config); err != nil {
|
||||
log.Errorf("failed to update profile config: %v", err)
|
||||
return nil, fmt.Errorf("failed to update profile config: %w", err)
|
||||
}
|
||||
|
||||
return &proto.SetConfigResponse{}, nil
|
||||
}
|
||||
|
||||
// Login uses setup key to prepare configuration for the daemon.
|
||||
func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*proto.LoginResponse, error) {
|
||||
// Config-override gates. LoginRequest carries the same surface as
|
||||
// SetConfigRequest (managementUrl, PSK, ssh/rosenpass/port toggles,
|
||||
// ...), so the same protections must apply. Without these the CLI
|
||||
// command `netbird up --management-url=X` (which falls through to
|
||||
// Login when SetConfig is rejected — see cmd/up.go) would silently
|
||||
// bypass `--disable-update-settings` and any MDM policy.
|
||||
if loginRequestHasConfigOverrides(msg) {
|
||||
if s.checkUpdateSettingsDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
|
||||
}
|
||||
policy := loadMDMPolicy()
|
||||
if err := rejectMDMManagedFieldConflicts(loginRequestMDMConflicts(msg, policy)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
if s.actCancel != nil {
|
||||
s.actCancel()
|
||||
@@ -535,30 +445,9 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
}
|
||||
|
||||
if msg.ProfileName != nil {
|
||||
if *msg.ProfileName != "default" && (msg.Username == nil || *msg.Username == "") {
|
||||
log.Errorf("profile name is set to %s, but username is not provided", *msg.ProfileName)
|
||||
return nil, fmt.Errorf("profile name is set to %s, but username is not provided", *msg.ProfileName)
|
||||
}
|
||||
|
||||
var username string
|
||||
if *msg.ProfileName != "default" {
|
||||
username = *msg.Username
|
||||
}
|
||||
|
||||
if *msg.ProfileName != activeProf.Name && username != activeProf.Username {
|
||||
if s.checkProfilesDisabled() {
|
||||
log.Errorf("profiles are disabled, you cannot use this feature without profiles enabled")
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
|
||||
}
|
||||
|
||||
log.Infof("switching to profile %s for user '%s'", *msg.ProfileName, username)
|
||||
if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: *msg.ProfileName,
|
||||
Username: username,
|
||||
}); err != nil {
|
||||
log.Errorf("failed to set active profile state: %v", err)
|
||||
return nil, fmt.Errorf("failed to set active profile state: %w", err)
|
||||
}
|
||||
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
|
||||
log.Errorf("failed to switch profile: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -568,7 +457,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
return nil, fmt.Errorf("failed to get active profile state: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username)
|
||||
log.Infof("active profile: %s for %s", activeProf.ID, activeProf.Username)
|
||||
|
||||
s.mutex.Lock()
|
||||
|
||||
@@ -741,13 +630,7 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
|
||||
// Up starts engine work in the daemon.
|
||||
func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) {
|
||||
s.mutex.Lock()
|
||||
// clientRunning is the daemon-intent flag (set by previous Up/Start, cleared
|
||||
// by Down). connectionGoroutineRunning() reports whether the previous retry-loop
|
||||
// goroutine is still trying. When intent is up AND goroutine is alive,
|
||||
// the existing engine is on the job — just wait for it. When intent
|
||||
// is up but the goroutine has given up (backoff exhausted) OR when
|
||||
// intent is down, fall through to spawn a fresh retry loop.
|
||||
if s.clientRunning && s.connectionGoroutineRunning() {
|
||||
if s.clientRunning {
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
@@ -806,10 +689,10 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
}
|
||||
|
||||
if msg != nil && msg.ProfileName != nil {
|
||||
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
|
||||
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
|
||||
s.mutex.Unlock()
|
||||
log.Errorf("failed to switch profile: %v", err)
|
||||
return nil, fmt.Errorf("failed to switch profile: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -820,7 +703,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
return nil, fmt.Errorf("failed to get active profile state: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username)
|
||||
log.Infof("active profile: %s for %s", activeProf.ID, activeProf.Username)
|
||||
|
||||
config, _, err := s.getConfig(activeProf)
|
||||
if err != nil {
|
||||
@@ -838,7 +721,6 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
s.clientGiveUpChan = make(chan struct{})
|
||||
|
||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
||||
s.publishConfigChangedEvent("up_rpc")
|
||||
|
||||
s.mutex.Unlock()
|
||||
return s.waitForUp(callerCtx)
|
||||
@@ -864,34 +746,60 @@ func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) switchProfileIfNeeded(profileName string, userName *string, activeProf *profilemanager.ActiveProfileState) error {
|
||||
if profileName != "default" && (userName == nil || *userName == "") {
|
||||
log.Errorf("profile name is set to %s, but username is not provided", profileName)
|
||||
return fmt.Errorf("profile name is set to %s, but username is not provided", profileName)
|
||||
// resolveProfileHandle resolves a wire-level profile handle (display
|
||||
// name, ID, or unique ID prefix) to a concrete profile. Returns gRPC
|
||||
// status errors so handlers can return them directly.
|
||||
func (s *Server) resolveProfileHandle(handle, username string) (*profilemanager.Profile, error) {
|
||||
p, err := s.profileManager.ResolveProfile(handle, username)
|
||||
if err == nil {
|
||||
return p, nil
|
||||
}
|
||||
var amb *profilemanager.ErrAmbiguousHandle
|
||||
if errors.As(err, &amb) {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "%v", amb)
|
||||
}
|
||||
if errors.Is(err, profilemanager.ErrProfileNotFound) {
|
||||
return nil, gstatus.Errorf(codes.NotFound, "profile %q not found", handle)
|
||||
}
|
||||
return nil, fmt.Errorf("resolve profile: %w", err)
|
||||
}
|
||||
|
||||
// switchProfileIfNeeded resolves the user-supplied handle, updates the
|
||||
// active profile state if it differs from the current one, and returns
|
||||
// the resolved profile so callers can include its ID in RPC responses.
|
||||
func (s *Server) switchProfileIfNeeded(handle string, userName *string, activeProf *profilemanager.ActiveProfileState) (*profilemanager.Profile, error) {
|
||||
if handle != profilemanager.DefaultProfileName && (userName == nil || *userName == "") {
|
||||
log.Errorf("profile name is set to %s, but username is not provided", handle)
|
||||
return nil, fmt.Errorf("profile name is set to %s, but username is not provided", handle)
|
||||
}
|
||||
|
||||
var username string
|
||||
if profileName != "default" {
|
||||
if handle != profilemanager.DefaultProfileName {
|
||||
username = *userName
|
||||
}
|
||||
|
||||
if profileName != activeProf.Name || username != activeProf.Username {
|
||||
resolved, err := s.resolveProfileHandle(handle, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resolved.ID != activeProf.ID || username != activeProf.Username {
|
||||
if s.checkProfilesDisabled() {
|
||||
log.Errorf("profiles are disabled, you cannot use this feature without profiles enabled")
|
||||
return gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
|
||||
}
|
||||
|
||||
log.Infof("switching to profile %s for user %s", profileName, username)
|
||||
log.Infof("switching to profile %s (%s) for user %s", resolved.Name, resolved.ID, username)
|
||||
if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: profileName,
|
||||
ID: resolved.ID,
|
||||
Username: username,
|
||||
}); err != nil {
|
||||
log.Errorf("failed to set active profile state: %v", err)
|
||||
return fmt.Errorf("failed to set active profile state: %w", err)
|
||||
return nil, fmt.Errorf("failed to set active profile state: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return resolved, nil
|
||||
}
|
||||
|
||||
// SwitchProfile switches the active profile in the daemon.
|
||||
@@ -906,9 +814,9 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
|
||||
}
|
||||
|
||||
if msg != nil && msg.ProfileName != nil {
|
||||
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
|
||||
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
|
||||
log.Errorf("failed to switch profile: %v", err)
|
||||
return nil, fmt.Errorf("failed to switch profile: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
activeProf, err = s.profileManager.GetActiveProfileState()
|
||||
@@ -924,7 +832,7 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
|
||||
|
||||
s.config = config
|
||||
|
||||
return &proto.SwitchProfileResponse{}, nil
|
||||
return &proto.SwitchProfileResponse{Id: activeProf.ID.String()}, nil
|
||||
}
|
||||
|
||||
// Down engine work in the daemon.
|
||||
@@ -967,12 +875,6 @@ func (s *Server) cleanupConnection() error {
|
||||
return ErrServiceNotUp
|
||||
}
|
||||
|
||||
// Daemon intent flips to "down" — all callers (Down RPC,
|
||||
// Logout RPC handlers) tear down the connection because the user
|
||||
// explicitly asked for it. MDM restart does NOT go through this
|
||||
// path, so its clientRunning stays true.
|
||||
s.clientRunning = false
|
||||
|
||||
// Capture the engine reference before cancelling the context.
|
||||
// After actCancel(), the connectWithRetryRuns goroutine wakes up
|
||||
// and sets connectClient.engine = nil, causing connectClient.Stop()
|
||||
@@ -1014,22 +916,27 @@ func (s *Server) Logout(ctx context.Context, msg *proto.LogoutRequest) (*proto.L
|
||||
}
|
||||
|
||||
func (s *Server) handleProfileLogout(ctx context.Context, msg *proto.LogoutRequest) (*proto.LogoutResponse, error) {
|
||||
if err := s.validateProfileOperation(*msg.ProfileName, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if msg.Username == nil || *msg.Username == "" {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided when profile name is specified")
|
||||
}
|
||||
username := *msg.Username
|
||||
|
||||
if err := s.logoutFromProfile(ctx, *msg.ProfileName, username); err != nil {
|
||||
log.Errorf("failed to logout from profile %s: %v", *msg.ProfileName, err)
|
||||
resolved, err := s.resolveProfileHandle(*msg.ProfileName, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.validateProfileOperation(resolved.ID, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.logoutFromProfile(ctx, resolved); err != nil {
|
||||
log.Errorf("failed to logout from profile %s: %v", resolved.ID, err)
|
||||
return nil, gstatus.Errorf(codes.Internal, "logout: %v", err)
|
||||
}
|
||||
|
||||
activeProf, _ := s.profileManager.GetActiveProfileState()
|
||||
if activeProf != nil && activeProf.Name == *msg.ProfileName {
|
||||
if activeProf != nil && activeProf.ID == resolved.ID {
|
||||
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
|
||||
log.Errorf("failed to cleanup connection: %v", err)
|
||||
}
|
||||
@@ -1091,30 +998,30 @@ func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*prof
|
||||
return config, configExisted, nil
|
||||
}
|
||||
|
||||
func (s *Server) canRemoveProfile(profileName string) error {
|
||||
if profileName == profilemanager.DefaultProfileName {
|
||||
func (s *Server) canRemoveProfile(id profilemanager.ID) error {
|
||||
if id == profilemanager.DefaultProfileName {
|
||||
return fmt.Errorf("remove profile with reserved name: %s", profilemanager.DefaultProfileName)
|
||||
}
|
||||
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err == nil && activeProf.Name == profileName {
|
||||
return fmt.Errorf("remove active profile: %s", profileName)
|
||||
if err == nil && activeProf.ID == id {
|
||||
return fmt.Errorf("remove active profile: %s", id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) validateProfileOperation(profileName string, allowActiveProfile bool) error {
|
||||
func (s *Server) validateProfileOperation(id profilemanager.ID, allowActiveProfile bool) error {
|
||||
if s.checkProfilesDisabled() {
|
||||
return gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
|
||||
}
|
||||
|
||||
if profileName == "" {
|
||||
if id == "" {
|
||||
return gstatus.Errorf(codes.InvalidArgument, "profile name must be provided")
|
||||
}
|
||||
|
||||
if !allowActiveProfile {
|
||||
if err := s.canRemoveProfile(profileName); err != nil {
|
||||
if err := s.canRemoveProfile(id); err != nil {
|
||||
return gstatus.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
}
|
||||
@@ -1122,25 +1029,20 @@ func (s *Server) validateProfileOperation(profileName string, allowActiveProfile
|
||||
return nil
|
||||
}
|
||||
|
||||
// logoutFromProfile logs out from a specific profile by loading its config and sending logout request
|
||||
func (s *Server) logoutFromProfile(ctx context.Context, profileName, username string) error {
|
||||
func (s *Server) logoutFromProfile(ctx context.Context, profile *profilemanager.Profile) error {
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err == nil && activeProf.Name == profileName && s.connectClient != nil {
|
||||
if err == nil && activeProf.ID == profile.ID && s.connectClient != nil {
|
||||
return s.sendLogoutRequest(ctx)
|
||||
}
|
||||
|
||||
profileState := &profilemanager.ActiveProfileState{
|
||||
Name: profileName,
|
||||
Username: username,
|
||||
}
|
||||
profilePath, err := profileState.FilePath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get profile path: %w", err)
|
||||
cfgPath := profile.Path
|
||||
if cfgPath == "" {
|
||||
cfgPath = profilemanager.DefaultConfigPath
|
||||
}
|
||||
|
||||
config, err := profilemanager.GetConfig(profilePath)
|
||||
config, err := profilemanager.GetConfig(cfgPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("profile '%s' not found", profileName)
|
||||
return fmt.Errorf("profile '%s' not found", profile.ID)
|
||||
}
|
||||
|
||||
return s.sendLogoutRequestWithConfig(ctx, config)
|
||||
@@ -1176,14 +1078,10 @@ func (s *Server) Status(
|
||||
msg *proto.StatusRequest,
|
||||
) (*proto.StatusResponse, error) {
|
||||
s.mutex.Lock()
|
||||
// Only wait if the retry-loop goroutine is alive and making
|
||||
// progress. clientRunning=true with connectionGoroutineRunning=false means the
|
||||
// backoff has given up — there is nothing to wait for; let the
|
||||
// caller observe the failed status directly.
|
||||
alive := s.connectionGoroutineRunning()
|
||||
clientRunning := s.clientRunning
|
||||
s.mutex.Unlock()
|
||||
|
||||
if msg.WaitForReady != nil && *msg.WaitForReady && alive {
|
||||
if msg.WaitForReady != nil && *msg.WaitForReady && clientRunning {
|
||||
state := internal.CtxGetState(s.rootCtx)
|
||||
status, err := state.Status()
|
||||
if err != nil {
|
||||
@@ -1558,15 +1456,14 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
prof := profilemanager.ActiveProfileState{
|
||||
Name: req.ProfileName,
|
||||
Username: req.Username,
|
||||
}
|
||||
|
||||
cfgPath, err := prof.FilePath()
|
||||
resolved, err := s.resolveProfileHandle(req.ProfileName, req.Username)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get active profile file path: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
|
||||
log.Errorf("failed to resolve profile %q: %v", req.ProfileName, err)
|
||||
return nil, err
|
||||
}
|
||||
cfgPath := resolved.Path
|
||||
if cfgPath == "" {
|
||||
cfgPath = profilemanager.DefaultConfigPath
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.GetConfig(cfgPath)
|
||||
@@ -1654,7 +1551,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
|
||||
DisableSSHAuth: disableSSHAuth,
|
||||
SshJWTCacheTTL: sshJWTCacheTTL,
|
||||
MDMManagedFields: cfg.Policy().ManagedKeys(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1671,12 +1567,39 @@ func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
|
||||
}
|
||||
|
||||
if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil {
|
||||
created, err := s.profileManager.AddProfile(msg.ProfileName, msg.Username)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create profile: %v", err)
|
||||
return nil, fmt.Errorf("failed to create profile: %w", err)
|
||||
}
|
||||
|
||||
return &proto.AddProfileResponse{}, nil
|
||||
return &proto.AddProfileResponse{Id: created.ID.String()}, nil
|
||||
}
|
||||
|
||||
func (s *Server) RenameProfile(ctx context.Context, msg *proto.RenameProfileRequest) (*proto.RenameProfileResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.checkProfilesDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
|
||||
}
|
||||
|
||||
if msg.Handle == "" || msg.Username == "" || msg.NewProfileName == "" {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name, username and new profile name must be provided")
|
||||
}
|
||||
|
||||
resolved, err := s.resolveProfileHandle(msg.Handle, msg.Username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.profileManager.RenameProfile(resolved.ID, msg.Username, msg.NewProfileName)
|
||||
if err != nil {
|
||||
log.Errorf("failed to rename profile: %v", err)
|
||||
return nil, fmt.Errorf("failed to rename profile: %w", err)
|
||||
}
|
||||
|
||||
return &proto.RenameProfileResponse{OldProfileName: resolved.Name}, nil
|
||||
}
|
||||
|
||||
// RemoveProfile removes a profile from the daemon.
|
||||
@@ -1684,20 +1607,29 @@ func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequ
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
|
||||
if s.checkProfilesDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
|
||||
}
|
||||
|
||||
if msg.ProfileName == "" {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name must be provided")
|
||||
}
|
||||
|
||||
resolved, err := s.resolveProfileHandle(msg.ProfileName, msg.Username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
|
||||
log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
|
||||
if err := s.logoutFromProfile(ctx, resolved); err != nil {
|
||||
log.Warnf("failed to logout from profile %s before removal: %v", resolved.ID, err)
|
||||
}
|
||||
|
||||
if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
|
||||
if err := s.profileManager.RemoveProfile(resolved.ID, msg.Username); err != nil {
|
||||
log.Errorf("failed to remove profile: %v", err)
|
||||
return nil, fmt.Errorf("failed to remove profile: %w", err)
|
||||
}
|
||||
|
||||
return &proto.RemoveProfileResponse{}, nil
|
||||
return &proto.RemoveProfileResponse{Id: resolved.ID.String()}, nil
|
||||
}
|
||||
|
||||
// ListProfiles lists all profiles in the daemon.
|
||||
@@ -1720,6 +1652,7 @@ func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesReques
|
||||
}
|
||||
for i, profile := range profiles {
|
||||
response.Profiles[i] = &proto.Profile{
|
||||
Id: profile.ID.String(),
|
||||
Name: profile.Name,
|
||||
IsActive: profile.IsActive,
|
||||
}
|
||||
@@ -1728,7 +1661,9 @@ func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesReques
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// GetActiveProfile returns the active profile in the daemon.
|
||||
// GetActiveProfile returns the active profile in the daemon. The ProfileName
|
||||
// field carries the display name for backwards compatibility with UI clients,
|
||||
// new callers should prefer Id.
|
||||
func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
@@ -1739,9 +1674,23 @@ func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfi
|
||||
return nil, fmt.Errorf("failed to get active profile state: %w", err)
|
||||
}
|
||||
|
||||
// Fallback to legacy name == ID
|
||||
displayName := activeProfile.ID.String()
|
||||
if activeProfile.ID != profilemanager.DefaultProfileName {
|
||||
if profiles, lerr := s.profileManager.ListProfiles(activeProfile.Username); lerr == nil {
|
||||
for _, p := range profiles {
|
||||
if p.ID == activeProfile.ID {
|
||||
displayName = p.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &proto.GetActiveProfileResponse{
|
||||
ProfileName: activeProfile.Name,
|
||||
ProfileName: displayName,
|
||||
Username: activeProfile.Username,
|
||||
Id: activeProfile.ID.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1753,7 +1702,7 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
|
||||
features := &proto.GetFeaturesResponse{
|
||||
DisableProfiles: s.checkProfilesDisabled(),
|
||||
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
|
||||
DisableNetworks: s.checkNetworksDisabled(),
|
||||
DisableNetworks: s.networksDisabled,
|
||||
}
|
||||
|
||||
return features, nil
|
||||
@@ -1775,46 +1724,22 @@ func (s *Server) connect(ctx context.Context, config *profilemanager.Config, sta
|
||||
return nil
|
||||
}
|
||||
|
||||
// MDM authority: when the platform-native MDM source sets a kill switch
|
||||
// key (regardless of true/false value), that value wins. The CLI flag
|
||||
// supplied at service install time is the fallback used only when the
|
||||
// MDM source is silent on the key. This honors the "MDM decides
|
||||
// everything" semantic agreed for NET-1214 — an admin pushing
|
||||
// disableX=false via MDM explicitly re-enables the feature even on a
|
||||
// box installed with --disable-X.
|
||||
func (s *Server) checkProfilesDisabled() bool {
|
||||
if s.config != nil {
|
||||
if v, ok := s.config.Policy().GetBool(mdm.KeyDisableProfiles); ok {
|
||||
return v
|
||||
}
|
||||
// Check if the environment variable is set to disable profiles
|
||||
if s.profilesDisabled {
|
||||
return true
|
||||
}
|
||||
return s.profilesDisabled
|
||||
}
|
||||
|
||||
// checkNetworksDisabled reports whether the networks/exit-node feature
|
||||
// is disabled on this daemon instance. Resolved MDM-first: when the
|
||||
// active policy declares mdm.KeyDisableNetworks the policy value wins
|
||||
// (regardless of true/false), so an admin can re-enable the feature
|
||||
// via MDM even on a host that was installed with --disable-networks.
|
||||
// Falls back to the s.networksDisabled CLI flag when the policy is
|
||||
// silent on the key. Mirrors checkProfilesDisabled and
|
||||
// checkUpdateSettingsDisabled.
|
||||
func (s *Server) checkNetworksDisabled() bool {
|
||||
if s.config != nil {
|
||||
if v, ok := s.config.Policy().GetBool(mdm.KeyDisableNetworks); ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return s.networksDisabled
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Server) checkUpdateSettingsDisabled() bool {
|
||||
if s.config != nil {
|
||||
if v, ok := s.config.Policy().GetBool(mdm.KeyDisableUpdateSettings); ok {
|
||||
return v
|
||||
}
|
||||
// Check if the environment variable is set to disable profiles
|
||||
if s.updateSettingsDisabled {
|
||||
return true
|
||||
}
|
||||
return s.updateSettingsDisabled
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Server) startUpdateManagerForGUI() {
|
||||
|
||||
@@ -101,7 +101,6 @@ func TestCleanupConnection_ClearsConnectClient(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
|
||||
assert.False(t, s.clientRunning, "clientRunning should be cleared after cleanup (intent = down)")
|
||||
}
|
||||
|
||||
// TestCleanState_NilConnectClient validates that CleanState doesn't panic
|
||||
@@ -145,20 +144,17 @@ func TestDownThenUp_StaleRunningChan(t *testing.T) {
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
s.actCancel = cancel
|
||||
|
||||
// Simulate Down(): cleanupConnection sets connectClient = nil and
|
||||
// flips clientRunning to false (intent = down). The connectionGoroutineRunning state
|
||||
// remains independent of intent — derived from clientGiveUpChan.
|
||||
// Simulate Down(): cleanupConnection sets connectClient = nil
|
||||
s.mutex.Lock()
|
||||
err := s.cleanupConnection()
|
||||
s.mutex.Unlock()
|
||||
require.NoError(t, err)
|
||||
|
||||
// After cleanup: connectClient is nil, clientRunning is false (intent
|
||||
// cleared by cleanupConnection), connectionGoroutineRunning may still be true
|
||||
// (goroutine teardown is independent of the intent flag).
|
||||
// After cleanup: connectClient is nil, clientRunning still true
|
||||
// (goroutine hasn't exited yet)
|
||||
s.mutex.Lock()
|
||||
assert.Nil(t, s.connectClient, "connectClient should be nil after cleanup")
|
||||
assert.False(t, s.clientRunning, "clientRunning should be cleared by cleanupConnection (intent = down)")
|
||||
assert.True(t, s.clientRunning, "clientRunning still true until goroutine exits")
|
||||
s.mutex.Unlock()
|
||||
|
||||
// waitForUp() returns immediately due to stale closed clientRunningChan
|
||||
|
||||
@@ -97,7 +97,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: "test-profile",
|
||||
ID: "test-profile",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -158,7 +158,7 @@ func TestServer_Up(t *testing.T) {
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: profName,
|
||||
ID: profilemanager.ID(profName),
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -228,7 +228,7 @@ func TestServer_SubcribeEvents(t *testing.T) {
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: "default",
|
||||
ID: "default",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -1,198 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// withMDMPolicy temporarily overrides the server-package loadMDMPolicy hook
|
||||
// so SetConfig observes the supplied Policy. Restores the original loader
|
||||
// at test cleanup.
|
||||
func withMDMPolicy(t *testing.T, policy *mdm.Policy) {
|
||||
t.Helper()
|
||||
prev := loadMDMPolicy
|
||||
loadMDMPolicy = func() *mdm.Policy { return policy }
|
||||
t.Cleanup(func() { loadMDMPolicy = prev })
|
||||
}
|
||||
|
||||
// setupServerWithProfile mirrors the boilerplate of TestSetConfig_AllFieldsSaved:
|
||||
// overrides profilemanager paths to a temp dir, seeds a profile, sets it
|
||||
// active, and constructs a Server instance. Returns the constructed server
|
||||
// plus context + profile name + username + cfgPath for the seeded profile.
|
||||
func setupServerWithProfile(t *testing.T) (s *Server, ctx context.Context, profName, username, cfgPath string) {
|
||||
t.Helper()
|
||||
tempDir := t.TempDir()
|
||||
|
||||
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
|
||||
origDefaultConfigPath := profilemanager.DefaultConfigPath
|
||||
origActiveProfileStatePath := profilemanager.ActiveProfileStatePath
|
||||
profilemanager.ConfigDirOverride = tempDir
|
||||
profilemanager.DefaultConfigPathDir = tempDir
|
||||
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
|
||||
profilemanager.DefaultConfigPath = filepath.Join(tempDir, "default.json")
|
||||
t.Cleanup(func() {
|
||||
profilemanager.DefaultConfigPathDir = origDefaultProfileDir
|
||||
profilemanager.ActiveProfileStatePath = origActiveProfileStatePath
|
||||
profilemanager.DefaultConfigPath = origDefaultConfigPath
|
||||
profilemanager.ConfigDirOverride = ""
|
||||
})
|
||||
|
||||
currUser, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
|
||||
profName = "test-profile-mdm"
|
||||
cfgPath = filepath.Join(tempDir, profName+".json")
|
||||
|
||||
_, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: cfgPath,
|
||||
ManagementURL: "https://api.netbird.io:443",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
require.NoError(t, pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: profName,
|
||||
Username: currUser.Username,
|
||||
}))
|
||||
|
||||
ctx = context.Background()
|
||||
s = New(ctx, "console", "", false, false, false, false)
|
||||
return s, ctx, profName, currUser.Username, cfgPath
|
||||
}
|
||||
|
||||
// extractViolation pulls the MDMManagedFieldsViolation detail from a
|
||||
// FailedPrecondition error. Fails the test if absent or malformed.
|
||||
func extractViolation(t *testing.T, err error) *proto.MDMManagedFieldsViolation {
|
||||
t.Helper()
|
||||
require.Error(t, err)
|
||||
st, ok := gstatus.FromError(err)
|
||||
require.True(t, ok, "error must be a gRPC status: %v", err)
|
||||
require.Equal(t, codes.FailedPrecondition, st.Code(), "expected FailedPrecondition, got %s", st.Code())
|
||||
for _, d := range st.Details() {
|
||||
if v, ok := d.(*proto.MDMManagedFieldsViolation); ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
t.Fatalf("MDMManagedFieldsViolation detail not found on status; details: %v", st.Details())
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMReject_SingleField(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
ManagementUrl: "https://user.tried.this.com:443",
|
||||
})
|
||||
|
||||
v := extractViolation(t, err)
|
||||
assert.Equal(t, []string{mdm.KeyManagementURL}, v.GetFields())
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMReject_MultipleFields(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
mdm.KeyBlockInbound: true,
|
||||
mdm.KeyRosenpassEnabled: true,
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
blockInbound := false
|
||||
rosenpassEnabled := false
|
||||
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
ManagementUrl: "https://user.tried.this.com:443",
|
||||
BlockInbound: &blockInbound,
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
})
|
||||
|
||||
v := extractViolation(t, err)
|
||||
assert.ElementsMatch(t, []string{
|
||||
mdm.KeyManagementURL,
|
||||
mdm.KeyBlockInbound,
|
||||
mdm.KeyRosenpassEnabled,
|
||||
}, v.GetFields())
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMReject_AllOrNothing(t *testing.T) {
|
||||
// MDM enforces ManagementURL only; user request touches both the
|
||||
// enforced field AND a non-enforced field (RosenpassEnabled).
|
||||
// The whole request must be rejected — non-conflicting fields are not
|
||||
// applied either.
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, cfgPath := setupServerWithProfile(t)
|
||||
|
||||
rosenpassEnabled := true
|
||||
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
ManagementUrl: "https://user.tried.this.com:443",
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
})
|
||||
|
||||
v := extractViolation(t, err)
|
||||
assert.Equal(t, []string{mdm.KeyManagementURL}, v.GetFields())
|
||||
|
||||
// Confirm RosenpassEnabled was NOT applied even though it was not
|
||||
// in the conflict list: the request was rejected as a whole.
|
||||
reloaded, err := profilemanager.GetConfig(cfgPath)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, reloaded.RosenpassEnabled, "non-conflicting field must not be applied when request is rejected")
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMAllow_NonManagedFields(t *testing.T) {
|
||||
// MDM enforces ManagementURL but the user only writes RosenpassEnabled.
|
||||
// Request must succeed.
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
rosenpassEnabled := true
|
||||
resp, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMEmpty_NoEnforcement(t *testing.T) {
|
||||
// No MDM policy active: any field can be written.
|
||||
withMDMPolicy(t, mdm.NewPolicy(nil))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
resp, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
ManagementUrl: "https://user.changed.url.com:443",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
}
|
||||
@@ -47,7 +47,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: profName,
|
||||
ID: profilemanager.ID(profName),
|
||||
Username: currUser.Username,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -96,7 +96,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
DisableNotifications: &disableNotifications,
|
||||
LazyConnectionEnabled: &lazyConnectionEnabled,
|
||||
BlockInbound: &blockInbound,
|
||||
DisableIpv6: &disableIPv6,
|
||||
DisableIpv6: &disableIPv6,
|
||||
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
|
||||
CleanNATExternalIPs: false,
|
||||
CustomDNSAddress: []byte("1.1.1.1:53"),
|
||||
@@ -112,7 +112,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
profState := profilemanager.ActiveProfileState{
|
||||
Name: profName,
|
||||
ID: profilemanager.ID(profName),
|
||||
Username: currUser.Username,
|
||||
}
|
||||
cfgPath, err := profState.FilePath()
|
||||
|
||||
@@ -38,7 +38,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/ui/desktop"
|
||||
"github.com/netbirdio/netbird/client/ui/event"
|
||||
@@ -57,22 +56,8 @@ const (
|
||||
const (
|
||||
censoredPreSharedKey = "**********"
|
||||
maxSSHJWTCacheTTL = 86_400 // 24 hours in seconds
|
||||
// mdmFieldSuffix is appended to plain-text Entry widgets in the
|
||||
// advanced Settings window when the underlying field is enforced
|
||||
// by MDM, so the user sees the lock indicator inline next to the
|
||||
// value. Stripped before any read site that feeds the value back
|
||||
// into a SetConfig request (saveSettings / parseNumericSettings).
|
||||
mdmFieldSuffix = " (MDM)"
|
||||
)
|
||||
|
||||
// main is the entry point for the UI tray/client binary. Parses CLI
|
||||
// flags, initialises logging, builds the Fyne application and tray
|
||||
// icons, and constructs the service client (which may open a
|
||||
// requested UI window). When a window-mode flag is set the Fyne event
|
||||
// loop runs and main returns; otherwise main enforces single-instance
|
||||
// behaviour (signalling an existing instance to show its window when
|
||||
// present), sets up signal handling + default fonts, and runs the
|
||||
// system tray loop.
|
||||
func main() {
|
||||
flags := parseFlags()
|
||||
|
||||
@@ -330,13 +315,9 @@ type serviceClient struct {
|
||||
isUpdateIconActive bool
|
||||
isEnforcedUpdate bool
|
||||
lastNotifiedVersion string
|
||||
settingsEnabled bool
|
||||
profilesEnabled bool
|
||||
networksEnabled bool
|
||||
// networksMenuEnabled caches the last applied enabled-state of the
|
||||
// mNetworks + mExitNode submenu items. Combines features.DisableNetworks
|
||||
// AND s.connected — both must be true for the menus to be active.
|
||||
// Zero value (false) matches the Disable() call at AddMenuItem time.
|
||||
networksMenuEnabled bool
|
||||
showNetworks bool
|
||||
wNetworks fyne.Window
|
||||
wProfiles fyne.Window
|
||||
@@ -355,13 +336,6 @@ type serviceClient struct {
|
||||
updateContextCancel context.CancelFunc
|
||||
|
||||
connectCancel context.CancelFunc
|
||||
|
||||
// mdmManagedFields caches the names of MDM-enforced policy keys
|
||||
// surfaced by the daemon in GetConfigResponse. Each refresh of
|
||||
// daemon config (loadSettings, getSrvConfig, config_changed event)
|
||||
// updates this set and re-applies the lock/badge to the affected
|
||||
// menu items and settings-form widgets.
|
||||
mdmManagedFields map[string]bool
|
||||
}
|
||||
|
||||
type menuHandler struct {
|
||||
@@ -467,12 +441,15 @@ func (s *serviceClient) updateIcon() {
|
||||
}
|
||||
|
||||
func (s *serviceClient) showSettingsUI() {
|
||||
// DisableUpdateSettings no longer gates the window from opening:
|
||||
// the daemon blocks every actual mutation at SetConfig / Login,
|
||||
// so the window is safe to show as a read-only view. The previous
|
||||
// early-return also blocked Advanced Settings whenever update
|
||||
// editing was off, which conflated two distinct kill switches
|
||||
// (see comment in checkAndUpdateFeatures).
|
||||
// Check if update settings are disabled by daemon
|
||||
features, err := s.getFeatures()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get features from daemon: %v", err)
|
||||
// Continue with default behavior if features can't be retrieved
|
||||
} else if features != nil && features.DisableUpdateSettings {
|
||||
log.Warn("Update settings are disabled by daemon")
|
||||
return
|
||||
}
|
||||
|
||||
// add settings window UI elements.
|
||||
s.wSettings = s.app.NewWindow("NetBird Settings")
|
||||
@@ -555,7 +532,7 @@ func (s *serviceClient) saveSettings() {
|
||||
return
|
||||
}
|
||||
|
||||
iMngURL := strings.TrimSpace(strings.TrimSuffix(s.iMngURL.Text, mdmFieldSuffix))
|
||||
iMngURL := strings.TrimSpace(s.iMngURL.Text)
|
||||
|
||||
if s.hasSettingsChanged(iMngURL, port, mtu) {
|
||||
if err := s.applySettingsChanges(iMngURL, port, mtu); err != nil {
|
||||
@@ -577,7 +554,7 @@ func (s *serviceClient) validateSettings() error {
|
||||
}
|
||||
|
||||
func (s *serviceClient) parseNumericSettings() (int64, int64, error) {
|
||||
port, err := strconv.ParseInt(strings.TrimSpace(strings.TrimSuffix(s.iInterfacePort.Text, mdmFieldSuffix)), 10, 64)
|
||||
port, err := strconv.ParseInt(s.iInterfacePort.Text, 10, 64)
|
||||
if err != nil {
|
||||
return 0, 0, errors.New("invalid interface port")
|
||||
}
|
||||
@@ -645,7 +622,7 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
|
||||
}
|
||||
|
||||
req := &proto.SetConfigRequest{
|
||||
ProfileName: activeProf.Name,
|
||||
ProfileName: activeProf.ID.String(),
|
||||
Username: currUser.Username,
|
||||
}
|
||||
|
||||
@@ -686,15 +663,7 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
|
||||
req.SshJWTCacheTTL = &sshJWTCacheTTL32
|
||||
}
|
||||
|
||||
// Only attach the PSK when the user actually typed something:
|
||||
// - "" means the field was left untouched (we deliberately render
|
||||
// an empty Text + placeholder hint to avoid leaking the daemon's
|
||||
// "**********" redaction through the password reveal toggle);
|
||||
// sending an empty pointer would tell the daemon to clear / overwrite
|
||||
// the on-disk or MDM-enforced PSK, which then trips the MDM
|
||||
// conflict gate when PSK is policy-managed.
|
||||
// - "**********" is the redacted echo (legacy non-MDM path); also a no-op.
|
||||
if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey {
|
||||
if s.iPreSharedKey.Text != censoredPreSharedKey {
|
||||
req.OptionalPreSharedKey = &s.iPreSharedKey.Text
|
||||
}
|
||||
|
||||
@@ -818,13 +787,15 @@ func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginRe
|
||||
return nil, fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
handle := activeProf.ID.String()
|
||||
|
||||
loginReq := &proto.LoginRequest{
|
||||
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
|
||||
ProfileName: &activeProf.Name,
|
||||
ProfileName: &handle,
|
||||
Username: &currUser.Username,
|
||||
}
|
||||
|
||||
profileState, err := s.profileManager.GetProfileState(activeProf.Name)
|
||||
profileState, err := s.profileManager.GetProfileState(activeProf.ID)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
@@ -1067,13 +1038,6 @@ func (s *serviceClient) onTrayReady() {
|
||||
}
|
||||
|
||||
s.mProfile = newProfileMenu(*newProfileMenuArgs)
|
||||
// Seed the transition cache to match the actual default menu
|
||||
// state (visible / enabled). Without this, the first
|
||||
// checkAndUpdateFeatures tick that observes DisableProfiles=true
|
||||
// is a no-op (cache zero-value == desired-false) and the menu
|
||||
// never gets hidden — symptom: MDM enforces the kill switch but
|
||||
// the profile menu stays clickable.
|
||||
s.profilesEnabled = true
|
||||
|
||||
systray.AddSeparator()
|
||||
s.mUp = systray.AddMenuItem("Connect", "Connect")
|
||||
@@ -1093,18 +1057,18 @@ func (s *serviceClient) onTrayReady() {
|
||||
s.mCreateDebugBundle = s.mSettings.AddSubMenuItem("Create Debug Bundle", debugBundleMenuDescr)
|
||||
s.loadSettings()
|
||||
|
||||
// Disable profile menu if profiles are disabled by daemon.
|
||||
// DisableUpdateSettings is enforced at the daemon's SetConfig /
|
||||
// Login gates, not by hiding the UI — so the Settings menu (and
|
||||
// its Advanced Settings submenu, which has its own kill switch)
|
||||
// stays visible and the user can still inspect current values.
|
||||
// Disable settings menu if update settings are disabled by daemon
|
||||
features, err := s.getFeatures()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get features from daemon: %v", err)
|
||||
// Continue with default behavior if features can't be retrieved
|
||||
} else if features != nil && features.DisableProfiles {
|
||||
s.mProfile.setEnabled(false)
|
||||
s.profilesEnabled = false
|
||||
} else {
|
||||
if features != nil && features.DisableUpdateSettings {
|
||||
s.setSettingsEnabled(false)
|
||||
}
|
||||
if features != nil && features.DisableProfiles {
|
||||
s.mProfile.setEnabled(false)
|
||||
}
|
||||
}
|
||||
|
||||
s.exitNodeMu.Lock()
|
||||
@@ -1138,20 +1102,13 @@ func (s *serviceClient) onTrayReady() {
|
||||
// update exit node menu in case service is already connected
|
||||
go s.updateExitNodes()
|
||||
|
||||
// Features (DisableProfiles, DisableUpdateSettings, DisableNetworks,
|
||||
// ...) only change in two ways: at service install time (CLI flag,
|
||||
// static) and at MDM ticker diff time. The daemon already publishes
|
||||
// a SystemEvent{type=config_changed} on every MDM-driven engine
|
||||
// restart, so the UI no longer needs to poll GetFeatures every 2 s.
|
||||
// A single fetch at startup covers the static CLI-flag case; the
|
||||
// event handler below covers MDM transitions. updateStatus stays in
|
||||
// the 2 s loop because connection / peer state genuinely change
|
||||
// continuously and have no event yet.
|
||||
s.checkAndUpdateFeatures()
|
||||
go func() {
|
||||
s.getSrvConfig()
|
||||
time.Sleep(100 * time.Millisecond) // To prevent race condition caused by systray not being fully initialized and ignoring setIcon
|
||||
for {
|
||||
// Check features before status so menus respect disable flags before being enabled
|
||||
s.checkAndUpdateFeatures()
|
||||
|
||||
err := s.updateStatus()
|
||||
if err != nil {
|
||||
log.Errorf("error while updating status: %v", err)
|
||||
@@ -1195,23 +1152,6 @@ func (s *serviceClient) onTrayReady() {
|
||||
s.onUpdateAvailable(newVersion, enforced)
|
||||
}
|
||||
})
|
||||
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
|
||||
// Daemon emits a config_changed event after every engine spawn
|
||||
// (Server.Start, Server.Up, MDM ticker restart). Re-sync the
|
||||
// tray submenu checkboxes from the fresh daemon-side config so
|
||||
// the user does not have to restart the tray to see CLI- or
|
||||
// MDM-driven changes.
|
||||
if event.Category == proto.SystemEvent_SYSTEM && event.Metadata["type"] == "config_changed" {
|
||||
log.Infof("config_changed event received (source=%s); refreshing settings + features", event.Metadata["source"])
|
||||
s.loadSettings()
|
||||
// MDM-driven feature kill switches (DisableProfiles /
|
||||
// DisableUpdateSettings / DisableNetworks) ride the same
|
||||
// config_changed signal because the daemon re-applies its
|
||||
// MDM policy on every engine spawn. Pull them in here so
|
||||
// the UI is up to date without a periodic GetFeatures poll.
|
||||
s.checkAndUpdateFeatures()
|
||||
}
|
||||
})
|
||||
|
||||
go s.eventManager.Start(s.ctx)
|
||||
go s.eventHandler.listen(s.ctx)
|
||||
@@ -1275,6 +1215,18 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService
|
||||
return s.conn, nil
|
||||
}
|
||||
|
||||
// setSettingsEnabled enables or disables the settings menu based on the provided state
|
||||
func (s *serviceClient) setSettingsEnabled(enabled bool) {
|
||||
if s.mSettings != nil {
|
||||
if enabled {
|
||||
s.mSettings.Enable()
|
||||
} else {
|
||||
s.mSettings.Hide()
|
||||
s.mSettings.SetTooltip("Settings are disabled by daemon")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkAndUpdateFeatures checks the current features and updates the UI accordingly
|
||||
func (s *serviceClient) checkAndUpdateFeatures() {
|
||||
features, err := s.getFeatures()
|
||||
@@ -1286,11 +1238,12 @@ func (s *serviceClient) checkAndUpdateFeatures() {
|
||||
s.updateIndicationLock.Lock()
|
||||
defer s.updateIndicationLock.Unlock()
|
||||
|
||||
// DisableUpdateSettings is enforced server-side by the daemon gates
|
||||
// on SetConfig + Login: any attempt to mutate config from UI or
|
||||
// CLI is rejected at that layer. The UI deliberately keeps the
|
||||
// Settings menu visible so the user can still inspect current
|
||||
// values — read-only by virtue of the daemon refusing edits.
|
||||
// Update settings menu based on current features
|
||||
settingsEnabled := features == nil || !features.DisableUpdateSettings
|
||||
if s.settingsEnabled != settingsEnabled {
|
||||
s.settingsEnabled = settingsEnabled
|
||||
s.setSettingsEnabled(settingsEnabled)
|
||||
}
|
||||
|
||||
// Update profile menu based on current features
|
||||
if s.mProfile != nil {
|
||||
@@ -1301,23 +1254,14 @@ func (s *serviceClient) checkAndUpdateFeatures() {
|
||||
}
|
||||
}
|
||||
|
||||
// Update networks and exit node menus based on current features.
|
||||
// `networksEnabled` is the bare feature flag (read elsewhere, e.g. at
|
||||
// connection-status transitions). `networksMenuEnabled` is the
|
||||
// transition-cached state actually applied to the menu items —
|
||||
// it folds in the connection state so a Connected client with the
|
||||
// kill switch off shows the menus active, and only flips on diff.
|
||||
// Update networks and exit node menus based on current features
|
||||
s.networksEnabled = features == nil || !features.DisableNetworks
|
||||
desiredNetworksMenu := s.networksEnabled && s.connected
|
||||
if desiredNetworksMenu != s.networksMenuEnabled {
|
||||
s.networksMenuEnabled = desiredNetworksMenu
|
||||
if desiredNetworksMenu {
|
||||
s.mNetworks.Enable()
|
||||
s.mExitNode.Enable()
|
||||
} else {
|
||||
s.mNetworks.Disable()
|
||||
s.mExitNode.Disable()
|
||||
}
|
||||
if s.networksEnabled && s.connected {
|
||||
s.mNetworks.Enable()
|
||||
s.mExitNode.Enable()
|
||||
} else {
|
||||
s.mNetworks.Disable()
|
||||
s.mExitNode.Disable()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1367,7 +1311,7 @@ func (s *serviceClient) getSrvConfig() {
|
||||
}
|
||||
|
||||
srvCfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.Name,
|
||||
ProfileName: activeProf.ID.String(),
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -1414,14 +1358,7 @@ func (s *serviceClient) getSrvConfig() {
|
||||
|
||||
if s.showAdvancedSettings {
|
||||
s.iMngURL.SetText(s.managementURL)
|
||||
// PSK is rendered with an empty Text and a hint via the
|
||||
// placeholder so the eye toggle never reveals literal asterisks
|
||||
// (the daemon returns the "**********" sentinel — writing that
|
||||
// into a PasswordEntry would surface the literal sentinel when
|
||||
// the user unmasks the field). The placeholder communicates the
|
||||
// configured / MDM-managed state without exposing any value.
|
||||
s.iPreSharedKey.SetText("")
|
||||
s.iPreSharedKey.SetPlaceHolder(preSharedKeyPlaceholder(srvCfg))
|
||||
s.iPreSharedKey.SetText(cfg.PreSharedKey)
|
||||
s.iInterfaceName.SetText(cfg.WgIface)
|
||||
s.iInterfacePort.SetText(strconv.Itoa(cfg.WgPort))
|
||||
if cfg.MTU != 0 {
|
||||
@@ -1431,15 +1368,7 @@ func (s *serviceClient) getSrvConfig() {
|
||||
s.iMTU.SetPlaceHolder(strconv.Itoa(int(iface.DefaultMTU)))
|
||||
}
|
||||
s.sRosenpassPermissive.SetChecked(cfg.RosenpassPermissive)
|
||||
// Re-baseline the enabled state on every refresh: when Rosenpass
|
||||
// is on the checkbox is editable, when it's off the field is
|
||||
// inert. Without an explicit Enable() here the control stays
|
||||
// stuck disabled after a previous refresh (or an MDM unlock) had
|
||||
// turned it off — applyMDMLocksToSettingsForm below adds the
|
||||
// MDM lock on top of this baseline.
|
||||
if cfg.RosenpassEnabled {
|
||||
s.sRosenpassPermissive.Enable()
|
||||
} else {
|
||||
if !cfg.RosenpassEnabled {
|
||||
s.sRosenpassPermissive.Disable()
|
||||
}
|
||||
s.sNetworkMonitor.SetChecked(*cfg.NetworkMonitor)
|
||||
@@ -1468,13 +1397,6 @@ func (s *serviceClient) getSrvConfig() {
|
||||
}
|
||||
}
|
||||
|
||||
// MDM locks must run before the mNotifications-nil early return:
|
||||
// the Settings window is rendered by a separate UI process launched
|
||||
// with --settings (see handleAdvancedSettingsClick), and that child
|
||||
// process does NOT run onReady — so its mNotifications is nil and
|
||||
// the early return below skipped the lock pass entirely.
|
||||
s.applyMDMLocks(srvCfg.MDMManagedFields)
|
||||
|
||||
if s.mNotifications == nil {
|
||||
return
|
||||
}
|
||||
@@ -1613,7 +1535,7 @@ func (s *serviceClient) loadSettings() {
|
||||
}
|
||||
|
||||
cfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.Name,
|
||||
ProfileName: activeProf.ID.String(),
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -1659,129 +1581,6 @@ func (s *serviceClient) loadSettings() {
|
||||
if s.eventManager != nil {
|
||||
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
|
||||
}
|
||||
s.applyMDMLocks(cfg.MDMManagedFields)
|
||||
}
|
||||
|
||||
// applyMDMLocks disables and badges any tray submenu item or settings-
|
||||
// form widget whose underlying field is enforced by the active MDM
|
||||
// policy. Called from loadSettings (submenu refresh) and from
|
||||
// getSrvConfig (settings-window refresh). Locked items keep their value
|
||||
// already set by the surrounding refresh code — this routine only
|
||||
// flips the enabled state and the title suffix, never the value.
|
||||
func (s *serviceClient) applyMDMLocks(managed []string) {
|
||||
set := make(map[string]bool, len(managed))
|
||||
for _, k := range managed {
|
||||
set[k] = true
|
||||
}
|
||||
s.mdmManagedFields = set
|
||||
if len(managed) > 0 {
|
||||
log.Infof("MDM-managed UI fields: %v", managed)
|
||||
}
|
||||
|
||||
type submenuTarget struct {
|
||||
item *systray.MenuItem
|
||||
title string
|
||||
key string
|
||||
}
|
||||
for _, t := range []submenuTarget{
|
||||
{s.mAllowSSH, "Allow SSH", mdm.KeyAllowServerSSH},
|
||||
{s.mAutoConnect, "Connect on Startup", mdm.KeyDisableAutoConnect},
|
||||
{s.mEnableRosenpass, "Enable Quantum-Resistance", mdm.KeyRosenpassEnabled},
|
||||
{s.mBlockInbound, "Block Inbound Connections", mdm.KeyBlockInbound},
|
||||
} {
|
||||
if t.item == nil {
|
||||
continue
|
||||
}
|
||||
if set[t.key] {
|
||||
t.item.SetTitle(t.title + " (MDM)")
|
||||
t.item.Disable()
|
||||
} else {
|
||||
t.item.SetTitle(t.title)
|
||||
t.item.Enable()
|
||||
}
|
||||
}
|
||||
|
||||
s.applyMDMLocksToSettingsForm(set)
|
||||
}
|
||||
|
||||
// preSharedKeyPlaceholder returns the hint string shown in the PSK
|
||||
// Entry's placeholder slot. The placeholder is the only signal the
|
||||
// user gets that a PSK is configured, because the entry's Text is
|
||||
// forced to empty to keep the password reveal toggle from leaking
|
||||
// the daemon-returned "**********" redaction sentinel. Returns "" if
|
||||
// no PSK is present, "MDM-managed" if the key is enforced by MDM,
|
||||
// and "configured" otherwise.
|
||||
func preSharedKeyPlaceholder(cfg *proto.GetConfigResponse) string {
|
||||
if cfg == nil || cfg.PreSharedKey == "" {
|
||||
return ""
|
||||
}
|
||||
for _, k := range cfg.MDMManagedFields {
|
||||
if k == mdm.KeyPreSharedKey {
|
||||
return "MDM-managed"
|
||||
}
|
||||
}
|
||||
return "configured"
|
||||
}
|
||||
|
||||
// applyMDMLocksToSettingsForm disables the per-field input widgets in
|
||||
// the advanced Settings window when the corresponding MDM key is set.
|
||||
// For plain-text entries (Management URL, Interface Port) the visible
|
||||
// value is suffixed with " (MDM)" so the user sees the lock indicator
|
||||
// inline; for the password entry the suffix is skipped (a password
|
||||
// widget renders every char as a dot and the indicator would not be
|
||||
// readable). The widgets are created lazily by showSettingsUI, so
|
||||
// guard each ref against nil.
|
||||
func (s *serviceClient) applyMDMLocksToSettingsForm(set map[string]bool) {
|
||||
type entryTarget struct {
|
||||
entry *widget.Entry
|
||||
key string
|
||||
inlineTag bool
|
||||
}
|
||||
for _, t := range []entryTarget{
|
||||
{s.iMngURL, mdm.KeyManagementURL, true},
|
||||
{s.iPreSharedKey, mdm.KeyPreSharedKey, false},
|
||||
{s.iInterfacePort, mdm.KeyWireguardPort, true},
|
||||
} {
|
||||
if t.entry == nil {
|
||||
continue
|
||||
}
|
||||
if set[t.key] {
|
||||
if t.inlineTag && t.entry.Text != "" && !strings.HasSuffix(t.entry.Text, mdmFieldSuffix) {
|
||||
t.entry.SetText(t.entry.Text + mdmFieldSuffix)
|
||||
}
|
||||
t.entry.Disable()
|
||||
} else {
|
||||
if t.inlineTag {
|
||||
t.entry.SetText(strings.TrimSuffix(t.entry.Text, mdmFieldSuffix))
|
||||
}
|
||||
t.entry.Enable()
|
||||
}
|
||||
}
|
||||
type checkTarget struct {
|
||||
check *widget.Check
|
||||
key string
|
||||
}
|
||||
for _, t := range []checkTarget{
|
||||
{s.sDisableClientRoutes, mdm.KeyDisableClientRoutes},
|
||||
{s.sDisableServerRoutes, mdm.KeyDisableServerRoutes},
|
||||
} {
|
||||
if t.check == nil {
|
||||
continue
|
||||
}
|
||||
if set[t.key] {
|
||||
t.check.Disable()
|
||||
} else {
|
||||
t.check.Enable()
|
||||
}
|
||||
}
|
||||
if s.sRosenpassPermissive != nil && set[mdm.KeyRosenpassPermissive] {
|
||||
// MDM lock layered on top of the Rosenpass-on/off baseline
|
||||
// applied by getSrvConfig. No Enable() branch here: when the
|
||||
// MDM key is removed, the next getSrvConfig refresh re-baselines
|
||||
// the control on cfg.RosenpassEnabled and brings it back if
|
||||
// Rosenpass is on.
|
||||
s.sRosenpassPermissive.Disable()
|
||||
}
|
||||
}
|
||||
|
||||
// updateConfig updates the configuration parameters
|
||||
@@ -1813,7 +1612,7 @@ func (s *serviceClient) updateConfig() error {
|
||||
}
|
||||
|
||||
req := proto.SetConfigRequest{
|
||||
ProfileName: activeProf.Name,
|
||||
ProfileName: activeProf.ID.String(),
|
||||
Username: currUser.Username,
|
||||
DisableAutoConnect: &disableAutoStart,
|
||||
ServerSSHAllowed: &sshAllowed,
|
||||
|
||||
@@ -66,7 +66,7 @@ func (s *serviceClient) showProfilesUI() {
|
||||
} else {
|
||||
indicator.SetText("")
|
||||
}
|
||||
nameLabel.SetText(profile.Name)
|
||||
nameLabel.SetText(formatProfileLabel(profile, profiles))
|
||||
|
||||
// Configure Select/Active button
|
||||
selectBtn.SetText(func() string {
|
||||
@@ -88,7 +88,7 @@ func (s *serviceClient) showProfilesUI() {
|
||||
return
|
||||
}
|
||||
// switch
|
||||
err = s.switchProfile(profile.Name)
|
||||
err = s.switchProfile(profile.ID)
|
||||
if err != nil {
|
||||
log.Errorf("failed to switch profile: %v", err)
|
||||
dialog.ShowError(errors.New("failed to select profile"), s.wProfiles)
|
||||
@@ -130,7 +130,7 @@ func (s *serviceClient) showProfilesUI() {
|
||||
logoutBtn.Show()
|
||||
logoutBtn.SetText("Deregister")
|
||||
logoutBtn.OnTapped = func() {
|
||||
s.handleProfileLogout(profile.Name, refresh)
|
||||
s.handleProfileLogout(profile, refresh)
|
||||
}
|
||||
|
||||
// Remove profile
|
||||
@@ -144,7 +144,7 @@ func (s *serviceClient) showProfilesUI() {
|
||||
return
|
||||
}
|
||||
|
||||
err = s.removeProfile(profile.Name)
|
||||
err = s.removeProfile(profile.ID)
|
||||
if err != nil {
|
||||
log.Errorf("failed to remove profile: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("failed to remove profile"), s.wProfiles)
|
||||
@@ -250,7 +250,7 @@ func (s *serviceClient) addProfile(profileName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) switchProfile(profileName string) error {
|
||||
func (s *serviceClient) switchProfile(handle string) error {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf(getClientFMT, err)
|
||||
@@ -261,15 +261,15 @@ func (s *serviceClient) switchProfile(profileName string) error {
|
||||
return fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
if _, err := conn.SwitchProfile(s.ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &profileName,
|
||||
resp, err := conn.SwitchProfile(s.ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &handle,
|
||||
Username: &currUser.Username,
|
||||
}); err != nil {
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("switch profile failed: %w", err)
|
||||
}
|
||||
|
||||
err = s.profileManager.SwitchProfile(profileName)
|
||||
if err != nil {
|
||||
if err := s.profileManager.SwitchProfile(profilemanager.ID(resp.Id)); err != nil {
|
||||
return fmt.Errorf("switch profile: %w", err)
|
||||
}
|
||||
|
||||
@@ -299,10 +299,27 @@ func (s *serviceClient) removeProfile(profileName string) error {
|
||||
}
|
||||
|
||||
type Profile struct {
|
||||
ID string
|
||||
Name string
|
||||
IsActive bool
|
||||
}
|
||||
|
||||
// formatProfileLabel returns the display label for a profile. Profiles can
|
||||
// share the same Name, so when more than one profile in profiles carries this
|
||||
// Name, a short form of the ID is appended to disambiguate the entries.
|
||||
func formatProfileLabel(profile Profile, profiles []Profile) string {
|
||||
count := 0
|
||||
for _, p := range profiles {
|
||||
if p.Name == profile.Name {
|
||||
count++
|
||||
}
|
||||
}
|
||||
if count <= 1 {
|
||||
return profile.Name
|
||||
}
|
||||
return fmt.Sprintf("%s (%s)", profile.Name, profilemanager.ID(profile.ID).ShortID())
|
||||
}
|
||||
|
||||
func (s *serviceClient) getProfiles() ([]Profile, error) {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
@@ -324,6 +341,7 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
|
||||
|
||||
for _, profile := range profilesResp.Profiles {
|
||||
profiles = append(profiles, Profile{
|
||||
ID: profile.Id,
|
||||
Name: profile.Name,
|
||||
IsActive: profile.IsActive,
|
||||
})
|
||||
@@ -332,10 +350,10 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback func()) {
|
||||
func (s *serviceClient) handleProfileLogout(profile Profile, refreshCallback func()) {
|
||||
dialog.ShowConfirm(
|
||||
"Deregister",
|
||||
fmt.Sprintf("Are you sure you want to deregister from '%s'?", profileName),
|
||||
fmt.Sprintf("Are you sure you want to deregister from '%s'?", profile.Name),
|
||||
func(confirm bool) {
|
||||
if !confirm {
|
||||
return
|
||||
@@ -356,8 +374,10 @@ func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback
|
||||
}
|
||||
|
||||
username := currUser.Username
|
||||
// ProfileName is treated as a handle; send the ID so the
|
||||
// daemon resolves to exactly this profile.
|
||||
_, err = conn.Logout(s.ctx, &proto.LogoutRequest{
|
||||
ProfileName: &profileName,
|
||||
ProfileName: &profile.ID,
|
||||
Username: &username,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -368,7 +388,7 @@ func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback
|
||||
|
||||
dialog.ShowInformation(
|
||||
"Deregistered",
|
||||
fmt.Sprintf("Successfully deregistered from '%s'", profileName),
|
||||
fmt.Sprintf("Successfully deregistered from '%s'", profile.Name),
|
||||
s.wProfiles,
|
||||
)
|
||||
|
||||
@@ -461,6 +481,7 @@ func (p *profileMenu) getProfiles() ([]Profile, error) {
|
||||
|
||||
for _, profile := range profilesResp.Profiles {
|
||||
profiles = append(profiles, Profile{
|
||||
ID: profile.Id,
|
||||
Name: profile.Name,
|
||||
IsActive: profile.IsActive,
|
||||
})
|
||||
@@ -501,7 +522,7 @@ func (p *profileMenu) refresh() {
|
||||
}
|
||||
|
||||
if activeProf.ProfileName == "default" || activeProf.Username == currUser.Username {
|
||||
activeProfState, err := p.profileManager.GetProfileState(activeProf.ProfileName)
|
||||
activeProfState, err := p.profileManager.GetProfileState(profilemanager.ID(activeProf.Id))
|
||||
if err != nil {
|
||||
log.Warnf("failed to get active profile state: %v", err)
|
||||
p.emailMenuItem.Hide()
|
||||
@@ -512,7 +533,7 @@ func (p *profileMenu) refresh() {
|
||||
}
|
||||
|
||||
for _, profile := range profiles {
|
||||
item := p.profileMenuItem.AddSubMenuItem(profile.Name, "")
|
||||
item := p.profileMenuItem.AddSubMenuItem(formatProfileLabel(profile, profiles), "")
|
||||
if profile.IsActive {
|
||||
item.Check()
|
||||
}
|
||||
@@ -541,8 +562,8 @@ func (p *profileMenu) refresh() {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.SwitchProfile(ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &profile.Name,
|
||||
switchResp, err := conn.SwitchProfile(ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &profile.ID,
|
||||
Username: &currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -552,7 +573,7 @@ func (p *profileMenu) refresh() {
|
||||
return
|
||||
}
|
||||
|
||||
err = p.profileManager.SwitchProfile(profile.Name)
|
||||
err = p.profileManager.SwitchProfile(profilemanager.ID(switchResp.Id))
|
||||
if err != nil {
|
||||
log.Errorf("failed to switch profile '%s': %v", profile.Name, err)
|
||||
return
|
||||
@@ -666,49 +687,17 @@ func (p *profileMenu) clear(profiles []Profile) {
|
||||
}
|
||||
}
|
||||
|
||||
// setEnabled greys out (Disable) the profile menu and every existing
|
||||
// sub-item when the daemon reports the kill switch active, so the user
|
||||
// sees the menu but cannot enter "Manage Profiles" or switch profile.
|
||||
// Previously this used Hide() on the parent, but Fyne's systray on
|
||||
// Windows does not propagate Hide() to a parent that already has
|
||||
// children — the submenu kept popping up and accepting clicks. Disable
|
||||
// is the reliable visual lock.
|
||||
// setEnabled enables or disables the profile menu based on the provided state
|
||||
func (p *profileMenu) setEnabled(enabled bool) {
|
||||
if p.profileMenuItem == nil {
|
||||
return
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if enabled {
|
||||
p.profileMenuItem.Enable()
|
||||
p.profileMenuItem.SetTooltip("")
|
||||
} else {
|
||||
p.profileMenuItem.Disable()
|
||||
p.profileMenuItem.SetTooltip("Profiles are disabled by daemon")
|
||||
}
|
||||
|
||||
apply := func(item *systray.MenuItem) {
|
||||
if item == nil {
|
||||
return
|
||||
}
|
||||
if p.profileMenuItem != nil {
|
||||
if enabled {
|
||||
item.Enable()
|
||||
p.profileMenuItem.Enable()
|
||||
p.profileMenuItem.SetTooltip("")
|
||||
} else {
|
||||
item.Disable()
|
||||
p.profileMenuItem.Hide()
|
||||
p.profileMenuItem.SetTooltip("Profiles are disabled by daemon")
|
||||
}
|
||||
}
|
||||
for _, sub := range p.profileSubItems {
|
||||
if sub != nil {
|
||||
apply(sub.MenuItem)
|
||||
}
|
||||
}
|
||||
if p.manageProfilesSubItem != nil {
|
||||
apply(p.manageProfilesSubItem.MenuItem)
|
||||
}
|
||||
if p.logoutSubItem != nil {
|
||||
apply(p.logoutSubItem.MenuItem)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *profileMenu) updateMenu() {
|
||||
@@ -727,7 +716,10 @@ func (p *profileMenu) updateMenu() {
|
||||
}
|
||||
|
||||
sort.Slice(profiles, func(i, j int) bool {
|
||||
return profiles[i].Name < profiles[j].Name
|
||||
if profiles[i].Name != profiles[j].Name {
|
||||
return profiles[i].Name < profiles[j].Name
|
||||
}
|
||||
return profiles[i].ID < profiles[j].ID
|
||||
})
|
||||
|
||||
p.mu.Lock()
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<!--
|
||||
NetBird MDM preferences (macOS) — bare plist for MDM platforms that
|
||||
accept a managed-preferences plist tied to a bundle identifier
|
||||
(e.g. JumpCloud "Mac Application Custom Settings", Mosyle "Custom
|
||||
Settings", Jamf "Application & Custom Settings" → External
|
||||
Application).
|
||||
|
||||
Bundle identifier (preference domain): io.netbird.client
|
||||
|
||||
The MDM provider will wrap this plist into a Configuration Profile
|
||||
payload of type com.apple.ManagedClient.preferences and push it to
|
||||
target devices via the Apple MDM protocol. The OS materializes the
|
||||
final file at:
|
||||
/Library/Managed Preferences/io.netbird.client.plist
|
||||
which is what the NetBird daemon's client/mdm/policy_darwin.go
|
||||
loader reads on every 1-minute MDM reload tick.
|
||||
|
||||
For MDM platforms that expect a full Configuration Profile instead
|
||||
of a bare plist (Custom Configuration Profile / .mobileconfig upload),
|
||||
use docs/netbird-macos.mobileconfig — same keys, additional Payload*
|
||||
envelope.
|
||||
|
||||
Editing this file:
|
||||
- Remove or comment out any key you do NOT want to enforce. The
|
||||
daemon treats an absent key as "no enforcement" for that field.
|
||||
- Keep the document well-formed XML. Validate locally with:
|
||||
plutil -lint docs/io.netbird.client.plist
|
||||
- Keys are camelCase; values are typed (<string>, <true/>, <false/>,
|
||||
<integer>). See docs/src/pages/client/mdm-integration.mdx (the
|
||||
public docs page) for the full reference.
|
||||
|
||||
Persistence caveat:
|
||||
macOS wipes /Library/Managed Preferences/ at every boot on
|
||||
devices that are NOT MDM-enrolled. This plist only sticks across
|
||||
reboots when delivered through a real MDM channel. For local
|
||||
testing on an un-enrolled host, write the file manually as root
|
||||
and accept it will not survive the next boot.
|
||||
-->
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
|
||||
<!-- ===== Identity / auth ===== -->
|
||||
<key>managementURL</key>
|
||||
<string>https://api.netbird.io:443</string>
|
||||
|
||||
<!--
|
||||
Pre-shared key: secret. Remove the entry entirely when not used;
|
||||
do NOT leave an empty <string></string>, which the daemon would
|
||||
otherwise treat as a deliberate empty-PSK enforcement.
|
||||
-->
|
||||
<!--
|
||||
<key>preSharedKey</key>
|
||||
<string>REPLACE_ME</string>
|
||||
-->
|
||||
|
||||
<!-- ===== Engine / runtime behavior =====
|
||||
Each key is optional. Remove or comment out to leave the
|
||||
field unmanaged on the client. -->
|
||||
|
||||
<key>allowServerSSH</key>
|
||||
<true/>
|
||||
|
||||
<!--
|
||||
<key>disableAutoConnect</key>
|
||||
<false/>
|
||||
|
||||
<key>disableClientRoutes</key>
|
||||
<false/>
|
||||
|
||||
<key>disableServerRoutes</key>
|
||||
<false/>
|
||||
|
||||
<key>blockInbound</key>
|
||||
<false/>
|
||||
|
||||
<key>rosenpassEnabled</key>
|
||||
<true/>
|
||||
|
||||
<key>rosenpassPermissive</key>
|
||||
<false/>
|
||||
-->
|
||||
|
||||
<!-- ===== WireGuard UDP port =====
|
||||
Range 1-65535. Omit to keep the daemon default. -->
|
||||
<!--
|
||||
<key>wireguardPort</key>
|
||||
<integer>51820</integer>
|
||||
-->
|
||||
|
||||
<!-- ===== UI / lockdown kill switches =====
|
||||
disableUpdateSettings : block every config change from UI and CLI
|
||||
on this device (Settings view stays
|
||||
readable but read-only).
|
||||
disableProfiles : hide the profile menu, reject profile CRUD.
|
||||
disableNetworks : hide the Networks / Exit Node menus,
|
||||
reject the related RPCs.
|
||||
disableMetricsCollection: opt out of anonymous usage telemetry. -->
|
||||
<!--
|
||||
<key>disableUpdateSettings</key>
|
||||
<true/>
|
||||
|
||||
<key>disableProfiles</key>
|
||||
<true/>
|
||||
|
||||
<key>disableNetworks</key>
|
||||
<true/>
|
||||
|
||||
<key>disableMetricsCollection</key>
|
||||
<false/>
|
||||
-->
|
||||
|
||||
<!-- ===== Split tunnel =====
|
||||
Android-only at the client level. Safe to ship on macOS for
|
||||
mixed-platform fleets; the macOS daemon parses and ignores. -->
|
||||
<!--
|
||||
<key>splitTunnelMode</key>
|
||||
<string>allow</string>
|
||||
|
||||
<key>splitTunnelApps</key>
|
||||
<string>com.acme.app1,com.acme.app2</string>
|
||||
-->
|
||||
|
||||
</dict>
|
||||
</plist>
|
||||
@@ -1,159 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<!--
|
||||
NetBird MDM configuration profile (macOS).
|
||||
|
||||
Wraps a `com.apple.ManagedClient.preferences` payload that pushes the
|
||||
NetBird MDM policy into:
|
||||
/Library/Managed Preferences/io.netbird.client.plist
|
||||
|
||||
Read at runtime by the netbird daemon's macOS loader
|
||||
(client/mdm/policy_darwin.go — Phase 2). Key names match the canonical
|
||||
lowerCamelCase form used in docs/netbird.admx and the mdm.Key*
|
||||
constants in client/mdm/policy.go.
|
||||
|
||||
Bundle identifier: io.netbird.client
|
||||
(confirm against the signed pkg before fleet roll-out)
|
||||
|
||||
Distribution:
|
||||
- sign with `productsign --sign "Developer ID Installer: ..." ...`
|
||||
before fleet roll-out (Apple-Configurator-2 won't install an
|
||||
unsigned profile on Sonoma+ without user override).
|
||||
- For local dev install: `sudo profiles install -path netbird-macos.mobileconfig`.
|
||||
- For MDM (Jamf/Kandji/Mosyle/Intune): upload as a Custom Profile.
|
||||
|
||||
Editing:
|
||||
- Replace UUID placeholders below with fresh UUIDs (`uuidgen` on
|
||||
macOS) when forking this template for a real fleet — each
|
||||
deployment should have unique UUIDs so the OS treats it as a
|
||||
distinct profile.
|
||||
- Tune the PayloadContent values to the policy you want to enforce.
|
||||
- Remove any key you do NOT want to enforce (the daemon treats an
|
||||
absent key as "no enforcement" for that field).
|
||||
|
||||
iOS note:
|
||||
This file is macOS-specific. iOS uses managed app config via
|
||||
UserDefaults[com.apple.configuration.managed] under a different
|
||||
payload type (com.apple.app.configuration.managed); the wrapper
|
||||
structure is the same but the inner payload dictionary differs.
|
||||
See docs/netbird-ios.mobileconfig (Phase 5) when shipped.
|
||||
-->
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<!-- Outer profile envelope -->
|
||||
<key>PayloadType</key>
|
||||
<string>Configuration</string>
|
||||
<key>PayloadVersion</key>
|
||||
<integer>1</integer>
|
||||
<key>PayloadIdentifier</key>
|
||||
<string>io.netbird.client.mdm</string>
|
||||
<key>PayloadUUID</key>
|
||||
<string>11111111-1111-1111-1111-111111111111</string>
|
||||
<key>PayloadDisplayName</key>
|
||||
<string>NetBird MDM Policy</string>
|
||||
<key>PayloadDescription</key>
|
||||
<string>Enforces NetBird client configuration. Values written here override any local user / CLI / on-disk setting and are re-applied at every daemon boot and on every 1-minute MDM reload tick.</string>
|
||||
<key>PayloadOrganization</key>
|
||||
<string>NetBird</string>
|
||||
<key>PayloadScope</key>
|
||||
<string>System</string>
|
||||
<key>PayloadRemovalDisallowed</key>
|
||||
<false/>
|
||||
|
||||
<key>PayloadContent</key>
|
||||
<array>
|
||||
<dict>
|
||||
<!-- Managed preferences payload: writes /Library/Managed Preferences/io.netbird.client.plist -->
|
||||
<key>PayloadType</key>
|
||||
<string>com.apple.ManagedClient.preferences</string>
|
||||
<key>PayloadVersion</key>
|
||||
<integer>1</integer>
|
||||
<key>PayloadIdentifier</key>
|
||||
<string>io.netbird.client.mdm.preferences</string>
|
||||
<key>PayloadUUID</key>
|
||||
<string>22222222-2222-2222-2222-222222222222</string>
|
||||
<key>PayloadDisplayName</key>
|
||||
<string>NetBird Managed Preferences</string>
|
||||
<key>PayloadEnabled</key>
|
||||
<true/>
|
||||
|
||||
<key>PayloadContent</key>
|
||||
<dict>
|
||||
<key>io.netbird.client</key>
|
||||
<dict>
|
||||
<key>Forced</key>
|
||||
<array>
|
||||
<dict>
|
||||
<key>mcx_preference_settings</key>
|
||||
<dict>
|
||||
|
||||
<!-- ===== Identity / auth (strings) ===== -->
|
||||
<key>managementURL</key>
|
||||
<string>https://api.netbird.io:443</string>
|
||||
|
||||
<!-- Pre-shared key: secret. Remove the entry entirely
|
||||
when not used; do NOT leave an empty string. -->
|
||||
<!--
|
||||
<key>preSharedKey</key>
|
||||
<string>REPLACE_ME</string>
|
||||
-->
|
||||
|
||||
<!-- ===== Engine / runtime behavior (bool) =====
|
||||
Remove any key to leave the field unmanaged. -->
|
||||
<!--
|
||||
<key>disableAutoConnect</key>
|
||||
<false/>
|
||||
<key>disableClientRoutes</key>
|
||||
<false/>
|
||||
<key>disableServerRoutes</key>
|
||||
<false/>
|
||||
<key>blockInbound</key>
|
||||
<false/>
|
||||
-->
|
||||
<key>allowServerSSH</key>
|
||||
<true/>
|
||||
<!--
|
||||
<key>rosenpassEnabled</key>
|
||||
<true/>
|
||||
<key>rosenpassPermissive</key>
|
||||
<false/>
|
||||
-->
|
||||
|
||||
<!-- ===== WireGuard UDP port (int) =====
|
||||
Range 1-65535. Omit to keep the default. -->
|
||||
<!--
|
||||
<key>wireguardPort</key>
|
||||
<integer>51820</integer>
|
||||
-->
|
||||
|
||||
<!-- ===== Split tunnel (Android-only at the daemon level)
|
||||
Pushed harmlessly on macOS for fleets with mixed
|
||||
desktop+mobile devices; the macOS daemon ignores it. -->
|
||||
<!--
|
||||
<key>splitTunnelMode</key>
|
||||
<string>allow</string>
|
||||
<key>splitTunnelApps</key>
|
||||
<string>com.acme.app1,com.acme.app2</string>
|
||||
-->
|
||||
|
||||
<!-- ===== UI / kill switches (bool) ===== -->
|
||||
<!--
|
||||
<key>disableUpdateSettings</key>
|
||||
<true/>
|
||||
<key>disableProfiles</key>
|
||||
<true/>
|
||||
<key>disableNetworks</key>
|
||||
<true/>
|
||||
<key>disableMetricsCollection</key>
|
||||
<false/>
|
||||
-->
|
||||
|
||||
</dict>
|
||||
</dict>
|
||||
</array>
|
||||
</dict>
|
||||
</dict>
|
||||
</dict>
|
||||
</array>
|
||||
</dict>
|
||||
</plist>
|
||||
@@ -1,189 +0,0 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# SYNOPSIS
|
||||
# Push the NetBird MDM policy to a macOS device via JumpCloud Commands.
|
||||
#
|
||||
# DESCRIPTION
|
||||
# This is the macOS counterpart of docs/netbird-policy.reg.ps1.
|
||||
# It writes the values declared in the "POLICY VALUES" block below to
|
||||
# the managed-preferences plist that the NetBird daemon's
|
||||
# client/mdm/policy_darwin.go loader reads on every 1-minute MDM
|
||||
# reload tick:
|
||||
#
|
||||
# /Library/Managed Preferences/io.netbird.client.plist
|
||||
#
|
||||
# Once the plist lands, the daemon picks up the new values without
|
||||
# restart (the ticker calls Config.apply() → applyMDMPolicy() and
|
||||
# restarts the engine on diff).
|
||||
#
|
||||
# DEPLOYMENT (JumpCloud)
|
||||
# 1. Admin Console -> Device Management -> Commands -> +.
|
||||
# 2. Type: Mac, Shell, Run as: root.
|
||||
# 3. Paste this file verbatim into the command body.
|
||||
# 4. Bind to the target system group, save, run.
|
||||
#
|
||||
# IMPORTANT: PERSISTENCE
|
||||
# macOS wipes /Library/Managed Preferences/ at every boot on devices
|
||||
# that are NOT MDM-enrolled. For a persistent fleet rollout, push the
|
||||
# companion docs/netbird-macos.mobileconfig as a Custom Configuration
|
||||
# Profile (Admin Console -> MDM -> Mac Custom Configuration Profiles)
|
||||
# instead of this script. Use this script when:
|
||||
# - the device is MDM-enrolled (file survives reboots), or
|
||||
# - you need a one-shot test push before reboot, or
|
||||
# - you orchestrate via JumpCloud Commands and want the same
|
||||
# variable-driven workflow as the Windows .ps1 sibling.
|
||||
#
|
||||
# IDEMPOTENCY: re-running with the same values is a no-op from the
|
||||
# daemon's point of view (the 1-minute reload ticker diff returns empty).
|
||||
#
|
||||
# SECURITY: PreSharedKey is redacted in this script's log output.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
### POLICY VALUES — EDIT THIS BLOCK ###########################################
|
||||
#
|
||||
# Set each variable below to the desired value. Set to empty string ""
|
||||
# or to NULL to omit a key entirely (the daemon treats an absent key
|
||||
# as "no enforcement" for that field). Booleans use "true"/"false"
|
||||
# (lowercase). Integers as decimal.
|
||||
#
|
||||
# Reference for key names + accepted values:
|
||||
# client/mdm/policy.go (Key* constants)
|
||||
# docs/netbird-macos.mobileconfig (sample profile)
|
||||
# docs/netbird.admx + .adml (Windows ADMX schema)
|
||||
#
|
||||
NULL='__UNSET__'
|
||||
managementURL='https://api.netbird.io:443'
|
||||
preSharedKey="$NULL" # secret; redacted in log
|
||||
allowServerSSH='true'
|
||||
blockInbound="$NULL"
|
||||
disableAutoConnect="$NULL"
|
||||
disableClientRoutes="$NULL"
|
||||
disableServerRoutes="$NULL"
|
||||
disableMetricsCollection="$NULL"
|
||||
disableUpdateSettings="$NULL"
|
||||
disableProfiles="$NULL"
|
||||
disableNetworks="$NULL"
|
||||
rosenpassEnabled="$NULL"
|
||||
rosenpassPermissive="$NULL"
|
||||
wireguardPort='51820'
|
||||
splitTunnelMode="$NULL" # "allow" or "disallow", Android-only at the daemon level
|
||||
splitTunnelApps="$NULL" # comma-separated app IDs, Android-only
|
||||
##############################################################################
|
||||
|
||||
readonly PLIST_DIR='/Library/Managed Preferences'
|
||||
readonly PLIST_PATH="$PLIST_DIR/io.netbird.client.plist"
|
||||
readonly LOG_TAG='netbird-mdm'
|
||||
|
||||
# log sends a message to the system logger using the configured tag and echoes the message to stdout prefixed by an ISO 8601 UTC timestamp and the tag.
|
||||
log() {
|
||||
/usr/bin/logger -t "$LOG_TAG" "$*"
|
||||
printf '%s [%s] %s\n' "$(date -u '+%Y-%m-%dT%H:%M:%SZ')" "$LOG_TAG" "$*"
|
||||
}
|
||||
|
||||
# is_set returns success if the provided value is non-empty and is not equal to the special NULL marker.
|
||||
is_set() {
|
||||
local value="$1"
|
||||
[[ -n "$value" && "$value" != "$NULL" ]]
|
||||
}
|
||||
|
||||
# start_plist creates the temporary plist file at "$PLIST_PATH.tmp" containing the XML plist header and opening `<dict>` for the policy plist.
|
||||
start_plist() {
|
||||
cat > "$PLIST_PATH.tmp" <<'EOF'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
EOF
|
||||
}
|
||||
|
||||
# end_plist appends the closing `</dict>` and `</plist>` tags to the temporary plist file.
|
||||
end_plist() {
|
||||
cat >> "$PLIST_PATH.tmp" <<'EOF'
|
||||
</dict>
|
||||
</plist>
|
||||
EOF
|
||||
}
|
||||
|
||||
# emit_string appends a plist `<key>`/`<string>` entry for the given key and value to "$PLIST_PATH.tmp", XML-escaping `&`, `<`, and `>`, and logs the assignment (masking the logged value as `********** (secret)` when the key is `preSharedKey`).
|
||||
emit_string() {
|
||||
local key="$1" value="$2" log_value="$2"
|
||||
# Escape XML entities in the value
|
||||
local escaped
|
||||
escaped="$(printf '%s' "$value" | sed -e 's/&/\&/g' -e 's/</\</g' -e 's/>/\>/g')"
|
||||
printf ' <key>%s</key>\n <string>%s</string>\n' "$key" "$escaped" >> "$PLIST_PATH.tmp"
|
||||
if [[ "$key" == "preSharedKey" ]]; then
|
||||
log_value='********** (secret)'
|
||||
fi
|
||||
log "set $key = $log_value"
|
||||
}
|
||||
|
||||
# emit_bool writes a boolean plist entry for a given key into the temporary plist file.
|
||||
# emit_bool writes a boolean plist entry for a key when the provided value matches an accepted boolean token; logs an error and skips the key on invalid input.
|
||||
emit_bool() {
|
||||
local key="$1" value="$2"
|
||||
local xml_bool
|
||||
case "$value" in
|
||||
true|True|TRUE|1|yes) xml_bool='<true/>' ; value='true' ;;
|
||||
false|False|FALSE|0|no) xml_bool='<false/>' ; value='false' ;;
|
||||
*) log "invalid boolean for $key: $value (must be true/false); skipping"; return ;;
|
||||
esac
|
||||
printf ' <key>%s</key>\n %s\n' "$key" "$xml_bool" >> "$PLIST_PATH.tmp"
|
||||
log "set $key = $value"
|
||||
}
|
||||
|
||||
# emit_int validates that VALUE contains only decimal digits and, if valid, appends an `<integer>` plist entry for KEY to the temporary plist (`$PLIST_PATH.tmp`) and logs the assignment; on invalid input it logs a skip and does not emit the key.
|
||||
emit_int() {
|
||||
local key="$1" value="$2"
|
||||
if ! [[ "$value" =~ ^[0-9]+$ ]]; then
|
||||
log "invalid integer for $key: $value (must be decimal); skipping"
|
||||
return
|
||||
fi
|
||||
printf ' <key>%s</key>\n <integer>%s</integer>\n' "$key" "$value" >> "$PLIST_PATH.tmp"
|
||||
log "set $key = $value"
|
||||
}
|
||||
|
||||
# main builds the NetBird MDM plist from configured policy variables, validates and installs it to /Library/Managed Preferences/io.netbird.client.plist (root:wheel, 644) and optionally triggers the NetBird daemon to reload.
|
||||
main() {
|
||||
log "applying NetBird MDM policy to $PLIST_PATH"
|
||||
/bin/mkdir -p "$PLIST_DIR"
|
||||
start_plist
|
||||
|
||||
is_set "$managementURL" && emit_string managementURL "$managementURL"
|
||||
is_set "$preSharedKey" && emit_string preSharedKey "$preSharedKey"
|
||||
is_set "$allowServerSSH" && emit_bool allowServerSSH "$allowServerSSH"
|
||||
is_set "$blockInbound" && emit_bool blockInbound "$blockInbound"
|
||||
is_set "$disableAutoConnect" && emit_bool disableAutoConnect "$disableAutoConnect"
|
||||
is_set "$disableClientRoutes" && emit_bool disableClientRoutes "$disableClientRoutes"
|
||||
is_set "$disableServerRoutes" && emit_bool disableServerRoutes "$disableServerRoutes"
|
||||
is_set "$disableMetricsCollection" && emit_bool disableMetricsCollection "$disableMetricsCollection"
|
||||
is_set "$disableUpdateSettings" && emit_bool disableUpdateSettings "$disableUpdateSettings"
|
||||
is_set "$disableProfiles" && emit_bool disableProfiles "$disableProfiles"
|
||||
is_set "$disableNetworks" && emit_bool disableNetworks "$disableNetworks"
|
||||
is_set "$rosenpassEnabled" && emit_bool rosenpassEnabled "$rosenpassEnabled"
|
||||
is_set "$rosenpassPermissive" && emit_bool rosenpassPermissive "$rosenpassPermissive"
|
||||
is_set "$wireguardPort" && emit_int wireguardPort "$wireguardPort"
|
||||
is_set "$splitTunnelMode" && emit_string splitTunnelMode "$splitTunnelMode"
|
||||
is_set "$splitTunnelApps" && emit_string splitTunnelApps "$splitTunnelApps"
|
||||
|
||||
end_plist
|
||||
|
||||
if ! /usr/bin/plutil -lint "$PLIST_PATH.tmp" >/dev/null 2>&1; then
|
||||
log "ERROR: generated plist failed plutil lint; not installing"
|
||||
/usr/bin/plutil -lint "$PLIST_PATH.tmp" >&2 || true
|
||||
/bin/rm -f "$PLIST_PATH.tmp"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
/bin/mv -f "$PLIST_PATH.tmp" "$PLIST_PATH"
|
||||
/usr/sbin/chown root:wheel "$PLIST_PATH"
|
||||
/bin/chmod 644 "$PLIST_PATH"
|
||||
|
||||
log "policy installed; NetBird daemon will pick it up within the next 1-minute reload tick"
|
||||
|
||||
# Optional: kick the daemon for an immediate apply. Safe — does
|
||||
# nothing on a host where NetBird is not yet installed.
|
||||
/bin/launchctl kickstart -k system/io.netbird.client 2>/dev/null || true
|
||||
}
|
||||
|
||||
main "$@"
|
||||
Binary file not shown.
@@ -1,94 +0,0 @@
|
||||
#requires -Version 5.1
|
||||
<#
|
||||
.SYNOPSIS
|
||||
Push the NetBird MDM policy to a Windows device via JumpCloud Commands
|
||||
by importing a sidecar netbird-policy.reg file.
|
||||
|
||||
.DESCRIPTION
|
||||
Windows counterpart of docs/netbird-macos.sh. Outcome:
|
||||
HKLM\Software\Policies\NetBird populated from the attached
|
||||
netbird-policy.reg file, daemon picks up the change via the
|
||||
1-minute MDM reload ticker.
|
||||
|
||||
Deployment:
|
||||
1. Admin Console -> Device Management -> Commands -> +.
|
||||
2. Type: Windows PowerShell. Run as: SYSTEM.
|
||||
3. Paste this file verbatim into the command body.
|
||||
4. In the same command, attach `netbird-policy.reg` as a file.
|
||||
JumpCloud copies attached files into the command's working
|
||||
directory before invoking the script, so `$PSScriptRoot` or
|
||||
Get-Location resolves to where the .reg lives.
|
||||
5. Bind to the target system group, save, run.
|
||||
|
||||
Producing the .reg file:
|
||||
On a reference machine, after configuring the policy values either
|
||||
via gpedit (GPO) or manual `reg add`, export with:
|
||||
|
||||
reg export "HKLM\Software\Policies\NetBird" netbird-policy.reg /y
|
||||
|
||||
Then attach the resulting file to the JumpCloud command.
|
||||
|
||||
Semantics:
|
||||
- The script nukes the existing HKLM\Software\Policies\NetBird key
|
||||
before importing the .reg, so the .reg is the SINGLE SOURCE OF
|
||||
TRUTH. Any value present in the registry but absent from the .reg
|
||||
is removed. This is what an MDM admin almost always wants.
|
||||
- Setting the .reg to an empty (header-only) file effectively unsets
|
||||
the policy.
|
||||
|
||||
Idempotency: re-running the script with the same .reg is a no-op from
|
||||
the daemon's perspective (values identical → 1-min ticker sees no
|
||||
diff → engine not restarted).
|
||||
|
||||
Exit codes: 0 = success; 1 = .reg missing or reg.exe error.
|
||||
#>
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
$RegFileName = "netbird-policy.reg"
|
||||
$RegKey = "HKLM\Software\Policies\NetBird"
|
||||
|
||||
# Resolve the attached .reg file: JumpCloud copies command attachments
|
||||
# into C:\Windows\Temp\ before invoking the script. Cwd / $PSScriptRoot
|
||||
# fallbacks cover the local-dev case where you might dot-source this
|
||||
# from elsewhere.
|
||||
$candidates = @(
|
||||
(Join-Path "$env:WINDIR\Temp" $RegFileName)
|
||||
(Join-Path (Get-Location) $RegFileName)
|
||||
(Join-Path $PSScriptRoot $RegFileName)
|
||||
) | Where-Object { Test-Path $_ }
|
||||
|
||||
if ($candidates.Count -eq 0) {
|
||||
Write-Error "[netbird-mdm] $RegFileName not found in working directory or `$PSScriptRoot. Attach the file to the JumpCloud command."
|
||||
exit 1
|
||||
}
|
||||
$regFile = $candidates[0]
|
||||
Write-Host "[netbird-mdm] using $regFile"
|
||||
|
||||
# Wipe the existing policy key so the .reg is authoritative.
|
||||
$existed = Test-Path "Registry::HKEY_LOCAL_MACHINE\Software\Policies\NetBird"
|
||||
if ($existed) {
|
||||
& reg.exe delete $RegKey /f | Out-Null
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Error "[netbird-mdm] failed to clear $RegKey before import (exit $LASTEXITCODE)"
|
||||
exit 1
|
||||
}
|
||||
Write-Host "[netbird-mdm] cleared previous values under $RegKey"
|
||||
}
|
||||
|
||||
# Import. reg.exe writes both data and (re-)creates the key if needed.
|
||||
& reg.exe import $regFile
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Error "[netbird-mdm] reg import failed (exit $LASTEXITCODE)"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Audit dump so the JumpCloud per-execution log captures the applied state.
|
||||
Write-Host "[netbird-mdm] final policy state under $RegKey :"
|
||||
& reg.exe query $RegKey /s
|
||||
|
||||
# Daemon's 1-min reload ticker picks up the change automatically.
|
||||
# Uncomment to force immediate convergence (skips the ticker wait):
|
||||
# Restart-Service netbird -Force -ErrorAction SilentlyContinue
|
||||
|
||||
exit 0
|
||||
@@ -1,95 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<policyDefinitionResources xmlns:xsd="http://www.w3.org/2001/XMLSchema"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
revision="1.0"
|
||||
schemaVersion="1.0"
|
||||
xmlns="http://schemas.microsoft.com/GroupPolicy/2006/07/PolicyDefinitions">
|
||||
<displayName>NetBird Client Policies</displayName>
|
||||
<description>Group Policy template for NetBird client MDM-managed settings. Values are written under HKLM\Software\Policies\NetBird and consumed by the netbird daemon at startup and every 1-minute reload tick.</description>
|
||||
<resources>
|
||||
<stringTable>
|
||||
|
||||
<!-- Categories -->
|
||||
<string id="NetBird_Category">NetBird</string>
|
||||
<string id="SUPPORTED_NetBird_All">NetBird Client 0.40+</string>
|
||||
|
||||
<!-- Identity / auth -->
|
||||
<string id="ManagementURL_Name">Management URL</string>
|
||||
<string id="ManagementURL_Help">URL of the NetBird management server. Format: https://host[:port]. When set, users cannot override this value via UI or CLI.</string>
|
||||
|
||||
<string id="PreSharedKey_Name">Pre-shared key</string>
|
||||
<string id="PreSharedKey_Help">WireGuard pre-shared key used as an additional symmetric secret on every peer-to-peer tunnel. Secret value.</string>
|
||||
|
||||
<!-- Settings: engine / runtime behavior -->
|
||||
<string id="DisableAutoConnect_Name">Disable auto-connect</string>
|
||||
<string id="DisableAutoConnect_Help">When enabled, the NetBird tunnel does not auto-connect at daemon startup. Equivalent to --disable-auto-connect.</string>
|
||||
|
||||
<string id="DisableClientRoutes_Name">Disable client routes</string>
|
||||
<string id="DisableClientRoutes_Help">When enabled, this client will not consume routes advertised by routing peers. Equivalent to --disable-client-routes.</string>
|
||||
|
||||
<string id="DisableServerRoutes_Name">Disable server routes</string>
|
||||
<string id="DisableServerRoutes_Help">When enabled, this client will not act as a routing peer for other clients. Equivalent to --disable-server-routes.</string>
|
||||
|
||||
<string id="BlockInbound_Name">Block inbound</string>
|
||||
<string id="BlockInbound_Help">When enabled, the client firewall blocks all inbound peer traffic on the WireGuard interface. Equivalent to --block-inbound.</string>
|
||||
|
||||
<string id="AllowServerSSH_Name">Allow server SSH</string>
|
||||
<string id="AllowServerSSH_Help">When enabled, this client accepts incoming SSH sessions via NetBird SSH. Equivalent to --allow-server-ssh.</string>
|
||||
|
||||
<string id="RosenpassEnabled_Name">Enable Rosenpass</string>
|
||||
<string id="RosenpassEnabled_Help">Enables Rosenpass post-quantum key exchange on WireGuard tunnels. Both peers must support it.</string>
|
||||
|
||||
<string id="RosenpassPermissive_Name">Rosenpass permissive</string>
|
||||
<string id="RosenpassPermissive_Help">When enabled, the client falls back to plain WireGuard if a peer does not support Rosenpass; otherwise it refuses the connection.</string>
|
||||
|
||||
<string id="WireguardPort_Name">WireGuard port</string>
|
||||
<string id="WireguardPort_Help">UDP port used by the local WireGuard interface. Allowed range: 1-65535.</string>
|
||||
|
||||
<string id="SplitTunnel_Name">Split tunnel</string>
|
||||
<string id="SplitTunnel_Help">Restrict the NetBird tunnel to or from a chosen list of application package names. Choose either the allow mode (only the listed apps route through NetBird) or the disallow mode (the listed apps bypass NetBird; everything else routes through). The mode is mutually exclusive — only one can be active at a time. Android-only at the daemon level; Windows/macOS/iOS clients ignore this policy.</string>
|
||||
<string id="SplitTunnel_Allow">Allow only listed apps (everything else bypasses)</string>
|
||||
<string id="SplitTunnel_Disallow">Disallow listed apps (everything else routes)</string>
|
||||
|
||||
<!-- UI -->
|
||||
<string id="DisableUpdateSettings_Name">Disable update settings</string>
|
||||
<string id="DisableUpdateSettings_Help">When enabled, blocks every configuration change from the client UI and from the CLI (netbird up / login / setconfig). The Settings view stays viewable but read-only. Equivalent to --disable-update-settings.</string>
|
||||
|
||||
<string id="DisableProfiles_Name">Disable profiles</string>
|
||||
<string id="DisableProfiles_Help">When enabled, the client UI/CLI cannot list, create, switch or remove NetBird connection profiles. Equivalent to --disable-profiles.</string>
|
||||
|
||||
<string id="DisableNetworks_Name">Disable networks</string>
|
||||
<string id="DisableNetworks_Help">When enabled, the client UI/CLI cannot list, select or deselect NetBird networks (the corresponding daemon RPCs return Unavailable). Equivalent to --disable-networks.</string>
|
||||
|
||||
<string id="DisableMetricsCollection_Name">Disable metrics collection</string>
|
||||
<string id="DisableMetricsCollection_Help">When enabled, the client does not collect or report local usage metrics.</string>
|
||||
|
||||
</stringTable>
|
||||
<presentationTable>
|
||||
|
||||
<presentation id="ManagementURL_Pres">
|
||||
<textBox refId="ManagementURL_Text">
|
||||
<label>Management URL:</label>
|
||||
<defaultValue>https://api.netbird.io:443</defaultValue>
|
||||
</textBox>
|
||||
</presentation>
|
||||
|
||||
<presentation id="PreSharedKey_Pres">
|
||||
<textBox refId="PreSharedKey_Text">
|
||||
<label>Pre-shared key:</label>
|
||||
</textBox>
|
||||
</presentation>
|
||||
|
||||
<presentation id="WireguardPort_Pres">
|
||||
<decimalTextBox refId="WireguardPort_Decimal" defaultValue="51820">WireGuard UDP port:</decimalTextBox>
|
||||
</presentation>
|
||||
|
||||
<presentation id="SplitTunnel_Pres">
|
||||
<dropdownList refId="SplitTunnel_Mode" defaultItem="0">Mode:</dropdownList>
|
||||
<textBox refId="SplitTunnel_Apps">
|
||||
<label>Package names (comma-separated):</label>
|
||||
</textBox>
|
||||
</presentation>
|
||||
|
||||
</presentationTable>
|
||||
</resources>
|
||||
</policyDefinitionResources>
|
||||
@@ -1,223 +0,0 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<policyDefinitions xmlns:xsd="http://www.w3.org/2001/XMLSchema"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
revision="1.0"
|
||||
schemaVersion="1.0"
|
||||
xmlns="http://schemas.microsoft.com/GroupPolicy/2006/07/PolicyDefinitions">
|
||||
<policyNamespaces>
|
||||
<target prefix="netbird" namespace="NetBird.Policies.Client" />
|
||||
</policyNamespaces>
|
||||
<resources minRequiredRevision="1.0" />
|
||||
<supportedOn>
|
||||
<definitions>
|
||||
<definition name="SUPPORTED_NetBird_All" displayName="$(string.SUPPORTED_NetBird_All)" />
|
||||
</definitions>
|
||||
</supportedOn>
|
||||
<categories>
|
||||
<category name="NetBird" displayName="$(string.NetBird_Category)" />
|
||||
</categories>
|
||||
<policies>
|
||||
|
||||
<!-- ============================================================ -->
|
||||
<!-- TOP-LEVEL: foundational identity / authentication -->
|
||||
<!-- ============================================================ -->
|
||||
|
||||
<policy name="ManagementURL"
|
||||
class="Machine"
|
||||
displayName="$(string.ManagementURL_Name)"
|
||||
explainText="$(string.ManagementURL_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
presentation="$(presentation.ManagementURL_Pres)">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<elements>
|
||||
<text id="ManagementURL_Text" valueName="ManagementURL" required="true" />
|
||||
</elements>
|
||||
</policy>
|
||||
|
||||
<policy name="PreSharedKey"
|
||||
class="Machine"
|
||||
displayName="$(string.PreSharedKey_Name)"
|
||||
explainText="$(string.PreSharedKey_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
presentation="$(presentation.PreSharedKey_Pres)">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<elements>
|
||||
<text id="PreSharedKey_Text" valueName="PreSharedKey" />
|
||||
</elements>
|
||||
</policy>
|
||||
|
||||
<!-- ============================================================ -->
|
||||
<!-- SETTINGS: engine / runtime / connection behavior -->
|
||||
<!-- ============================================================ -->
|
||||
|
||||
<policy name="DisableAutoConnect"
|
||||
class="Machine"
|
||||
displayName="$(string.DisableAutoConnect_Name)"
|
||||
explainText="$(string.DisableAutoConnect_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="DisableAutoConnect">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
<policy name="DisableClientRoutes"
|
||||
class="Machine"
|
||||
displayName="$(string.DisableClientRoutes_Name)"
|
||||
explainText="$(string.DisableClientRoutes_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="DisableClientRoutes">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
<policy name="DisableServerRoutes"
|
||||
class="Machine"
|
||||
displayName="$(string.DisableServerRoutes_Name)"
|
||||
explainText="$(string.DisableServerRoutes_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="DisableServerRoutes">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
<policy name="BlockInbound"
|
||||
class="Machine"
|
||||
displayName="$(string.BlockInbound_Name)"
|
||||
explainText="$(string.BlockInbound_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="BlockInbound">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
<policy name="AllowServerSSH"
|
||||
class="Machine"
|
||||
displayName="$(string.AllowServerSSH_Name)"
|
||||
explainText="$(string.AllowServerSSH_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="AllowServerSSH">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
<policy name="RosenpassEnabled"
|
||||
class="Machine"
|
||||
displayName="$(string.RosenpassEnabled_Name)"
|
||||
explainText="$(string.RosenpassEnabled_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="RosenpassEnabled">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
<policy name="RosenpassPermissive"
|
||||
class="Machine"
|
||||
displayName="$(string.RosenpassPermissive_Name)"
|
||||
explainText="$(string.RosenpassPermissive_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="RosenpassPermissive">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
<policy name="WireguardPort"
|
||||
class="Machine"
|
||||
displayName="$(string.WireguardPort_Name)"
|
||||
explainText="$(string.WireguardPort_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
presentation="$(presentation.WireguardPort_Pres)">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<elements>
|
||||
<decimal id="WireguardPort_Decimal" valueName="WireguardPort"
|
||||
minValue="1" maxValue="65535" required="true" />
|
||||
</elements>
|
||||
</policy>
|
||||
|
||||
<policy name="SplitTunnel"
|
||||
class="Machine"
|
||||
displayName="$(string.SplitTunnel_Name)"
|
||||
explainText="$(string.SplitTunnel_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
presentation="$(presentation.SplitTunnel_Pres)">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<elements>
|
||||
<enum id="SplitTunnel_Mode" valueName="SplitTunnelMode" required="true">
|
||||
<item displayName="$(string.SplitTunnel_Allow)"><value><string>allow</string></value></item>
|
||||
<item displayName="$(string.SplitTunnel_Disallow)"><value><string>disallow</string></value></item>
|
||||
</enum>
|
||||
<text id="SplitTunnel_Apps" valueName="SplitTunnelApps" required="true" />
|
||||
</elements>
|
||||
</policy>
|
||||
|
||||
<!-- ============================================================ -->
|
||||
<!-- UI: visibility / UX kill switches -->
|
||||
<!-- ============================================================ -->
|
||||
|
||||
<policy name="DisableUpdateSettings"
|
||||
class="Machine"
|
||||
displayName="$(string.DisableUpdateSettings_Name)"
|
||||
explainText="$(string.DisableUpdateSettings_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="DisableUpdateSettings">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
<policy name="DisableProfiles"
|
||||
class="Machine"
|
||||
displayName="$(string.DisableProfiles_Name)"
|
||||
explainText="$(string.DisableProfiles_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="DisableProfiles">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
<policy name="DisableNetworks"
|
||||
class="Machine"
|
||||
displayName="$(string.DisableNetworks_Name)"
|
||||
explainText="$(string.DisableNetworks_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="DisableNetworks">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
<policy name="DisableMetricsCollection"
|
||||
class="Machine"
|
||||
displayName="$(string.DisableMetricsCollection_Name)"
|
||||
explainText="$(string.DisableMetricsCollection_Help)"
|
||||
key="Software\Policies\NetBird"
|
||||
valueName="DisableMetricsCollection">
|
||||
<parentCategory ref="NetBird" />
|
||||
<supportedOn ref="SUPPORTED_NetBird_All" />
|
||||
<enabledValue><decimal value="1" /></enabledValue>
|
||||
<disabledValue><decimal value="0" /></disabledValue>
|
||||
</policy>
|
||||
|
||||
</policies>
|
||||
</policyDefinitions>
|
||||
@@ -99,9 +99,6 @@ func addFields(entry *logrus.Entry) {
|
||||
if ctxAccountID, ok := entry.Context.Value(context.AccountIDKey).(string); ok {
|
||||
entry.Data[context.AccountIDKey] = ctxAccountID
|
||||
}
|
||||
if ctxUserAgent, ok := entry.Context.Value(context.UserAgentKey).(string); ok {
|
||||
entry.Data[context.UserAgentKey] = ctxUserAgent
|
||||
}
|
||||
if ctxInitiatorID, ok := entry.Context.Value(context.UserIDKey).(string); ok {
|
||||
entry.Data[context.UserIDKey] = ctxInitiatorID
|
||||
}
|
||||
|
||||
7
go.mod
7
go.mod
@@ -2,8 +2,6 @@ module github.com/netbirdio/netbird
|
||||
|
||||
go 1.25.5
|
||||
|
||||
toolchain go1.25.11
|
||||
|
||||
require (
|
||||
cunicu.li/go-rosenpass v0.5.42
|
||||
github.com/cenkalti/backoff/v4 v4.3.0
|
||||
@@ -56,7 +54,6 @@ require (
|
||||
github.com/fsnotify/fsnotify v1.9.0
|
||||
github.com/gliderlabs/ssh v0.3.8
|
||||
github.com/go-jose/go-jose/v4 v4.1.4
|
||||
github.com/goccy/go-yaml v1.18.0
|
||||
github.com/godbus/dbus/v5 v5.1.0
|
||||
github.com/golang-jwt/jwt/v5 v5.3.1
|
||||
github.com/golang/mock v1.6.0
|
||||
@@ -134,7 +131,6 @@ require (
|
||||
gorm.io/driver/sqlite v1.5.7
|
||||
gorm.io/gorm v1.25.12
|
||||
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89
|
||||
howett.net/plist v1.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -215,9 +211,10 @@ require (
|
||||
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
|
||||
github.com/go-webauthn/webauthn v0.16.4 // indirect
|
||||
github.com/go-webauthn/x v0.2.3 // indirect
|
||||
github.com/goccy/go-yaml v1.18.0 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
|
||||
github.com/google/btree v1.1.3 // indirect
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/go-tpm v0.9.8 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
|
||||
8
go.sum
8
go.sum
@@ -275,8 +275,8 @@ github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu
|
||||
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
|
||||
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
|
||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
@@ -380,7 +380,6 @@ github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZ
|
||||
github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
|
||||
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade h1:FmusiCI1wHw+XQbvL9M+1r/C3SPqKrmBaIOYwVfQoDE=
|
||||
github.com/jeandeaual/go-locale v0.0.0-20250612000132-0ef82f21eade/go.mod h1:ZDXo8KHryOWSIqnsb/CiDq7hQUYryCgdVnxbj8tDG7o=
|
||||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
@@ -947,7 +946,6 @@ gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI
|
||||
gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
|
||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
|
||||
gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0/go.mod h1:WDnlLJ4WF5VGsH/HVa3CI79GS0ol3YnhVnKP89i0kNg=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
@@ -970,7 +968,5 @@ gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
|
||||
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA=
|
||||
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q=
|
||||
howett.net/plist v1.0.1 h1:37GdZ8tP09Q35o9ych3ehygcsL+HqKSwzctveSlarvM=
|
||||
howett.net/plist v1.0.1/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
|
||||
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
|
||||
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=
|
||||
|
||||
@@ -19,46 +19,6 @@ readonly MSG_SEPARATOR="=========================================="
|
||||
# Utility Functions
|
||||
############################################
|
||||
|
||||
check_docker_sock_perms() {
|
||||
local sock="${DOCKER_HOST:-unix:///var/run/docker.sock}"
|
||||
sock="${sock#unix://}"
|
||||
|
||||
if [[ ! -S "$sock" ]]; then
|
||||
return 0
|
||||
fi
|
||||
|
||||
if [[ ! -r "$sock" ]] || [[ ! -w "$sock" ]]; then
|
||||
local group
|
||||
if [[ "${OSTYPE}" == "darwin"* ]]; then
|
||||
group="$(stat -f '%Sg' "$sock")"
|
||||
else
|
||||
group="$(stat -c '%G' "$sock")"
|
||||
fi
|
||||
|
||||
echo "Cannot access Docker socket: $sock" > /dev/stderr
|
||||
echo "" > /dev/stderr
|
||||
echo "Socket permissions:" > /dev/stderr
|
||||
ls -l "$sock" > /dev/stderr
|
||||
echo "" > /dev/stderr
|
||||
|
||||
if [[ "$group" == "docker" ]]; then
|
||||
echo "Your user may need to be added to the '$group' group:" > /dev/stderr
|
||||
echo " sudo usermod -aG $group \"$USER\"" > /dev/stderr
|
||||
echo "Then log out and back in, or run this for the current shell:" > /dev/stderr
|
||||
echo " newgrp $group" > /dev/stderr
|
||||
echo "Note: newgrp is temporary; usermod is the permanent group change." > /dev/stderr
|
||||
else
|
||||
echo "The Docker socket is owned by the '$group' group, which is not the standard 'docker' group." > /dev/stderr
|
||||
echo "For safety, this script will not suggest adding your user to '$group'." > /dev/stderr
|
||||
echo "Instead, either run this script with appropriate privileges (for example, via sudo) or follow Docker's post-install steps to configure access via the 'docker' group:" > /dev/stderr
|
||||
echo " https://docs.docker.com/engine/install/linux-postinstall/" > /dev/stderr
|
||||
fi
|
||||
|
||||
exit 1
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
check_docker_compose() {
|
||||
if command -v docker-compose &> /dev/null
|
||||
then
|
||||
@@ -621,15 +581,12 @@ start_services_and_show_instructions() {
|
||||
}
|
||||
|
||||
init_environment() {
|
||||
# Check if docker compose is installed using check_docker_compose function
|
||||
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||
check_docker_sock_perms
|
||||
|
||||
initialize_default_values
|
||||
configure_domain
|
||||
configure_reverse_proxy
|
||||
|
||||
check_jq
|
||||
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
|
||||
|
||||
check_existing_installation
|
||||
generate_configuration_files
|
||||
|
||||
@@ -488,195 +488,6 @@ func TestUpdate_AllowsPortChange(t *testing.T) {
|
||||
assert.Equal(t, uint16(54321), updated.ListenPort, "explicit port change should be applied")
|
||||
}
|
||||
|
||||
func TestUpdate_PreservesPortWhenCustomPortsNotSupported(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
|
||||
ctx := context.Background()
|
||||
|
||||
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
|
||||
|
||||
updated := &rpservice.Service{
|
||||
ID: existing.ID,
|
||||
AccountID: testAccountID,
|
||||
Name: "tcp-svc-renamed",
|
||||
Mode: "tcp",
|
||||
Domain: testCluster,
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 0,
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
|
||||
require.NoError(t, err, "update must not be rejected by the custom-port capability check")
|
||||
assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved on unsupported cluster")
|
||||
}
|
||||
|
||||
func TestUpdate_PreservesPortWhenCustomPortsUnknown(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
|
||||
|
||||
updated := &rpservice.Service{
|
||||
ID: existing.ID,
|
||||
AccountID: testAccountID,
|
||||
Name: "tcp-svc-renamed",
|
||||
Mode: "tcp",
|
||||
Domain: testCluster,
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 0,
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
|
||||
require.NoError(t, err, "update must not be rejected when cluster capability is unknown")
|
||||
assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved when capability is unknown")
|
||||
}
|
||||
|
||||
func TestUpdate_RejectsPortChangeWhenCustomPortsNotSupported(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
|
||||
ctx := context.Background()
|
||||
|
||||
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
|
||||
|
||||
updated := &rpservice.Service{
|
||||
ID: existing.ID,
|
||||
AccountID: testAccountID,
|
||||
Name: "tcp-svc",
|
||||
Mode: "tcp",
|
||||
Domain: testCluster,
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 54321,
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
|
||||
require.Error(t, err, "explicit port change on update must be rejected on unsupported clusters")
|
||||
assert.Contains(t, err.Error(), "custom ports not supported on target cluster")
|
||||
}
|
||||
|
||||
func TestUpdate_TLSPortChangeAllowedWhenNotSupported(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
|
||||
ctx := context.Background()
|
||||
|
||||
existing := seedService(t, testStore, "tls-svc", "tls", "app.example.com", testCluster, 443)
|
||||
|
||||
updated := &rpservice.Service{
|
||||
ID: existing.ID,
|
||||
AccountID: testAccountID,
|
||||
Name: "tls-svc",
|
||||
Mode: "tls",
|
||||
Domain: "app.example.com",
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 9999,
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
|
||||
require.NoError(t, err, "TLS port change uses SNI routing and is exempt from the custom-port check")
|
||||
assert.Equal(t, uint16(9999), updated.ListenPort, "TLS port change should be applied")
|
||||
}
|
||||
|
||||
func TestValidateL4PortDiffOnClusterDiff(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mode string
|
||||
customPorts *bool
|
||||
newPort uint16
|
||||
oldPort uint16
|
||||
wantErr bool
|
||||
}{
|
||||
{"tcp port change unsupported", "tcp", boolPtr(false), 54321, 12345, true},
|
||||
{"tcp port change unknown capability", "tcp", nil, 54321, 12345, true},
|
||||
{"udp port change unsupported", "udp", boolPtr(false), 54321, 12345, true},
|
||||
{"tcp first port assignment unsupported", "tcp", boolPtr(false), 54321, 0, true},
|
||||
{"tcp port change supported", "tcp", boolPtr(true), 54321, 12345, false},
|
||||
{"tcp port unchanged unsupported", "tcp", boolPtr(false), 12345, 12345, false},
|
||||
{"tcp zero port unsupported", "tcp", boolPtr(false), 0, 12345, false},
|
||||
{"tls port change unsupported", "tls", boolPtr(false), 9999, 443, false},
|
||||
{"http mode ignored", "http", boolPtr(false), 54321, 12345, false},
|
||||
{"empty mode ignored", "", boolPtr(false), 54321, 12345, false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
newSvc := &rpservice.Service{Mode: tc.mode, ListenPort: tc.newPort, ProxyCluster: testCluster}
|
||||
oldSvc := &rpservice.Service{Mode: tc.mode, ListenPort: tc.oldPort, ProxyCluster: testCluster}
|
||||
|
||||
err := validateL4PortDiffOnClusterDiff(tc.customPorts, newSvc, oldSvc)
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err, "port diff should be rejected for %s", tc.name)
|
||||
} else {
|
||||
assert.NoError(t, err, "port diff should be allowed for %s", tc.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate_PortConflictRejected(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
|
||||
ctx := context.Background()
|
||||
|
||||
seedService(t, testStore, "tcp-a", "tcp", "tcp-a."+testCluster, testCluster, 5432)
|
||||
svcB := seedService(t, testStore, "tcp-b", "tcp", "tcp-b."+testCluster, testCluster, 6543)
|
||||
|
||||
updated := &rpservice.Service{
|
||||
ID: svcB.ID,
|
||||
AccountID: testAccountID,
|
||||
Name: "tcp-b",
|
||||
Mode: "tcp",
|
||||
Domain: "tcp-b." + testCluster,
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 5432,
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
|
||||
require.Error(t, err, "updating to a port held by another service should be rejected")
|
||||
assert.Contains(t, err.Error(), "already in use")
|
||||
}
|
||||
|
||||
func TestUpdate_AutoAssignsWhenNoPort(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, boolPtr(false))
|
||||
ctx := context.Background()
|
||||
|
||||
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 0)
|
||||
|
||||
updated := &rpservice.Service{
|
||||
ID: existing.ID,
|
||||
AccountID: testAccountID,
|
||||
Name: "tcp-svc",
|
||||
Mode: "tcp",
|
||||
Domain: testCluster,
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 0,
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, updated.ListenPort >= autoAssignPortMin && updated.ListenPort <= autoAssignPortMax,
|
||||
"auto-assigned port %d should be in range [%d, %d]", updated.ListenPort, autoAssignPortMin, autoAssignPortMax)
|
||||
assert.True(t, updated.PortAutoAssigned, "PortAutoAssigned should be set when update triggers auto-assignment")
|
||||
}
|
||||
|
||||
func TestCreateServiceFromPeer_TCP(t *testing.T) {
|
||||
mgr, _, _ := setupL4Test(t, boolPtr(false))
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -338,7 +338,7 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.ensureL4Port(ctx, transaction, svc, customPorts, false); err != nil {
|
||||
if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -367,11 +367,11 @@ func (m *Manager) clusterCustomPorts(ctx context.Context, svc *service.Service)
|
||||
|
||||
// ensureL4Port auto-assigns a listen port when needed and validates cluster support.
|
||||
// customPorts must be pre-computed via clusterCustomPorts before entering a transaction.
|
||||
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service, customPorts *bool, serviceUpdate bool) error {
|
||||
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service, customPorts *bool) error {
|
||||
if !service.IsL4Protocol(svc.Mode) {
|
||||
return nil
|
||||
}
|
||||
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && !serviceUpdate && (customPorts == nil || !*customPorts) {
|
||||
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) {
|
||||
if svc.Source != service.SourceEphemeral {
|
||||
return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster)
|
||||
}
|
||||
@@ -465,7 +465,7 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.ensureL4Port(ctx, transaction, svc, customPorts, false); err != nil {
|
||||
if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -651,22 +651,12 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
|
||||
m.preserveListenPort(service, existingService)
|
||||
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
|
||||
|
||||
// if the service is being updated, and we decide in the future to allow mode update,
|
||||
// we should reconsider the currently assigned port if not 0 for clusters that don't support custom ports
|
||||
if err := validateL4PortDiffOnClusterDiff(customPorts, service, existingService); err != nil {
|
||||
if err := m.ensureL4Port(ctx, transaction, service, customPorts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.ensureL4Port(ctx, transaction, service, customPorts, true); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// we can try carrying the previous service port into a new cluster, if this becomes a problem for multiple users,
|
||||
// we should reconsider adding another check
|
||||
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("update service: %w", err)
|
||||
}
|
||||
@@ -674,21 +664,6 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateL4PortDiffOnClusterDiff checks if custom L4 ports are configured and validates port changes across clusters.
|
||||
// It ensures no port changes if custom ports are unsupported for a given cluster and protocol mode.
|
||||
// Returns an error if validation fails, otherwise returns nil.
|
||||
func validateL4PortDiffOnClusterDiff(customPorts *bool, newSVC, oldSVC *service.Service) error {
|
||||
if !service.IsPortBasedProtocol(newSVC.Mode) || (customPorts != nil && *customPorts) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if newSVC.ListenPort != 0 && newSVC.ListenPort != oldSVC.ListenPort {
|
||||
return status.Errorf(status.InvalidArgument, "custom ports not supported on target cluster %s", newSVC.ProxyCluster)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleDomainChange validates the new domain is free inside the transaction
|
||||
// and applies the pre-resolved cluster (computed outside the tx by
|
||||
// resolveEffectiveCluster). It must NOT call clusterDeriver here: that talks
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
nbversion "github.com/netbirdio/netbird/version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
@@ -30,23 +28,6 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
const (
|
||||
// deprecatedRemotePeersVersion is the version of Netbird that introduced the NetworkMap.RemotePeers field, deprecated in favor of RemotePeers.
|
||||
deprecatedRemotePeersVersion = "0.29.3"
|
||||
)
|
||||
|
||||
// precomputedDeprecatedRemotePeersConstraint is the parsed ">= 0.29.3" constraint,
|
||||
// built once at init since the bound is a compile-time constant.
|
||||
var precomputedDeprecatedRemotePeersConstraint version.Constraints
|
||||
|
||||
func init() {
|
||||
constraint, err := version.NewConstraint(">= " + deprecatedRemotePeersVersion)
|
||||
if err != nil {
|
||||
panic("parse deprecated remote peers version constraint: " + err.Error())
|
||||
}
|
||||
precomputedDeprecatedRemotePeersConstraint = constraint
|
||||
}
|
||||
|
||||
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
||||
if config == nil {
|
||||
return nil
|
||||
@@ -174,11 +155,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
|
||||
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
|
||||
|
||||
if !shouldSkipSendingDeprecatedRemotePeers(peer.Meta.WtVersion) {
|
||||
response.RemotePeers = remotePeers
|
||||
}
|
||||
|
||||
response.RemotePeers = remotePeers
|
||||
response.NetworkMap.RemotePeers = remotePeers
|
||||
response.RemotePeersIsEmpty = len(remotePeers) == 0
|
||||
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
||||
@@ -269,19 +246,6 @@ func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]m
|
||||
return hashedUsers, machineUsers
|
||||
}
|
||||
|
||||
func shouldSkipSendingDeprecatedRemotePeers(peerVersion string) bool {
|
||||
if nbversion.IsDevelopmentVersion(peerVersion) {
|
||||
return true
|
||||
}
|
||||
|
||||
peerNBVersion, err := version.NewVersion(peerVersion)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return precomputedDeprecatedRemotePeersConstraint.Check(peerNBVersion)
|
||||
}
|
||||
|
||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
|
||||
for _, rPeer := range peers {
|
||||
allowedIPs := []string{rPeer.IP.String() + "/32"}
|
||||
@@ -399,6 +363,7 @@ func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSource
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
|
||||
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
|
||||
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {
|
||||
|
||||
@@ -202,42 +202,6 @@ func TestBuildJWTConfig_Audiences(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestShouldSkipSendingDeprecatedRemotePeers covers the version gate that
|
||||
// stops populating the deprecated top-level SyncResponse.RemotePeers field for
|
||||
// peers new enough to read RemotePeers off the NetworkMap. Development builds
|
||||
// are treated as latest and skip the field. The gate otherwise fails safe: a
|
||||
// release version older than the boundary, or one that can't be parsed (empty,
|
||||
// garbage, prereleases of the boundary) still receives the deprecated field so
|
||||
// older/unknown clients keep working.
|
||||
func TestShouldSkipSendingDeprecatedRemotePeers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
peerVersion string
|
||||
wantSkip bool
|
||||
}{
|
||||
{"exact boundary skips", "0.29.3", true},
|
||||
{"newer patch skips", "0.29.4", true},
|
||||
{"newer minor skips", "0.30.0", true},
|
||||
{"newer major skips", "1.0.0", true},
|
||||
{"v-prefixed newer skips", "v0.30.0", true},
|
||||
{"development build skips", "development", true},
|
||||
{"development build with commit skips", "development-abc123def456-dirty", true},
|
||||
{"older patch keeps field", "0.29.2", false},
|
||||
{"older minor keeps field", "0.28.0", false},
|
||||
{"prerelease of boundary keeps field", "0.29.3-SNAPSHOT", false},
|
||||
{"tagged dev prerelease keeps field", "v0.31.1-dev", false},
|
||||
{"empty version keeps field", "", false},
|
||||
{"garbage version keeps field", "not-a-version", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := shouldSkipSendingDeprecatedRemotePeers(tc.peerVersion)
|
||||
assert.Equal(t, tc.wantSkip, got, "skip decision for peer version %q", tc.peerVersion)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeSessionExpiresAt pins the wire encoding the client's
|
||||
// applySessionDeadline depends on:
|
||||
//
|
||||
|
||||
@@ -666,10 +666,8 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
||||
case resp := <-conn.sendChan:
|
||||
if err := conn.sendResponse(resp); err != nil {
|
||||
errChan <- err
|
||||
log.WithContext(conn.ctx).Tracef("Failed to send response to proxy %s: %v", conn.proxyID, err)
|
||||
return
|
||||
}
|
||||
log.WithContext(conn.ctx).Tracef("Send response to proxy %s", conn.proxyID)
|
||||
case <-conn.ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ const (
|
||||
RoleKey = nbcontext.RoleKey
|
||||
UserIDKey = nbcontext.UserIDKey
|
||||
PeerIDKey = nbcontext.PeerIDKey
|
||||
UserAgentKey = nbcontext.UserAgentKey
|
||||
)
|
||||
|
||||
// RoleFromContext returns the role stored in ctx, or empty string and false if absent.
|
||||
|
||||
@@ -21,8 +21,6 @@ const (
|
||||
httpRequestCounterPrefix = "management.http.request.counter"
|
||||
httpResponseCounterPrefix = "management.http.response.counter"
|
||||
httpRequestDurationPrefix = "management.http.request.duration.ms"
|
||||
|
||||
RequestIDHeader = "X-Request-Id"
|
||||
)
|
||||
|
||||
// WrappedResponseWriter is a wrapper for http.ResponseWriter that allows the
|
||||
@@ -174,10 +172,6 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
|
||||
reqID := xid.New().String()
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID)
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.UserAgentKey, r.UserAgent())
|
||||
|
||||
rw.Header().Set(RequestIDHeader, reqID)
|
||||
|
||||
log.WithContext(ctx).Tracef("HTTP request %v: %v %v", reqID, r.Method, r.URL)
|
||||
|
||||
|
||||
@@ -557,6 +557,7 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute
|
||||
return enabledRoutes, disabledRoutes
|
||||
}
|
||||
|
||||
|
||||
func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
|
||||
var filteredRoutes []*route.Route
|
||||
for _, r := range routes {
|
||||
@@ -627,14 +628,9 @@ func (c *NetworkMapComponents) getDefaultPermit(r *route.Route, includeIPv6 bool
|
||||
|
||||
rules := []*RouteFirewallRule{&rule}
|
||||
|
||||
isDefaultV4 := r.Network.Addr().Is4() && r.Network.Bits() == 0
|
||||
if includeIPv6 && (r.IsDynamic() || isDefaultV4) {
|
||||
if includeIPv6 && r.IsDynamic() {
|
||||
ruleV6 := rule
|
||||
ruleV6.SourceRanges = []string{"::/0"}
|
||||
if isDefaultV4 {
|
||||
ruleV6.Destination = "::/0"
|
||||
ruleV6.RouteID = r.ID + "-v6-default"
|
||||
}
|
||||
rules = append(rules, &ruleV6)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -1030,48 +1029,6 @@ func TestComponents_RouteDefaultPermit(t *testing.T) {
|
||||
assert.True(t, hasDefaultPermit, "route without ACG should have default permit rule with 0.0.0.0/0 source")
|
||||
}
|
||||
|
||||
// TestComponents_ExitNodeDefaultPermitIPv6 verifies that a default exit node route
|
||||
// (0.0.0.0/0) without AccessControlGroups also emits an IPv6 default permit rule
|
||||
// (::/0 source and destination) for peers that support IPv6, mirroring the route
|
||||
// the client installs. Without it, IPv6 traffic is routed to the exit node but
|
||||
// dropped at the forward chain.
|
||||
func TestComponents_ExitNodeDefaultPermitIPv6(t *testing.T) {
|
||||
account, validatedPeers := scalableTestAccount(20, 2)
|
||||
|
||||
routingPeerID := "peer-5"
|
||||
routingPeer := account.Peers[routingPeerID]
|
||||
routingPeer.IPv6 = netip.MustParseAddr("fd00::5")
|
||||
routingPeer.Meta.Capabilities = append(routingPeer.Meta.Capabilities, nbpeer.PeerCapabilityIPv6Overlay)
|
||||
|
||||
account.Routes["route-exit"] = &route.Route{
|
||||
ID: "route-exit", Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
PeerID: routingPeerID, Peer: routingPeer.Key,
|
||||
Enabled: true, Groups: []string{"group-all"}, PeerGroups: []string{"group-0"},
|
||||
AccessControlGroups: []string{},
|
||||
AccountID: "test-account",
|
||||
}
|
||||
|
||||
nm := componentsNetworkMap(account, routingPeerID, validatedPeers)
|
||||
require.NotNil(t, nm)
|
||||
|
||||
hasV4 := false
|
||||
hasV6 := false
|
||||
for _, rfr := range nm.RoutesFirewallRules {
|
||||
switch rfr.Destination {
|
||||
case "0.0.0.0/0":
|
||||
if slices.Contains(rfr.SourceRanges, "0.0.0.0/0") {
|
||||
hasV4 = true
|
||||
}
|
||||
case "::/0":
|
||||
if slices.Contains(rfr.SourceRanges, "::/0") {
|
||||
hasV6 = true
|
||||
}
|
||||
}
|
||||
}
|
||||
assert.True(t, hasV4, "exit node route should have an IPv4 default permit rule (0.0.0.0/0)")
|
||||
assert.True(t, hasV6, "exit node route should have an IPv6 default permit rule (::/0)")
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
// 15. MULTIPLE ROUTERS PER NETWORK
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -249,7 +249,6 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
Private: private,
|
||||
MaxDialTimeout: maxDialTimeout,
|
||||
MaxSessionIdleTimeout: maxSessionIdleTimeout,
|
||||
MappingBatchWatchdog: envDurationOrDefault("NB_PROXY_MAPPING_BATCH_WATCHDOG", 0),
|
||||
GeoDataDir: geoDataDir,
|
||||
CrowdSecAPIURL: crowdsecAPIURL,
|
||||
CrowdSecAPIKey: crowdsecAPIKey,
|
||||
|
||||
@@ -28,10 +28,6 @@ import (
|
||||
|
||||
const deviceNamePrefix = "ingress-proxy-"
|
||||
|
||||
const clientStopTimeout = 30 * time.Second
|
||||
|
||||
const createProxyPeerTimeout = 30 * time.Second
|
||||
|
||||
// backendKey identifies a backend by its host:port from the target URL.
|
||||
type backendKey string
|
||||
|
||||
@@ -166,7 +162,6 @@ type NetBird struct {
|
||||
|
||||
clientsMux sync.RWMutex
|
||||
clients map[types.AccountID]*clientEntry
|
||||
lifecycleMu sync.Map
|
||||
initLogOnce sync.Once
|
||||
statusNotifier statusNotifier
|
||||
// readyHandler runs after the embedded client for an account reports
|
||||
@@ -182,10 +177,6 @@ type NetBird struct {
|
||||
// (i.e. when a new client was actually created, not when an existing one
|
||||
// was reused). The duration covers keygen + gRPC CreateProxyPeer + embed.New.
|
||||
OnAddPeer func(d time.Duration, err error)
|
||||
|
||||
// startClient runs the post-create client startup. Nil uses runClientStartup;
|
||||
// tests override it to avoid a real embed client.Start.
|
||||
startClient func(accountID types.AccountID, client *embed.Client)
|
||||
}
|
||||
|
||||
// ClientDebugInfo contains debug information about a client.
|
||||
@@ -209,20 +200,31 @@ type skipTLSVerifyContextKey struct{}
|
||||
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, serviceID types.ServiceID) error {
|
||||
si := serviceInfo{serviceID: serviceID}
|
||||
|
||||
if n.registerExistingClient(accountID, key, si) {
|
||||
return nil
|
||||
}
|
||||
n.clientsMux.Lock()
|
||||
|
||||
lifecycle := n.accountLifecycle(accountID)
|
||||
lifecycle.Lock()
|
||||
transferred := false
|
||||
defer func() {
|
||||
if !transferred {
|
||||
lifecycle.Unlock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
entry.services[key] = si
|
||||
started := entry.started
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).Debug("registered service with existing client")
|
||||
|
||||
if started && n.statusNotifier != nil {
|
||||
// Use a background context, not the caller's: the management
|
||||
// connection notification must land even if the request /
|
||||
// stream that triggered this registration is cancelled.
|
||||
// Mirrors the async runClientStartup path.
|
||||
if err := n.statusNotifier.NotifyStatus(context.Background(), accountID, serviceID, true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).WithError(err).Warn("failed to notify status for existing client")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if n.registerExistingClient(accountID, key, si) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -232,10 +234,10 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
n.OnAddPeer(time.Since(createStart), err)
|
||||
}
|
||||
if err != nil {
|
||||
n.clientsMux.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
n.clientsMux.Lock()
|
||||
n.clients[accountID] = entry
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
@@ -244,64 +246,17 @@ func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, key Se
|
||||
"service_key": key,
|
||||
}).Info("created new client for account")
|
||||
|
||||
transferred = true
|
||||
go func() {
|
||||
defer lifecycle.Unlock()
|
||||
n.startClientStartup(accountID, entry.client)
|
||||
}()
|
||||
// Attempt to start the client in the background; if this fails we will
|
||||
// retry on the first request via RoundTrip. runClientStartup uses its
|
||||
// own background context so the caller's request-scoped ctx can't
|
||||
// cancel the inbound bring-up.
|
||||
go n.runClientStartup(accountID, entry.client)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *NetBird) startClientStartup(accountID types.AccountID, client *embed.Client) {
|
||||
if n.startClient != nil {
|
||||
n.startClient(accountID, client)
|
||||
return
|
||||
}
|
||||
n.runClientStartup(accountID, client)
|
||||
}
|
||||
|
||||
// registerExistingClient registers the service against an already-present
|
||||
// client for the account and returns true when it did. It notifies management
|
||||
// of the new service when the client is already started.
|
||||
func (n *NetBird) registerExistingClient(accountID types.AccountID, key ServiceKey, si serviceInfo) bool {
|
||||
n.clientsMux.Lock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
n.clientsMux.Unlock()
|
||||
return false
|
||||
}
|
||||
entry.services[key] = si
|
||||
started := entry.started
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).Debug("registered service with existing client")
|
||||
|
||||
if started && n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(context.Background(), accountID, si.serviceID, true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_key": key,
|
||||
}).WithError(err).Warn("failed to notify status for existing client")
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// accountLifecycle returns the per-account lifecycle mutex, serialising client
|
||||
// creation against teardown so a slow client.Stop cannot race a new
|
||||
// client.Start for the same account, without blocking clientsMux.
|
||||
func (n *NetBird) accountLifecycle(accountID types.AccountID) *sync.Mutex {
|
||||
mu, _ := n.lifecycleMu.LoadOrStore(accountID, &sync.Mutex{})
|
||||
return mu.(*sync.Mutex)
|
||||
}
|
||||
|
||||
// createClientEntry generates a WireGuard keypair, authenticates with management,
|
||||
// and creates an embedded NetBird client. Must be called with the account's
|
||||
// lifecycle mutex held.
|
||||
// and creates an embedded NetBird client. Must be called with clientsMux held.
|
||||
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, key ServiceKey, authToken string, si serviceInfo) (*clientEntry, error) {
|
||||
serviceID := si.serviceID
|
||||
n.logger.WithFields(log.Fields{
|
||||
@@ -321,9 +276,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
"public_key": publicKey.String(),
|
||||
}).Debug("authenticating new proxy peer with management")
|
||||
|
||||
createCtx, cancel := context.WithTimeout(ctx, createProxyPeerTimeout)
|
||||
defer cancel()
|
||||
resp, err := n.mgmtClient.CreateProxyPeer(createCtx, &proto.CreateProxyPeerRequest{
|
||||
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
|
||||
ServiceId: string(serviceID),
|
||||
AccountId: string(accountID),
|
||||
Token: authToken,
|
||||
@@ -491,15 +444,6 @@ func (n *NetBird) notifyClientReady(accountID types.AccountID, client *embed.Cli
|
||||
// RemovePeer unregisters a service from an account. The client is only stopped
|
||||
// when no services are using it anymore.
|
||||
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key ServiceKey) error {
|
||||
lifecycle := n.accountLifecycle(accountID)
|
||||
lifecycle.Lock()
|
||||
transferred := false
|
||||
defer func() {
|
||||
if !transferred {
|
||||
lifecycle.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
n.clientsMux.Lock()
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
@@ -522,8 +466,17 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
|
||||
delete(entry.services, key)
|
||||
|
||||
stopClient := len(entry.services) == 0
|
||||
var client *embed.Client
|
||||
var transport, insecureTransport *http.Transport
|
||||
var inbound any
|
||||
var stopHandler func(types.AccountID, any)
|
||||
if stopClient {
|
||||
n.logger.WithField("account_id", accountID).Info("stopping client, no more services")
|
||||
client = entry.client
|
||||
transport = entry.transport
|
||||
insecureTransport = entry.insecureTransport
|
||||
inbound = entry.inbound
|
||||
stopHandler = n.stopHandler
|
||||
delete(n.clients, accountID)
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
@@ -537,40 +490,19 @@ func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, key
|
||||
n.notifyDisconnect(ctx, accountID, key, si.serviceID)
|
||||
|
||||
if stopClient {
|
||||
transferred = true
|
||||
go n.stopClientLocked(accountID, lifecycle, entry)
|
||||
if inbound != nil && stopHandler != nil {
|
||||
stopHandler(accountID, inbound)
|
||||
}
|
||||
transport.CloseIdleConnections()
|
||||
insecureTransport.CloseIdleConnections()
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stopClientLocked releases a client's resources off the caller's goroutine so a
|
||||
// slow client.Stop cannot wedge the mapping receive loop (which calls RemovePeer
|
||||
// synchronously). It unlocks lifecycle when done so a new client.Start for the
|
||||
// same account waits for this teardown.
|
||||
func (n *NetBird) stopClientLocked(accountID types.AccountID, lifecycle *sync.Mutex, entry *clientEntry) {
|
||||
defer lifecycle.Unlock()
|
||||
|
||||
if entry.inbound != nil && n.stopHandler != nil {
|
||||
n.stopHandler(accountID, entry.inbound)
|
||||
}
|
||||
if entry.transport != nil {
|
||||
entry.transport.CloseIdleConnections()
|
||||
}
|
||||
if entry.insecureTransport != nil {
|
||||
entry.insecureTransport.CloseIdleConnections()
|
||||
}
|
||||
if entry.client == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), clientStopTimeout)
|
||||
defer cancel()
|
||||
if err := entry.client.Stop(ctx); err != nil {
|
||||
n.logger.WithField("account_id", accountID).WithError(err).Warn("failed to stop netbird client")
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NetBird) notifyDisconnect(ctx context.Context, accountID types.AccountID, key ServiceKey, serviceID types.ServiceID) {
|
||||
if n.statusNotifier == nil {
|
||||
return
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -23,18 +22,6 @@ func (m *mockMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxy
|
||||
return &proto.CreateProxyPeerResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
// signalMgmtClient closes entered the first time CreateProxyPeer is called, so
|
||||
// tests can detect AddPeer reaching client creation.
|
||||
type signalMgmtClient struct {
|
||||
entered chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (m *signalMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
|
||||
m.once.Do(func() { close(m.entered) })
|
||||
return &proto.CreateProxyPeerResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
type mockStatusNotifier struct {
|
||||
mu sync.Mutex
|
||||
statuses []statusCall
|
||||
@@ -65,15 +52,11 @@ func (m *mockStatusNotifier) calls() []statusCall {
|
||||
// mockNetBird creates a NetBird instance for testing without actually connecting.
|
||||
// It uses an invalid management URL to prevent real connections.
|
||||
func mockNetBird() *NetBird {
|
||||
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
return NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
MgmtAddr: "http://invalid.test:9999",
|
||||
WGPort: 0,
|
||||
PreSharedKey: "",
|
||||
}, nil, nil, &mockMgmtClient{})
|
||||
// Skip the real embed client.Start, which would hang against the unreachable
|
||||
// mgmt URL and (now that the lifecycle lock spans startup) serialise removes.
|
||||
nb.startClient = func(types.AccountID, *embed.Client) {}
|
||||
return nb
|
||||
}
|
||||
|
||||
func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
|
||||
@@ -305,7 +288,6 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
||||
WGPort: 0,
|
||||
PreSharedKey: "",
|
||||
}, nil, notifier, &mockMgmtClient{})
|
||||
nb.startClient = func(types.AccountID, *embed.Client) {}
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add first service — creates a new client entry.
|
||||
@@ -390,117 +372,6 @@ func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
|
||||
assert.False(t, calls[0].connected)
|
||||
}
|
||||
|
||||
// TestNetBird_RemovePeer_TeardownIsAsync proves the fix for the receive-loop
|
||||
// stall: RemovePeer must return promptly even when the client teardown blocks,
|
||||
// because teardown runs off the caller's goroutine. The receive loop calls
|
||||
// RemovePeer synchronously, so a blocking teardown inline would wedge it.
|
||||
func TestNetBird_RemovePeer_TeardownIsAsync(t *testing.T) {
|
||||
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
MgmtAddr: "http://invalid.test:9999",
|
||||
}, nil, &mockStatusNotifier{}, &mockMgmtClient{})
|
||||
|
||||
accountID := types.AccountID("acct-async-teardown")
|
||||
key := DomainServiceKey("svc.example")
|
||||
|
||||
teardownEntered := make(chan struct{})
|
||||
releaseTeardown := make(chan struct{})
|
||||
nb.SetClientLifecycle(nil, func(types.AccountID, any) {
|
||||
close(teardownEntered)
|
||||
<-releaseTeardown
|
||||
})
|
||||
|
||||
nb.clientsMux.Lock()
|
||||
nb.clients[accountID] = &clientEntry{
|
||||
services: map[ServiceKey]serviceInfo{key: {serviceID: types.ServiceID("svc-1")}},
|
||||
started: true,
|
||||
inbound: struct{}{},
|
||||
}
|
||||
nb.clientsMux.Unlock()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() { done <- nb.RemovePeer(context.Background(), accountID, key) }()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
require.NoError(t, err)
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("RemovePeer did not return while teardown was blocked — teardown is not async")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-teardownEntered:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("teardown never ran")
|
||||
}
|
||||
|
||||
close(releaseTeardown)
|
||||
}
|
||||
|
||||
// TestNetBird_AddPeer_WaitsForTeardown proves the lifecycle lock serialises a
|
||||
// new client bringup behind an in-flight teardown for the same account, so a
|
||||
// slow client.Stop can never race a new client.Start for that account.
|
||||
//
|
||||
// It targets the handoff race specifically: AddPeer is launched immediately
|
||||
// after RemovePeer returns, WITHOUT waiting for the teardown goroutine to start.
|
||||
// This only passes if RemovePeer acquires the lifecycle lock synchronously
|
||||
// (before returning) and hands it to the teardown goroutine — if the goroutine
|
||||
// acquired the lock itself, AddPeer could win the lock in this window and start
|
||||
// a replacement client while the old teardown is still pending.
|
||||
func TestNetBird_AddPeer_WaitsForTeardown(t *testing.T) {
|
||||
nb := NewNetBird(context.Background(), "test-proxy", "invalid.test", ClientConfig{
|
||||
MgmtAddr: "http://invalid.test:9999",
|
||||
}, nil, &mockStatusNotifier{}, &mockMgmtClient{})
|
||||
nb.startClient = func(types.AccountID, *embed.Client) {}
|
||||
|
||||
accountID := types.AccountID("acct-serialize")
|
||||
key := DomainServiceKey("svc.example")
|
||||
|
||||
addEntered := make(chan struct{})
|
||||
releaseTeardown := make(chan struct{})
|
||||
nb.SetClientLifecycle(nil, func(types.AccountID, any) {
|
||||
// Block teardown until released. If AddPeer ever reaches createClientEntry
|
||||
// (signalled via the mgmt client below) while we hold the lock, the lock
|
||||
// failed to serialise and the test fails before we release.
|
||||
<-releaseTeardown
|
||||
})
|
||||
|
||||
nb.clientsMux.Lock()
|
||||
nb.clients[accountID] = &clientEntry{
|
||||
services: map[ServiceKey]serviceInfo{key: {serviceID: types.ServiceID("svc-1")}},
|
||||
started: true,
|
||||
inbound: struct{}{},
|
||||
}
|
||||
nb.clientsMux.Unlock()
|
||||
|
||||
// createClientEntry calls CreateProxyPeer; closing addEntered there tells us
|
||||
// AddPeer got past the lifecycle lock and into client creation.
|
||||
nb.mgmtClient = &signalMgmtClient{entered: addEntered}
|
||||
|
||||
require.NoError(t, nb.RemovePeer(context.Background(), accountID, key))
|
||||
|
||||
// Launch AddPeer with NO synchronisation against the teardown goroutine.
|
||||
addReturned := make(chan struct{})
|
||||
go func() {
|
||||
_ = nb.AddPeer(context.Background(), accountID, DomainServiceKey("svc2.example"), "key-2", types.ServiceID("svc-2"))
|
||||
close(addReturned)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-addEntered:
|
||||
t.Fatal("AddPeer entered client creation while teardown held the lifecycle lock — handoff race not closed")
|
||||
case <-addReturned:
|
||||
t.Fatal("AddPeer completed while teardown held the lifecycle lock — not serialised")
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
}
|
||||
|
||||
close(releaseTeardown)
|
||||
select {
|
||||
case <-addReturned:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("AddPeer never completed after teardown released the lifecycle lock")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNotifyClientReady_UsesBackgroundCtx pins the contract that the
|
||||
// post-Start hooks (readyHandler + statusNotifier.NotifyStatus) run on
|
||||
// a fresh context.Background() rather than inheriting the AddPeer
|
||||
|
||||
@@ -114,10 +114,6 @@ type Config struct {
|
||||
MaxDialTimeout time.Duration
|
||||
// MaxSessionIdleTimeout caps the per-service session idle timeout.
|
||||
MaxSessionIdleTimeout time.Duration
|
||||
// MappingBatchWatchdog bounds how long a single mapping batch may spend
|
||||
// being applied before the receive loop reconnects to resync. Zero falls
|
||||
// back to the internal default.
|
||||
MappingBatchWatchdog time.Duration
|
||||
|
||||
// GeoDataDir is the directory containing GeoLite2 MMDB files.
|
||||
GeoDataDir string
|
||||
@@ -168,7 +164,6 @@ func New(ctx context.Context, cfg Config) *Server {
|
||||
Private: cfg.Private,
|
||||
MaxDialTimeout: cfg.MaxDialTimeout,
|
||||
MaxSessionIdleTimeout: cfg.MaxSessionIdleTimeout,
|
||||
MappingBatchWatchdog: cfg.MappingBatchWatchdog,
|
||||
GeoDataDir: cfg.GeoDataDir,
|
||||
CrowdSecAPIURL: cfg.CrowdSecAPIURL,
|
||||
CrowdSecAPIKey: cfg.CrowdSecAPIKey,
|
||||
|
||||
@@ -1,282 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// blockingMgmtClient implements roundtrip's managementClient interface.
|
||||
// CreateProxyPeer parks until release is closed, signalling entry on entered.
|
||||
// This reproduces the confirmed real-world stall: createClientEntry calls
|
||||
// CreateProxyPeer synchronously while holding clientsMux, and the proxy's
|
||||
// receive loop calls that path synchronously inside processMappings.
|
||||
type blockingMgmtClient struct {
|
||||
entered chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (b *blockingMgmtClient) CreateProxyPeer(ctx context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
|
||||
b.once.Do(func() { close(b.entered) })
|
||||
// Park until the caller's context is cancelled. In production this ctx is
|
||||
// the gRPC mapping-stream context with no per-call timeout, so a slow or
|
||||
// unresponsive CreateProxyPeer parks the receive loop here indefinitely.
|
||||
<-ctx.Done()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
// gatedMappingStream is a mock GetMappingUpdate client stream that hands out a
|
||||
// pre-seeded list of messages, then records how many times Recv advanced. It
|
||||
// lets the test observe whether the single-threaded receive loop ever gets
|
||||
// past the first (blocking) batch to pull the second message.
|
||||
type gatedMappingStream struct {
|
||||
grpc.ClientStream
|
||||
messages []*proto.GetMappingUpdateResponse
|
||||
idx int32
|
||||
}
|
||||
|
||||
func (g *gatedMappingStream) Recv() (*proto.GetMappingUpdateResponse, error) {
|
||||
i := int(atomic.LoadInt32(&g.idx))
|
||||
if i >= len(g.messages) {
|
||||
// Block instead of returning EOF so the loop doesn't exit; we only
|
||||
// care whether the loop ever reaches this second Recv at all.
|
||||
select {}
|
||||
}
|
||||
msg := g.messages[i]
|
||||
atomic.AddInt32(&g.idx, 1)
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (g *gatedMappingStream) deliveredCount() int32 { return atomic.LoadInt32(&g.idx) }
|
||||
|
||||
func (g *gatedMappingStream) Header() (metadata.MD, error) { return nil, nil } //nolint:nilnil
|
||||
func (g *gatedMappingStream) Trailer() metadata.MD { return nil }
|
||||
func (g *gatedMappingStream) CloseSend() error { return nil }
|
||||
func (g *gatedMappingStream) Context() context.Context { return context.Background() }
|
||||
func (g *gatedMappingStream) SendMsg(any) error { return nil }
|
||||
func (g *gatedMappingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
// noopNotifier satisfies roundtrip's statusNotifier interface.
|
||||
type noopNotifier struct{}
|
||||
|
||||
func (noopNotifier) NotifyStatus(context.Context, types.AccountID, types.ServiceID, bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// noopProxyClient is a proto.ProxyServiceClient that no-ops the one method the
|
||||
// teardown unwind reaches (SendStatusUpdate, via notifyError when the parked
|
||||
// AddPeer is cancelled). The embedded nil interface satisfies the rest at
|
||||
// compile time; none of those methods are called by this test.
|
||||
type noopProxyClient struct {
|
||||
proto.ProxyServiceClient
|
||||
}
|
||||
|
||||
func (noopProxyClient) SendStatusUpdate(context.Context, *proto.SendStatusUpdateRequest, ...grpc.CallOption) (*proto.SendStatusUpdateResponse, error) {
|
||||
return &proto.SendStatusUpdateResponse{}, nil
|
||||
}
|
||||
|
||||
// TestMappingStream_StallsWhenApplyBlocks proves the deadlock: the proxy's
|
||||
// mapping receive loop processes batches strictly serially, so when applying
|
||||
// one batch blocks (here: createClientEntry parked on a synchronous
|
||||
// CreateProxyPeer call, exactly as observed in production), the loop never
|
||||
// advances to Recv the next batch. Management can keep sending updates onto
|
||||
// the stream with no error and no channel overflow, yet the proxy applies
|
||||
// nothing further — it is stuck.
|
||||
func TestMappingStream_StallsWhenApplyBlocks(t *testing.T) {
|
||||
logger := log.New()
|
||||
logger.SetLevel(log.PanicLevel)
|
||||
|
||||
mgmt := &blockingMgmtClient{
|
||||
entered: make(chan struct{}),
|
||||
}
|
||||
|
||||
nb := roundtrip.NewNetBird(
|
||||
context.Background(),
|
||||
"proxy-test",
|
||||
"proxy.example.com",
|
||||
roundtrip.ClientConfig{},
|
||||
logger,
|
||||
noopNotifier{},
|
||||
mgmt,
|
||||
)
|
||||
|
||||
s := &Server{
|
||||
Logger: logger,
|
||||
netbird: nb,
|
||||
mgmtClient: noopProxyClient{},
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
}
|
||||
|
||||
// First batch: a CREATED mapping for a brand-new account. addMapping ->
|
||||
// netbird.AddPeer -> createClientEntry -> CreateProxyPeer, which blocks.
|
||||
// Empty Path keeps setupHTTPMapping a no-op (it returns early), so the
|
||||
// ONLY blocking point is the synchronous CreateProxyPeer in AddPeer —
|
||||
// no routers/auth need wiring. The second batch exists only to detect
|
||||
// whether the loop ever advances past the blocked first batch.
|
||||
stream := &gatedMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "svc-1",
|
||||
AccountId: "acct-1",
|
||||
AuthToken: "token-1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "svc-2",
|
||||
AccountId: "acct-2",
|
||||
AuthToken: "token-2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
// Unblock the parked apply on teardown via ctx (CreateProxyPeer returns
|
||||
// ctx.Err()), so the wedged loop goroutine unwinds before embed.New —
|
||||
// avoiding any dependency on collaborators this test deliberately leaves
|
||||
// nil. The deadlock is fully proven before this fires.
|
||||
t.Cleanup(cancel)
|
||||
|
||||
loopDone := make(chan struct{})
|
||||
syncDone := false
|
||||
go func() {
|
||||
defer close(loopDone)
|
||||
_ = s.handleMappingStream(ctx, stream, &syncDone, time.Time{})
|
||||
}()
|
||||
|
||||
// The loop must reach the blocking apply for the first batch.
|
||||
select {
|
||||
case <-mgmt.entered:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("receive loop never reached CreateProxyPeer for the first batch")
|
||||
}
|
||||
|
||||
// THE DEADLOCK: while the first batch is parked in CreateProxyPeer, the
|
||||
// single-threaded loop cannot advance. The second batch is never pulled,
|
||||
// even though it is already available on the stream. Give it ample time.
|
||||
// deliveredCount is atomic; syncDone is intentionally not read here because
|
||||
// the loop goroutine owns it (reading it from the test would race).
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
assert.Equal(t, int32(1), stream.deliveredCount(),
|
||||
"loop must NOT consume the second batch while the first is blocked in apply — proxy is stuck")
|
||||
|
||||
select {
|
||||
case <-loopDone:
|
||||
t.Fatal("receive loop returned while it should be wedged in apply")
|
||||
default:
|
||||
// Still wedged, as expected.
|
||||
}
|
||||
}
|
||||
|
||||
// TestMappingStream_StallsWhenRemoveBlocks proves the deadlock for the REMOVE
|
||||
// path observed in production: a mapping remove tears down the account's last
|
||||
// embedded client via netbird.RemovePeer -> client.Stop -> Engine.Stop, whose
|
||||
// jobExecutorWG.Wait() is unbounded. Because the receive loop is single-
|
||||
// threaded, a blocked remove wedges the loop: no further mapping updates of any
|
||||
// kind (create/modify/remove) are applied, while management keeps sending them
|
||||
// successfully (no send error, no channel-full). Matches the reported symptom:
|
||||
// the last log line is a remove that stops a client, then silence.
|
||||
func TestMappingStream_StallsWhenRemoveBlocks(t *testing.T) {
|
||||
logger := log.New()
|
||||
logger.SetLevel(log.PanicLevel)
|
||||
|
||||
enteredRemove := make(chan struct{})
|
||||
blockRemove := make(chan struct{})
|
||||
var once sync.Once
|
||||
|
||||
s := &Server{
|
||||
Logger: logger,
|
||||
mgmtClient: noopProxyClient{},
|
||||
routerReady: closedChan(),
|
||||
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
|
||||
// Stand in for netbird.RemovePeer -> client.Stop hanging on
|
||||
// Engine.Stop's unbounded jobExecutorWG.Wait(). Only the first remove
|
||||
// blocks; later removes return immediately so the recovery assertion
|
||||
// can observe the loop advancing.
|
||||
removePeer: func(ctx context.Context, _ types.AccountID, _ roundtrip.ServiceKey) error {
|
||||
first := false
|
||||
once.Do(func() {
|
||||
first = true
|
||||
close(enteredRemove)
|
||||
})
|
||||
if !first {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-blockRemove:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Batch 1 removes a service (blocks in teardown). Batch 2 is a later update
|
||||
// that must never be applied while the remove is wedged.
|
||||
stream := &gatedMappingStream{
|
||||
messages: []*proto.GetMappingUpdateResponse{
|
||||
{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "svc-1", AccountId: "acct-1"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
{Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED, Id: "svc-2", AccountId: "acct-1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
loopDone := make(chan struct{})
|
||||
syncDone := false
|
||||
go func() {
|
||||
defer close(loopDone)
|
||||
_ = s.handleMappingStream(context.Background(), stream, &syncDone, time.Time{})
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-enteredRemove:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("receive loop never reached the blocking remove for the first batch")
|
||||
}
|
||||
|
||||
// THE DEADLOCK: the loop is parked in the blocked remove and cannot advance.
|
||||
// syncDone is owned by the loop goroutine, so it is not read here.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
assert.Equal(t, int32(1), stream.deliveredCount(),
|
||||
"loop must NOT consume the second batch while the first remove is blocked — proxy is stuck")
|
||||
|
||||
select {
|
||||
case <-loopDone:
|
||||
t.Fatal("receive loop returned while it should be wedged on the remove")
|
||||
default:
|
||||
}
|
||||
|
||||
// Unblock and confirm the wedge was solely the blocked remove: the loop
|
||||
// then advances and consumes the next batch.
|
||||
close(blockRemove)
|
||||
assert.Eventually(t, func() bool {
|
||||
return stream.deliveredCount() >= 2
|
||||
}, 2*time.Second, 5*time.Millisecond,
|
||||
"once the remove unblocks, the loop must advance and consume the next batch")
|
||||
}
|
||||
154
proxy/server.go
154
proxy/server.go
@@ -24,7 +24,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pires/go-proxyproto"
|
||||
prometheus2 "github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
@@ -76,30 +75,29 @@ type portRouter struct {
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
ctx context.Context
|
||||
mgmtClient proto.ProxyServiceClient
|
||||
proxy *proxy.ReverseProxy
|
||||
netbird *roundtrip.NetBird
|
||||
acme *acme.Manager
|
||||
staticCertWatcher *certwatch.Watcher
|
||||
auth *auth.Middleware
|
||||
http *http.Server
|
||||
https *http.Server
|
||||
debug *http.Server
|
||||
healthServer *health.Server
|
||||
healthChecker *health.Checker
|
||||
meter *proxymetrics.Metrics
|
||||
accessLog *accesslog.Logger
|
||||
mainRouter *nbtcp.Router
|
||||
mainPort uint16
|
||||
udpMu sync.Mutex
|
||||
udpRelays map[types.ServiceID]*udprelay.Relay
|
||||
udpRelayWg sync.WaitGroup
|
||||
portMu sync.RWMutex
|
||||
portRouters map[uint16]*portRouter
|
||||
svcPorts map[types.ServiceID][]uint16
|
||||
lastMappings map[types.ServiceID]*proto.ProxyMapping
|
||||
portRouterWg sync.WaitGroup
|
||||
ctx context.Context
|
||||
mgmtClient proto.ProxyServiceClient
|
||||
proxy *proxy.ReverseProxy
|
||||
netbird *roundtrip.NetBird
|
||||
acme *acme.Manager
|
||||
auth *auth.Middleware
|
||||
http *http.Server
|
||||
https *http.Server
|
||||
debug *http.Server
|
||||
healthServer *health.Server
|
||||
healthChecker *health.Checker
|
||||
meter *proxymetrics.Metrics
|
||||
accessLog *accesslog.Logger
|
||||
mainRouter *nbtcp.Router
|
||||
mainPort uint16
|
||||
udpMu sync.Mutex
|
||||
udpRelays map[types.ServiceID]*udprelay.Relay
|
||||
udpRelayWg sync.WaitGroup
|
||||
portMu sync.RWMutex
|
||||
portRouters map[uint16]*portRouter
|
||||
svcPorts map[types.ServiceID][]uint16
|
||||
lastMappings map[types.ServiceID]*proto.ProxyMapping
|
||||
portRouterWg sync.WaitGroup
|
||||
|
||||
// hijackTracker tracks hijacked connections (e.g. WebSocket upgrades)
|
||||
// so they can be closed during graceful shutdown, since http.Server.Shutdown
|
||||
@@ -120,9 +118,6 @@ type Server struct {
|
||||
// The mapping worker waits on this before processing updates.
|
||||
routerReady chan struct{}
|
||||
|
||||
// removePeer defaults to netbird.RemovePeer; overridable in tests.
|
||||
removePeer func(ctx context.Context, accountID types.AccountID, key roundtrip.ServiceKey) error
|
||||
|
||||
// inbound, when non-nil, manages per-account inbound listeners. Set by
|
||||
// initPrivateInbound only when Private is true so the standalone
|
||||
// proxy keeps its zero-overhead default path.
|
||||
@@ -232,10 +227,6 @@ type Server struct {
|
||||
// Zero means no cap (the proxy honors whatever management sends).
|
||||
// Set via NB_PROXY_MAX_SESSION_IDLE_TIMEOUT for shared deployments.
|
||||
MaxSessionIdleTimeout time.Duration
|
||||
// MappingBatchWatchdog bounds how long a single mapping batch may spend
|
||||
// in processMappings before the receive loop reconnects to resync.
|
||||
// Zero uses defaultMappingBatchWatchdog.
|
||||
MappingBatchWatchdog time.Duration
|
||||
}
|
||||
|
||||
// clampIdleTimeout returns d capped to MaxSessionIdleTimeout when configured.
|
||||
@@ -616,7 +607,7 @@ func (s *Server) initDefaults() {
|
||||
|
||||
// If no ID is set then one can be generated.
|
||||
if s.ID == "" {
|
||||
s.ID = fmt.Sprintf("netbird-proxy-%s", uuid.NewString())
|
||||
s.ID = "netbird-proxy-" + s.startTime.Format("20060102150405")
|
||||
}
|
||||
// Fallback version option in case it is not set.
|
||||
if s.Version == "" {
|
||||
@@ -794,7 +785,6 @@ func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) {
|
||||
return nil, fmt.Errorf("initialize certificate watcher: %w", err)
|
||||
}
|
||||
go certWatcher.Watch(ctx)
|
||||
s.staticCertWatcher = certWatcher
|
||||
tlsConfig.GetCertificate = certWatcher.GetCertificate
|
||||
return tlsConfig, nil
|
||||
}
|
||||
@@ -1182,30 +1172,24 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
|
||||
s.healthChecker.SetManagementConnected(false)
|
||||
}
|
||||
|
||||
connected := false
|
||||
onConnected := func() { connected = true }
|
||||
|
||||
var streamErr error
|
||||
if syncSupported {
|
||||
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone, onConnected)
|
||||
streamErr = s.trySyncMappings(ctx, client, &initialSyncDone)
|
||||
if isSyncUnimplemented(streamErr) {
|
||||
syncSupported = false
|
||||
s.Logger.Info("management does not support SyncMappings, falling back to GetMappingUpdate")
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone, onConnected)
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
|
||||
}
|
||||
} else {
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone, onConnected)
|
||||
streamErr = s.tryGetMappingUpdate(ctx, client, &initialSyncDone)
|
||||
}
|
||||
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(false)
|
||||
}
|
||||
|
||||
// Reset backoff only when a stream actually connected, so immediate
|
||||
// connect failures still back off instead of spinning.
|
||||
if connected {
|
||||
bo.Reset()
|
||||
}
|
||||
// Stream established — reset backoff so the next failure retries quickly.
|
||||
bo.Reset()
|
||||
|
||||
if streamErr == nil {
|
||||
return fmt.Errorf("stream closed by server")
|
||||
@@ -1237,7 +1221,7 @@ func (s *Server) proxyCapabilities() *proto.ProxyCapabilities {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool, onConnected func()) error {
|
||||
func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
|
||||
connectTime := time.Now()
|
||||
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
|
||||
ProxyId: s.ID,
|
||||
@@ -1250,7 +1234,6 @@ func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServ
|
||||
return fmt.Errorf("create mapping stream: %w", err)
|
||||
}
|
||||
|
||||
onConnected()
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(true)
|
||||
}
|
||||
@@ -1259,7 +1242,7 @@ func (s *Server) tryGetMappingUpdate(ctx context.Context, client proto.ProxyServ
|
||||
return s.handleMappingStream(ctx, mappingClient, initialSyncDone, connectTime)
|
||||
}
|
||||
|
||||
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool, onConnected func()) error {
|
||||
func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceClient, initialSyncDone *bool) error {
|
||||
connectTime := time.Now()
|
||||
stream, err := client.SyncMappings(ctx)
|
||||
if err != nil {
|
||||
@@ -1280,7 +1263,6 @@ func (s *Server) trySyncMappings(ctx context.Context, client proto.ProxyServiceC
|
||||
return fmt.Errorf("send sync init: %w", err)
|
||||
}
|
||||
|
||||
onConnected()
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetManagementConnected(true)
|
||||
}
|
||||
@@ -1325,9 +1307,7 @@ func (s *Server) handleSyncMappingsStream(ctx context.Context, stream proto.Prox
|
||||
|
||||
batchStart := time.Now()
|
||||
s.Logger.Debug("Received mapping update, starting processing")
|
||||
if err := s.processMappingsGuarded(ctx, msg.GetMapping()); err != nil {
|
||||
return err
|
||||
}
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
|
||||
|
||||
@@ -1411,9 +1391,7 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
|
||||
|
||||
batchStart := time.Now()
|
||||
s.Logger.Debug("Received mapping update, starting processing")
|
||||
if err := s.processMappingsGuarded(ctx, msg.GetMapping()); err != nil {
|
||||
return err
|
||||
}
|
||||
s.processMappings(ctx, msg.GetMapping())
|
||||
s.Logger.Debug("Processing mapping update completed")
|
||||
tracker.recordBatch(ctx, s, msg.GetMapping(), msg.GetInitialSyncComplete(), batchStart)
|
||||
}
|
||||
@@ -1478,44 +1456,6 @@ func redactMappingForLog(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
return c
|
||||
}
|
||||
|
||||
const defaultMappingBatchWatchdog = 2 * time.Minute
|
||||
|
||||
// mappingBatchWatchdog returns the configured batch watchdog or the default.
|
||||
func (s *Server) mappingBatchWatchdog() time.Duration {
|
||||
if s.MappingBatchWatchdog > 0 {
|
||||
return s.MappingBatchWatchdog
|
||||
}
|
||||
return defaultMappingBatchWatchdog
|
||||
}
|
||||
|
||||
// processMappingsGuarded applies a batch under a watchdog, returning an error
|
||||
// if processing exceeds the watchdog so the caller reconnects and resyncs
|
||||
// instead of wedging silently.
|
||||
func (s *Server) processMappingsGuarded(ctx context.Context, mappings []*proto.ProxyMapping) error {
|
||||
batchCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
s.processMappings(batchCtx, mappings)
|
||||
}()
|
||||
|
||||
watchdog := s.mappingBatchWatchdog()
|
||||
timer := time.NewTimer(watchdog)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
s.Logger.Errorf("processing mapping batch exceeded %s, cancelling and reconnecting to resync", watchdog)
|
||||
return fmt.Errorf("mapping batch processing stalled after %s", watchdog)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
|
||||
debug := s.Logger != nil && s.Logger.IsLevelEnabled(log.DebugLevel)
|
||||
for _, mapping := range mappings {
|
||||
@@ -1626,8 +1566,6 @@ func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMappi
|
||||
var wildcardHit bool
|
||||
if s.acme != nil {
|
||||
wildcardHit = s.acme.AddDomain(d, accountID, svcID)
|
||||
} else {
|
||||
wildcardHit = s.staticCertCovers(d)
|
||||
}
|
||||
httpRoute := nbtcp.Route{
|
||||
Type: nbtcp.RouteHTTP,
|
||||
@@ -1652,26 +1590,6 @@ func (s *Server) setupHTTPMapping(ctx context.Context, mapping *proto.ProxyMappi
|
||||
return nil
|
||||
}
|
||||
|
||||
// staticCertCovers reports whether the static certificate loaded when ACME is
|
||||
// disabled covers the given domain, making it certificate-ready immediately —
|
||||
// the equivalent of a wildcard hit in the ACME path. Domains the certificate
|
||||
// does not cover are logged: clients connecting to them will get TLS errors.
|
||||
func (s *Server) staticCertCovers(d domain.Domain) bool {
|
||||
if s.staticCertWatcher == nil {
|
||||
return false
|
||||
}
|
||||
leaf := s.staticCertWatcher.Leaf()
|
||||
if leaf == nil {
|
||||
return false
|
||||
}
|
||||
name := d.PunycodeString()
|
||||
if err := leaf.VerifyHostname(name); err != nil {
|
||||
s.Logger.Warnf("static certificate (SANs %v) does not cover domain %q: %v", leaf.DNSNames, name, err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// setupTCPMapping sets up a TCP port-forwarding fallback route on the listen port.
|
||||
func (s *Server) setupTCPMapping(ctx context.Context, mapping *proto.ProxyMapping) error {
|
||||
svcID := types.ServiceID(mapping.GetId())
|
||||
@@ -2033,11 +1951,7 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
|
||||
func (s *Server) removeMapping(ctx context.Context, mapping *proto.ProxyMapping) {
|
||||
accountID := types.AccountID(mapping.GetAccountId())
|
||||
svcKey := s.serviceKeyForMapping(mapping)
|
||||
removePeer := s.removePeer
|
||||
if removePeer == nil {
|
||||
removePeer = s.netbird.RemovePeer
|
||||
}
|
||||
if err := removePeer(ctx, accountID, svcKey); err != nil {
|
||||
if err := s.netbird.RemovePeer(ctx, accountID, svcKey); err != nil {
|
||||
s.Logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_id": mapping.GetId(),
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/certwatch"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
func generateCertWithSANs(t *testing.T, dnsNames []string) (certPEM, keyPEM []byte) {
|
||||
t.Helper()
|
||||
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: dnsNames[0]},
|
||||
DNSNames: dnsNames,
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||||
require.NoError(t, err)
|
||||
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||
|
||||
keyDER, err := x509.MarshalECPrivateKey(key)
|
||||
require.NoError(t, err)
|
||||
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||
|
||||
return certPEM, keyPEM
|
||||
}
|
||||
|
||||
func newStaticWatcher(t *testing.T, dnsNames []string) *certwatch.Watcher {
|
||||
t.Helper()
|
||||
|
||||
dir := t.TempDir()
|
||||
certPEM, keyPEM := generateCertWithSANs(t, dnsNames)
|
||||
certPath := filepath.Join(dir, "tls.crt")
|
||||
keyPath := filepath.Join(dir, "tls.key")
|
||||
require.NoError(t, os.WriteFile(certPath, certPEM, 0o600))
|
||||
require.NoError(t, os.WriteFile(keyPath, keyPEM, 0o600))
|
||||
|
||||
w, err := certwatch.NewWatcher(certPath, keyPath, quietLifecycleLogger())
|
||||
require.NoError(t, err)
|
||||
return w
|
||||
}
|
||||
|
||||
func TestStaticCertCovers(t *testing.T) {
|
||||
s := &Server{
|
||||
Logger: quietLifecycleLogger(),
|
||||
staticCertWatcher: newStaticWatcher(t, []string{"*.p.example.com", "exact.example.com"}),
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
domain string
|
||||
covered bool
|
||||
}{
|
||||
{"svc.p.example.com", true},
|
||||
{"exact.example.com", true},
|
||||
{"a.b.p.example.com", false}, // wildcard does not span labels
|
||||
{"p.example.com", false},
|
||||
{"other.example.com", false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.domain, func(t *testing.T) {
|
||||
assert.Equal(t, tc.covered, s.staticCertCovers(domain.Domain(tc.domain)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticCertCoversNoWatcher(t *testing.T) {
|
||||
s := &Server{Logger: quietLifecycleLogger()}
|
||||
assert.False(t, s.staticCertCovers(domain.Domain("svc.p.example.com")))
|
||||
}
|
||||
@@ -417,30 +417,15 @@ if type uname >/dev/null 2>&1; then
|
||||
# Check the availability of a compatible package manager
|
||||
if check_use_bin_variable; then
|
||||
PACKAGE_MANAGER="bin"
|
||||
elif [ -e /run/ostree-booted ]; then
|
||||
if [ -x "$(command -v rpm-ostree)" ]; then
|
||||
PACKAGE_MANAGER="rpm-ostree"
|
||||
echo "The installation will be performed using rpm-ostree package manager"
|
||||
elif [ -x "$(command -v bootc)" ]; then
|
||||
echo "Detected bootc system without rpm-ostree." >&2
|
||||
echo "NetBird cannot be installed via package manager on this system." >&2
|
||||
echo "Options:" >&2
|
||||
echo " 1. Install via Distrobox (instructions in the installation docs)" >&2
|
||||
echo " 2. Rebuild your base image with rpm-ostree included" >&2
|
||||
echo " 3. Bake NetBird into your Containerfile" >&2
|
||||
exit 1
|
||||
else
|
||||
echo "Detected ostree-booted system without rpm-ostree or bootc." >&2
|
||||
echo "NetBird cannot be installed automatically on this atomic system." >&2
|
||||
echo "Please install NetBird by rebuilding your base image or use a supported package manager." >&2
|
||||
exit 1
|
||||
fi
|
||||
elif [ -x "$(command -v apt-get)" ]; then
|
||||
PACKAGE_MANAGER="apt"
|
||||
echo "The installation will be performed using apt package manager"
|
||||
elif [ -x "$(command -v dnf)" ]; then
|
||||
PACKAGE_MANAGER="dnf"
|
||||
echo "The installation will be performed using dnf package manager"
|
||||
elif [ -x "$(command -v rpm-ostree)" ]; then
|
||||
PACKAGE_MANAGER="rpm-ostree"
|
||||
echo "The installation will be performed using rpm-ostree package manager"
|
||||
elif [ -x "$(command -v yum)" ]; then
|
||||
PACKAGE_MANAGER="yum"
|
||||
echo "The installation will be performed using yum package manager"
|
||||
|
||||
@@ -6,5 +6,4 @@ const (
|
||||
RoleKey = "role"
|
||||
UserIDKey = "userID"
|
||||
PeerIDKey = "peerID"
|
||||
UserAgentKey = "userAgent"
|
||||
)
|
||||
|
||||
@@ -322,21 +322,15 @@ func TestClient_Sync(t *testing.T) {
|
||||
if resp.GetNetbirdConfig() == nil {
|
||||
t.Error("expecting non nil NetbirdConfig got nil")
|
||||
}
|
||||
// we test network map peers from 0.29.3 and dev builds
|
||||
if len(resp.GetRemotePeers()) != 0 {
|
||||
t.Error("expecting top-level RemotePeers to be empty for v0.29.3+ clients")
|
||||
}
|
||||
networkMap := resp.GetNetworkMap()
|
||||
if len(networkMap.GetRemotePeers()) != 1 {
|
||||
t.Errorf("expecting RemotePeers size %d got %d", 1, len(networkMap.GetRemotePeers()))
|
||||
if len(resp.GetRemotePeers()) != 1 {
|
||||
t.Errorf("expecting RemotePeers size %d got %d", 1, len(resp.GetRemotePeers()))
|
||||
return
|
||||
}
|
||||
|
||||
if networkMap.GetRemotePeersIsEmpty() {
|
||||
if resp.GetRemotePeersIsEmpty() == true {
|
||||
t.Error("expecting RemotePeers property to be false, got true")
|
||||
}
|
||||
if networkMap.GetRemotePeers()[0].GetWgPubKey() != remoteKey.PublicKey().String() {
|
||||
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), networkMap.GetRemotePeers()[0].GetWgPubKey())
|
||||
if resp.GetRemotePeers()[0].GetWgPubKey() != remoteKey.PublicKey().String() {
|
||||
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), resp.GetRemotePeers()[0].GetWgPubKey())
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Error("timeout waiting for test to finish")
|
||||
|
||||
@@ -5107,63 +5107,31 @@ components:
|
||||
responses:
|
||||
not_found:
|
||||
description: Resource not found
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
validation_failed_simple:
|
||||
description: Validation failed
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
bad_request:
|
||||
description: Bad Request
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
internal_error:
|
||||
description: Internal Server Error
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
validation_failed:
|
||||
description: Validation failed
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
forbidden:
|
||||
description: Forbidden
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
requires_authentication:
|
||||
description: Requires authentication
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content: { }
|
||||
conflict:
|
||||
description: Conflict
|
||||
headers:
|
||||
X-Request-Id:
|
||||
$ref: '#/components/headers/X-Request-Id'
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
headers:
|
||||
X-Request-Id:
|
||||
description: |
|
||||
Unique identifier assigned to the request by the server and set on every
|
||||
response. Useful for correlating client requests with server-side logs.
|
||||
schema:
|
||||
type: string
|
||||
example: cot7r4n3l3vh3qj4qveg
|
||||
securitySchemes:
|
||||
BearerAuth:
|
||||
type: http
|
||||
|
||||
@@ -9,14 +9,12 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer"
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
"github.com/netbirdio/netbird/shared/relay/healthcheck"
|
||||
"github.com/netbirdio/netbird/shared/relay/messages"
|
||||
)
|
||||
@@ -174,19 +172,6 @@ type Client struct {
|
||||
stateSubscription *PeersStateSubscription
|
||||
|
||||
mtu uint16
|
||||
|
||||
// transportFallback, when set, records datagram-too-large failures so a
|
||||
// datagram-sized transport is avoided on subsequent connects. Shared via
|
||||
// the manager.
|
||||
transportFallback *transportFallback
|
||||
// datagramFallbackTriggered guards a single fallback per connection so a
|
||||
// burst of oversized datagrams triggers one reconnect, not many.
|
||||
datagramFallbackTriggered atomic.Bool
|
||||
}
|
||||
|
||||
// SetTransportFallback wires the shared datagram-transport fallback tracker.
|
||||
func (c *Client) SetTransportFallback(tf *transportFallback) {
|
||||
c.transportFallback = tf
|
||||
}
|
||||
|
||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||
@@ -376,13 +361,12 @@ func (c *Client) Close() error {
|
||||
}
|
||||
|
||||
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
mode := transportModeFromEnv()
|
||||
dialers := c.getDialers(mode)
|
||||
dialers := c.getDialers()
|
||||
|
||||
var conn net.Conn
|
||||
if c.serverIP.IsValid() {
|
||||
var err error
|
||||
conn, err = c.dialRaceDirect(ctx, mode, dialers)
|
||||
conn, err = c.dialRaceDirect(ctx, dialers)
|
||||
if err != nil {
|
||||
c.log.Infof("dial via server IP %s failed, falling back to FQDN: %v", c.serverIP, err)
|
||||
conn = nil
|
||||
@@ -391,9 +375,6 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
|
||||
if conn == nil {
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, dialers...)
|
||||
if mode.sequential() {
|
||||
rd.WithSequential()
|
||||
}
|
||||
var err error
|
||||
conn, err = rd.Dial(ctx)
|
||||
if err != nil {
|
||||
@@ -401,7 +382,6 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
}
|
||||
}
|
||||
c.relayConn = conn
|
||||
c.datagramFallbackTriggered.Store(false)
|
||||
|
||||
instanceURL, err := c.handShake(ctx)
|
||||
if err != nil {
|
||||
@@ -416,7 +396,7 @@ func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
}
|
||||
|
||||
// dialRaceDirect dials c.serverIP, preserving the original FQDN as the TLS ServerName for SNI.
|
||||
func (c *Client) dialRaceDirect(ctx context.Context, mode TransportMode, dialers []dialer.DialeFn) (net.Conn, error) {
|
||||
func (c *Client) dialRaceDirect(ctx context.Context, dialers []dialer.DialeFn) (net.Conn, error) {
|
||||
directURL, serverName, err := substituteHost(c.connectionURL, c.serverIP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("substitute host: %w", err)
|
||||
@@ -426,9 +406,6 @@ func (c *Client) dialRaceDirect(ctx context.Context, mode TransportMode, dialers
|
||||
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, directURL, dialers...).
|
||||
WithServerName(serverName)
|
||||
if mode.sequential() {
|
||||
rd.WithSequential()
|
||||
}
|
||||
return rd.Dial(ctx)
|
||||
}
|
||||
|
||||
@@ -654,53 +631,13 @@ func (c *Client) writeTo(containerRef *connContainer, dstID messages.PeerID, pay
|
||||
}
|
||||
|
||||
// the write always return with 0 length because the underling does not support the size feedback.
|
||||
conn := c.relayConn
|
||||
_, err = conn.Write(msg)
|
||||
_, err = c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
if errors.Is(err, netErr.ErrDatagramTooLarge) {
|
||||
c.onDatagramTooLarge(conn, err)
|
||||
} else {
|
||||
c.log.Errorf("failed to write transport message: %s", err)
|
||||
}
|
||||
c.log.Errorf("failed to write transport message: %s", err)
|
||||
}
|
||||
return len(payload), err
|
||||
}
|
||||
|
||||
// onDatagramTooLarge reacts to a datagram rejected as too large for the path.
|
||||
// When a non-datagram transport is available, it records a fallback for this
|
||||
// server and closes the connection so the reconnect avoids datagram-sized
|
||||
// transports. A single fallback is triggered per connection regardless of how
|
||||
// many oversized datagrams arrive. cause carries the datagram size and budget.
|
||||
func (c *Client) onDatagramTooLarge(conn net.Conn, cause error) {
|
||||
// Handle one oversized datagram per connection; a burst triggers a single
|
||||
// fallback (and a single log line), not many.
|
||||
if !c.datagramFallbackTriggered.CompareAndSwap(false, true) {
|
||||
return
|
||||
}
|
||||
|
||||
// If the selected mode offers no non-datagram transport (e.g. pinned to a
|
||||
// datagram-sized transport), reconnecting would just re-fail, so leave the
|
||||
// connection up rather than loop.
|
||||
if len(nonDatagramSized(c.baseDialers(transportModeFromEnv()))) == 0 {
|
||||
c.log.Warnf("%s, but no non-datagram transport is available, not falling back", cause)
|
||||
return
|
||||
}
|
||||
|
||||
// Without the shared tracker a reconnect would just select the same
|
||||
// transport again and re-fail, so leave the connection up rather than loop.
|
||||
if c.transportFallback == nil {
|
||||
c.log.Debugf("%s, but no transport fallback configured, leaving connection up", cause)
|
||||
return
|
||||
}
|
||||
|
||||
window := c.transportFallback.recordFailure(c.connectionURL)
|
||||
c.log.Warnf("%s, avoiding datagram-sized transport for %s", cause, window)
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
c.log.Debugf("close relay connection for transport fallback: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
|
||||
for {
|
||||
select {
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
package dialer
|
||||
|
||||
// DatagramSized is implemented by dialers whose connections carry each write in
|
||||
// a single datagram, so a write can be rejected when it exceeds the path's
|
||||
// datagram budget (e.g. QUIC). Transports without this capability (e.g.
|
||||
// WebSocket over TCP) impose no per-write size limit, so the relay client can
|
||||
// fall back to them when a datagram-sized transport rejects a write as too
|
||||
// large. The capability is advertised per dialer rather than hardcoded, so a
|
||||
// new transport only needs to declare whether it is datagram-sized.
|
||||
type DatagramSized interface {
|
||||
DatagramSized()
|
||||
}
|
||||
|
||||
// IsDatagramSized reports whether d produces datagram-sized connections.
|
||||
func IsDatagramSized(d DialeFn) bool {
|
||||
_, ok := d.(DatagramSized)
|
||||
return ok
|
||||
}
|
||||
@@ -4,9 +4,4 @@ import "errors"
|
||||
|
||||
var (
|
||||
ErrClosedByServer = errors.New("closed by server")
|
||||
|
||||
// ErrDatagramTooLarge is returned when a transport message exceeds the
|
||||
// QUIC datagram size the path to the relay can carry. The relay client
|
||||
// treats it as a signal to fall back to a non-datagram transport.
|
||||
ErrDatagramTooLarge = errors.New("datagram frame too large")
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
)
|
||||
@@ -51,8 +52,11 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
if err := c.session.SendDatagram(b); err != nil {
|
||||
return 0, c.writeErrHandling(err, len(b))
|
||||
err := c.session.SendDatagram(b)
|
||||
if err != nil {
|
||||
err = c.remoteCloseErrHandling(err)
|
||||
log.Errorf("failed to write to QUIC stream: %v", err)
|
||||
return 0, err
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
@@ -91,15 +95,3 @@ func (c *Conn) remoteCloseErrHandling(err error) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// writeErrHandling normalizes SendDatagram errors. A datagram that exceeds the
|
||||
// path's QUIC packet budget is mapped to ErrDatagramTooLarge (annotated with the
|
||||
// datagram size and path budget) so the relay client can fall back to a
|
||||
// non-datagram transport.
|
||||
func (c *Conn) writeErrHandling(err error, size int) error {
|
||||
var tooLarge *quic.DatagramTooLargeError
|
||||
if errors.As(err, &tooLarge) {
|
||||
return fmt.Errorf("%w: %d byte datagram over path budget %d", netErr.ErrDatagramTooLarge, size, tooLarge.MaxDatagramPayloadSize)
|
||||
}
|
||||
return c.remoteCloseErrHandling(err)
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/quic-go/quic-go/logging"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
@@ -24,12 +23,6 @@ func (d Dialer) Protocol() string {
|
||||
return Network
|
||||
}
|
||||
|
||||
// DatagramSized marks QUIC as a datagram-sized transport: relay traffic is
|
||||
// carried in QUIC DATAGRAM frames, which must fit a single packet.
|
||||
func (d Dialer) DatagramSized() {
|
||||
// Intentional marker method; presence is the capability signal.
|
||||
}
|
||||
|
||||
func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn, error) {
|
||||
quicURL, err := prepareURL(address)
|
||||
if err != nil {
|
||||
@@ -54,7 +47,6 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
|
||||
MaxIdleTimeout: 4 * time.Minute,
|
||||
EnableDatagrams: true,
|
||||
InitialPacketSize: nbRelay.QUICInitialPacketSize,
|
||||
Tracer: connectionTracer(quicURL),
|
||||
}
|
||||
|
||||
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{Port: 0})
|
||||
@@ -82,28 +74,6 @@ func (d Dialer) Dial(ctx context.Context, address, serverName string) (net.Conn,
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// connectionTracer returns a QUIC tracer that logs the DPLPMTUD result and the
|
||||
// reason a relay connection closed, so the path MTU settled on and teardown
|
||||
// cause are visible in logs. Lines carry the relay address as a structured
|
||||
// field, matching the rest of the relay client logging.
|
||||
func connectionTracer(addr string) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
relayLog := log.WithField("relay", addr)
|
||||
return func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
|
||||
return &logging.ConnectionTracer{
|
||||
UpdatedMTU: func(mtu logging.ByteCount, done bool) {
|
||||
if done {
|
||||
relayLog.Infof("QUIC path MTU settled at %d", mtu)
|
||||
return
|
||||
}
|
||||
relayLog.Debugf("QUIC path MTU probing at %d", mtu)
|
||||
},
|
||||
ClosedConnection: func(err error) {
|
||||
relayLog.Debugf("QUIC connection closed: %v", err)
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func prepareURL(address string) (string, error) {
|
||||
var host string
|
||||
var defaultPort string
|
||||
|
||||
@@ -32,7 +32,6 @@ type RaceDial struct {
|
||||
serverName string
|
||||
dialerFns []DialeFn
|
||||
connectionTimeout time.Duration
|
||||
sequential bool
|
||||
}
|
||||
|
||||
func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
|
||||
@@ -54,21 +53,7 @@ func (r *RaceDial) WithServerName(serverName string) *RaceDial {
|
||||
return r
|
||||
}
|
||||
|
||||
// WithSequential makes Dial try the dialers in order, falling back to the next
|
||||
// only when one fails to connect, instead of racing them concurrently.
|
||||
//
|
||||
// Mutates the receiver and is not safe for concurrent reconfiguration; a
|
||||
// RaceDial is intended to be constructed per dial and discarded.
|
||||
func (r *RaceDial) WithSequential() *RaceDial {
|
||||
r.sequential = true
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
|
||||
if r.sequential {
|
||||
return r.dialSequential(ctx)
|
||||
}
|
||||
|
||||
connChan := make(chan dialResult, len(r.dialerFns))
|
||||
winnerConn := make(chan net.Conn, 1)
|
||||
abortCtx, abort := context.WithCancel(ctx)
|
||||
@@ -87,30 +72,6 @@ func (r *RaceDial) Dial(ctx context.Context) (net.Conn, error) {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// dialSequential tries each dialer in order, returning the first connection and
|
||||
// falling back to the next on failure.
|
||||
func (r *RaceDial) dialSequential(ctx context.Context) (net.Conn, error) {
|
||||
for _, dfn := range r.dialerFns {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attemptCtx, cancel := context.WithTimeout(ctx, r.connectionTimeout)
|
||||
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
|
||||
conn, err := dfn.Dial(attemptCtx, r.serverURL, r.serverName)
|
||||
cancel()
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, err
|
||||
}
|
||||
r.log.Errorf("failed to dial via %s: %s", dfn.Protocol(), err)
|
||||
continue
|
||||
}
|
||||
r.log.Infof("successfully dialed via: %s", dfn.Protocol())
|
||||
return conn, nil
|
||||
}
|
||||
return nil, errors.New("failed to dial to Relay server on any protocol")
|
||||
}
|
||||
|
||||
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
|
||||
ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -250,66 +250,3 @@ func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialSequentialFallback(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
var firstDialed, secondDialed bool
|
||||
preferred := &MockDialer{
|
||||
protocolStr: "quic",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
firstDialed = true
|
||||
return nil, errors.New("quic unreachable")
|
||||
},
|
||||
}
|
||||
fallbackConn := &MockConn{remoteAddr: &MockAddr{network: "ws"}}
|
||||
fallback := &MockDialer{
|
||||
protocolStr: "ws",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
secondDialed = true
|
||||
return fallbackConn, nil
|
||||
},
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
|
||||
conn, err := rd.Dial(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expected fallback to succeed, got %v", err)
|
||||
}
|
||||
if conn != fallbackConn {
|
||||
t.Errorf("expected fallback connection, got %v", conn)
|
||||
}
|
||||
if !firstDialed || !secondDialed {
|
||||
t.Errorf("expected both dialers attempted in order, first=%v second=%v", firstDialed, secondDialed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialSequentialPreferredWins(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
preferredConn := &MockConn{remoteAddr: &MockAddr{network: "quic"}}
|
||||
preferred := &MockDialer{
|
||||
protocolStr: "quic",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return preferredConn, nil
|
||||
},
|
||||
}
|
||||
fallback := &MockDialer{
|
||||
protocolStr: "ws",
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
t.Errorf("fallback dialer must not be tried when preferred succeeds")
|
||||
return nil, errors.New("should not happen")
|
||||
},
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, preferred, fallback).WithSequential()
|
||||
conn, err := rd.Dial(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("expected preferred to succeed, got %v", err)
|
||||
}
|
||||
if conn != preferredConn {
|
||||
t.Errorf("expected preferred connection, got %v", conn)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,42 +9,11 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
|
||||
)
|
||||
|
||||
// getDialers returns the ordered dialers for connecting to the relay server. It
|
||||
// applies the datagram fallback generically: if this server recently rejected a
|
||||
// datagram-sized transport, those dialers are dropped, leaving the rest.
|
||||
func (c *Client) getDialers(mode TransportMode) []dialer.DialeFn {
|
||||
dialers := c.baseDialers(mode)
|
||||
|
||||
if c.transportFallback != nil && c.transportFallback.avoidDatagramSized(c.connectionURL) {
|
||||
if filtered := nonDatagramSized(dialers); len(filtered) > 0 {
|
||||
c.log.Infof("relay recently rejected a datagram-sized transport, avoiding it")
|
||||
return filtered
|
||||
}
|
||||
}
|
||||
return dialers
|
||||
}
|
||||
|
||||
// baseDialers returns the ordered dialers for the mode, before any datagram
|
||||
// fallback filtering. For racing modes (auto) the order is irrelevant; for
|
||||
// prefer modes the first entry is tried before falling back to the second.
|
||||
func (c *Client) baseDialers(mode TransportMode) []dialer.DialeFn {
|
||||
switch mode {
|
||||
case TransportModeWS:
|
||||
c.log.Infof("%s=ws, using WebSocket transport", EnvRelayTransport)
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
case TransportModeQUIC:
|
||||
c.log.Infof("%s=quic, using QUIC transport", EnvRelayTransport)
|
||||
return []dialer.DialeFn{quic.Dialer{}}
|
||||
}
|
||||
|
||||
all := []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
|
||||
if mode == TransportModePreferWS {
|
||||
all = []dialer.DialeFn{ws.Dialer{}, quic.Dialer{}}
|
||||
}
|
||||
|
||||
// getDialers returns the list of dialers to use for connecting to the relay server.
|
||||
func (c *Client) getDialers() []dialer.DialeFn {
|
||||
if c.mtu > 0 && c.mtu > iface.DefaultMTU {
|
||||
c.log.Infof("MTU %d exceeds default (%d), avoiding datagram-sized transports", c.mtu, iface.DefaultMTU)
|
||||
return nonDatagramSized(all)
|
||||
c.log.Infof("MTU %d exceeds default (%d), forcing WebSocket transport to avoid DATAGRAM frame size issues", c.mtu, iface.DefaultMTU)
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
}
|
||||
return all
|
||||
return []dialer.DialeFn{quic.Dialer{}, ws.Dialer{}}
|
||||
}
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
//go:build !js
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer"
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/quic"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
|
||||
)
|
||||
|
||||
// TestDatagramSizedCapability locks the capability the generic fallback relies
|
||||
// on: QUIC is datagram-sized, WebSocket is not.
|
||||
func TestDatagramSizedCapability(t *testing.T) {
|
||||
assert.True(t, dialer.IsDatagramSized(quic.Dialer{}), "QUIC must advertise datagram-sized")
|
||||
assert.False(t, dialer.IsDatagramSized(ws.Dialer{}), "WebSocket must not advertise datagram-sized")
|
||||
}
|
||||
|
||||
func protocols(dialers []dialer.DialeFn) []string {
|
||||
out := make([]string, len(dialers))
|
||||
for i, d := range dialers {
|
||||
out[i] = d.Protocol()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestGetDialers(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mode string
|
||||
mtu uint16
|
||||
preferWS bool
|
||||
want []string
|
||||
}{
|
||||
{name: "auto races quic and ws", mode: "auto", mtu: iface.DefaultMTU, want: []string{"quic", "WS"}},
|
||||
{name: "ws pinned", mode: "ws", mtu: iface.DefaultMTU, want: []string{"WS"}},
|
||||
{name: "quic pinned", mode: "quic", mtu: iface.DefaultMTU, want: []string{"quic"}},
|
||||
{name: "prefer-quic orders quic first", mode: "prefer-quic", mtu: iface.DefaultMTU, want: []string{"quic", "WS"}},
|
||||
{name: "prefer-ws orders ws first", mode: "prefer-ws", mtu: iface.DefaultMTU, want: []string{"WS", "quic"}},
|
||||
{name: "mtu above default forces ws", mode: "auto", mtu: iface.DefaultMTU + 100, want: []string{"WS"}},
|
||||
{name: "sticky fallback forces ws in auto", mode: "auto", mtu: iface.DefaultMTU, preferWS: true, want: []string{"WS"}},
|
||||
{name: "sticky fallback forces ws in prefer-quic", mode: "prefer-quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"WS"}},
|
||||
{name: "quic pin overrides sticky fallback", mode: "quic", mtu: iface.DefaultMTU, preferWS: true, want: []string{"quic"}},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv(EnvRelayTransport, tc.mode)
|
||||
if tc.mode == "" {
|
||||
os.Unsetenv(EnvRelayTransport)
|
||||
}
|
||||
|
||||
tf := newTransportFallback()
|
||||
if tc.preferWS {
|
||||
tf.recordFailure(url)
|
||||
}
|
||||
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
mtu: tc.mtu,
|
||||
transportFallback: tf,
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.want, protocols(c.getDialers(transportModeFromEnv())))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStickyFallbackAfterDatagramTooLarge verifies the full chain: an oversized
|
||||
// datagram records a fallback that makes the next dial pick WebSocket, the way a
|
||||
// reconnect would after the connection is closed.
|
||||
func TestStickyFallbackAfterDatagramTooLarge(t *testing.T) {
|
||||
const url = "rels://relay.example:443"
|
||||
t.Setenv(EnvRelayTransport, string(TransportModeAuto))
|
||||
|
||||
c := &Client{
|
||||
log: log.WithField("test", t.Name()),
|
||||
connectionURL: url,
|
||||
mtu: iface.DefaultMTU,
|
||||
transportFallback: newTransportFallback(),
|
||||
}
|
||||
|
||||
// First dial races both transports.
|
||||
assert.Equal(t, []string{"quic", "WS"}, protocols(c.getDialers(transportModeFromEnv())))
|
||||
|
||||
// An oversized datagram records the fallback for this server.
|
||||
c.onDatagramTooLarge(&closeTrackingConn{}, netErr.ErrDatagramTooLarge)
|
||||
|
||||
// The reconnect now sticks to WebSocket.
|
||||
assert.Equal(t, []string{"WS"}, protocols(c.getDialers(transportModeFromEnv())))
|
||||
}
|
||||
@@ -7,11 +7,7 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
|
||||
)
|
||||
|
||||
func (c *Client) getDialers(_ TransportMode) []dialer.DialeFn {
|
||||
func (c *Client) getDialers() []dialer.DialeFn {
|
||||
// JS/WASM build only uses WebSocket transport
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
}
|
||||
|
||||
func (c *Client) baseDialers(_ TransportMode) []dialer.DialeFn {
|
||||
return []dialer.DialeFn{ws.Dialer{}}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user