mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-16 12:59:55 +00:00
Compare commits
6 Commits
prepare-fo
...
mdm_integr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b2c5732847 | ||
|
|
0340893854 | ||
|
|
874195440c | ||
|
|
bec26d5a14 | ||
|
|
db2c9b6f49 | ||
|
|
cd777395f2 |
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
@@ -75,6 +76,13 @@ type Client struct {
|
||||
connectClient *internal.ConnectClient
|
||||
config *profilemanager.Config
|
||||
cacheDir string
|
||||
|
||||
// mdmLoader holds the per-Client MDM policy source. Set by
|
||||
// SetMDMPolicyFetcher (called from the Kotlin side). Each Run
|
||||
// passes this loader to the resolved Config so applyMDMPolicy
|
||||
// picks up the active overlay. Nil means "MDM enforcement off
|
||||
// for this Client".
|
||||
mdmLoader *mdm.Loader
|
||||
}
|
||||
|
||||
func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) {
|
||||
@@ -129,6 +137,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.applyMDMOverlay(cfg)
|
||||
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
||||
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
|
||||
|
||||
@@ -173,6 +182,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.applyMDMOverlay(cfg)
|
||||
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
|
||||
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
|
||||
|
||||
@@ -230,6 +240,7 @@ func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (strin
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
c.applyMDMOverlay(cfg)
|
||||
cacheDir = platformFiles.CacheDir()
|
||||
}
|
||||
|
||||
|
||||
80
client/android/mdm.go
Normal file
80
client/android/mdm.go
Normal file
@@ -0,0 +1,80 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
)
|
||||
|
||||
// PolicyFetcher is the mobile-side bridge for the MDM managed-config
|
||||
// snapshot. The native layer (Kotlin) implements this and registers
|
||||
// the instance per Client via Client.SetMDMPolicyFetcher. Every
|
||||
// invocation of fetchJSON must read the current RestrictionsManager
|
||||
// state and return the result as a JSON-encoded map[string]any string.
|
||||
//
|
||||
// JSON is used because gomobile does not support map[string]any
|
||||
// crossing the JNI boundary — the adapter on the Go side parses the
|
||||
// string back into the map[string]any expected by mdm.Loader.
|
||||
//
|
||||
// Return value contract:
|
||||
// - "" (empty) : interpreted as "no MDM source / no managed keys"
|
||||
// - "{}" : managed config explicitly empty
|
||||
// - "{...}" : JSON object with key/value pairs
|
||||
// - malformed JSON : logged and treated as empty
|
||||
type PolicyFetcher interface {
|
||||
FetchJSON() string
|
||||
}
|
||||
|
||||
// jsonFetcherAdapter wraps a gomobile-exposed PolicyFetcher into the
|
||||
// internal mdm.PolicyFetcher interface, taking care of JSON decoding
|
||||
// on every Fetch.
|
||||
type jsonFetcherAdapter struct {
|
||||
inner PolicyFetcher
|
||||
}
|
||||
|
||||
func (a *jsonFetcherAdapter) Fetch() map[string]any {
|
||||
raw := a.inner.FetchJSON()
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
var out map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &out); err != nil {
|
||||
log.Warnf("MDM mobile fetcher: invalid JSON payload from native: %v", err)
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// SetMDMPolicyFetcher registers the native-provided MDM policy fetcher
|
||||
// on this Client. Call once from the gomobile-init code (Kotlin
|
||||
// Application.onCreate or Service onCreate) before invoking Run /
|
||||
// RunWithoutLogin. Passing nil disables MDM enforcement on this
|
||||
// Client.
|
||||
//
|
||||
// The fetcher is held as a *mdm.Loader instance on the Client (no
|
||||
// package-level state) — multiple Clients in the same process get
|
||||
// independent Loaders, and tests can inject fakes per Client.
|
||||
func (c *Client) SetMDMPolicyFetcher(p PolicyFetcher) {
|
||||
if p == nil {
|
||||
c.mdmLoader = mdm.NewLoader(nil)
|
||||
return
|
||||
}
|
||||
c.mdmLoader = mdm.NewLoader(&jsonFetcherAdapter{inner: p})
|
||||
}
|
||||
|
||||
// applyMDMOverlay applies the Client-held MDM Loader's current policy
|
||||
// on top of the just-read Config. Called immediately after every
|
||||
// UpdateOrCreateConfig — profilemanager's apply() initialises the
|
||||
// policy to empty and leaves overlay responsibility to the lifecycle
|
||||
// owner. No-op when no fetcher was registered.
|
||||
func (c *Client) applyMDMOverlay(cfg *profilemanager.Config) {
|
||||
if cfg == nil || c.mdmLoader == nil {
|
||||
return
|
||||
}
|
||||
cfg.ApplyMDMPolicy(c.mdmLoader.Load())
|
||||
}
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
@@ -248,6 +249,11 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
||||
if err != nil {
|
||||
return fmt.Errorf("read config file %s: %v", configFilePath, err)
|
||||
}
|
||||
// CLI standalone login: profilemanager no longer auto-applies MDM,
|
||||
// so layer in the OS-native policy here. Desktop builds construct
|
||||
// a Loader with no fetcher — the build-tagged loadPlatform reads
|
||||
// the registry/plist directly.
|
||||
config.ApplyMDMPolicy(mdm.NewLoader(nil).Load())
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
|
||||
if err != nil {
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -187,6 +188,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
if err != nil {
|
||||
return fmt.Errorf("get config file: %v", err)
|
||||
}
|
||||
// CLI foreground path runs without the daemon Server: layer in the
|
||||
// active MDM policy explicitly so a forced ManagementURL / PSK /
|
||||
// other managed key actually takes effect on this run.
|
||||
config.ApplyMDMPolicy(mdm.NewLoader(nil).Load())
|
||||
|
||||
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -215,6 +216,10 @@ func New(opts Options) (*Client, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create config: %w", err)
|
||||
}
|
||||
// Embedded path runs without the daemon Server: apply the active
|
||||
// MDM policy explicitly so a forced ManagementURL / PSK / other
|
||||
// managed key takes effect on this embedded engine instance.
|
||||
config.ApplyMDMPolicy(mdm.NewLoader(nil).Load())
|
||||
|
||||
if opts.PrivateKey != "" {
|
||||
config.PrivateKey = opts.PrivateKey
|
||||
|
||||
@@ -58,10 +58,6 @@ var DefaultInterfaceBlacklist = []string{
|
||||
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
|
||||
}
|
||||
|
||||
// loadMDMPolicy is the package-level indirection used by apply() to read the
|
||||
// active MDM policy. Tests override this to inject a fake policy.
|
||||
var loadMDMPolicy = mdm.LoadPolicy
|
||||
|
||||
// ConfigInput carries configuration changes to the client
|
||||
type ConfigInput struct {
|
||||
ManagementURL string
|
||||
@@ -180,14 +176,27 @@ type Config struct {
|
||||
|
||||
MTU uint16
|
||||
|
||||
// policy is the MDM policy that produced the currently-set values for
|
||||
// any MDM-enforced fields. Set by applyMDMPolicy at the tail of apply()
|
||||
// and reset on every apply() invocation. Never persisted to disk.
|
||||
// Callers query enforcement state via Policy() and the mdm.Policy API
|
||||
// (HasKey, ManagedKeys, IsEmpty).
|
||||
// policy is the MDM policy that produced the currently-set values
|
||||
// for any MDM-enforced fields. Set by ApplyMDMPolicy on every
|
||||
// invocation. Never persisted to disk. Callers query enforcement
|
||||
// state via Policy() and the mdm.Policy API (HasKey, ManagedKeys,
|
||||
// IsEmpty).
|
||||
policy *mdm.Policy `json:"-"`
|
||||
}
|
||||
|
||||
// ApplyMDMPolicy overlays the supplied MDM Policy on top of the
|
||||
// currently resolved Config values. Idempotent — pass an empty Policy
|
||||
// to clear any prior overlay. The lifecycle owner (Server.getConfig
|
||||
// on desktop, the Client.Run path on mobile) calls this with
|
||||
// loader.Load() once the per-process Loader is known; the Config
|
||||
// itself holds no reference to the Loader.
|
||||
func (config *Config) ApplyMDMPolicy(policy *mdm.Policy) {
|
||||
if config == nil {
|
||||
return
|
||||
}
|
||||
config.applyMDMPolicy(policy)
|
||||
}
|
||||
|
||||
// Policy returns the MDM policy applied to this Config. Returns a non-nil
|
||||
// empty Policy when MDM enforcement is inactive; callers can always invoke
|
||||
// HasKey / ManagedKeys / IsEmpty without a nil check.
|
||||
@@ -634,9 +643,11 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
// MDM is the last override layer: any key present in the policy
|
||||
// supersedes defaults, on-disk config, env vars and CLI input.
|
||||
config.applyMDMPolicy(loadMDMPolicy())
|
||||
// Initialise the MDM overlay to "no enforcement" so Config.Policy()
|
||||
// never returns a stale or nil policy on a freshly applied Config.
|
||||
// Lifecycle owners that want to enforce a real MDM policy invoke
|
||||
// Config.ApplyMDMPolicy(loader.Load()) after this returns.
|
||||
config.applyMDMPolicy(mdm.NewPolicy(nil))
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
@@ -10,24 +10,58 @@ import (
|
||||
"github.com/netbirdio/netbird/client/mdm"
|
||||
)
|
||||
|
||||
// withMDMPolicy temporarily overrides the package-level loadMDMPolicy hook so
|
||||
// apply() observes the supplied Policy. The original loader is restored at
|
||||
// test cleanup.
|
||||
func withMDMPolicy(t *testing.T, policy *mdm.Policy) {
|
||||
// fakeFetcher implements mdm.PolicyFetcher returning a pre-set policy
|
||||
// map. Test helper used to construct a Loader without touching the OS
|
||||
// or any package-level state.
|
||||
type fakeFetcher struct{ values map[string]any }
|
||||
|
||||
func (f *fakeFetcher) Fetch() map[string]any { return f.values }
|
||||
|
||||
// loaderFor builds an mdm.Loader whose loadPlatform returns the
|
||||
// supplied Policy's underlying values.
|
||||
func loaderFor(policy *mdm.Policy) *mdm.Loader {
|
||||
if policy == nil || policy.IsEmpty() {
|
||||
return mdm.NewLoader(&fakeFetcher{values: nil})
|
||||
}
|
||||
values := make(map[string]any)
|
||||
for _, k := range policy.ManagedKeys() {
|
||||
if v, ok := policy.GetString(k); ok {
|
||||
values[k] = v
|
||||
continue
|
||||
}
|
||||
if v, ok := policy.GetBool(k); ok {
|
||||
values[k] = v
|
||||
continue
|
||||
}
|
||||
if v, ok := policy.GetInt(k); ok {
|
||||
values[k] = v
|
||||
continue
|
||||
}
|
||||
if v, ok := policy.GetStringSlice(k); ok {
|
||||
values[k] = v
|
||||
}
|
||||
}
|
||||
return mdm.NewLoader(&fakeFetcher{values: values})
|
||||
}
|
||||
|
||||
// configWithMDM is the test convenience that builds a Config via
|
||||
// UpdateOrCreateConfig and overlays the supplied MDM policy on top —
|
||||
// mirrors the production pattern (Server.getConfig / Client.applyMDMOverlay)
|
||||
// where the Loader lives outside Config and the apply step is driven
|
||||
// by the lifecycle owner.
|
||||
func configWithMDM(t *testing.T, input ConfigInput, policy *mdm.Policy) *Config {
|
||||
t.Helper()
|
||||
prev := loadMDMPolicy
|
||||
loadMDMPolicy = func() *mdm.Policy { return policy }
|
||||
t.Cleanup(func() { loadMDMPolicy = prev })
|
||||
cfg, err := UpdateOrCreateConfig(input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
cfg.ApplyMDMPolicy(loaderFor(policy).Load())
|
||||
return cfg
|
||||
}
|
||||
|
||||
func TestApply_MDMEmpty_NoEnforcement(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(nil))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
cfg := configWithMDM(t, ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
}, mdm.NewPolicy(nil))
|
||||
|
||||
assert.True(t, cfg.Policy().IsEmpty(), "no MDM source ⇒ empty Policy")
|
||||
assert.False(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
|
||||
@@ -39,18 +73,15 @@ func TestApply_MDMEmpty_NoEnforcement(t *testing.T) {
|
||||
|
||||
func TestApply_MDMOnly_OverridesDefaults(t *testing.T) {
|
||||
const mdmURL = "https://corp.mdm.example.com:443"
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
|
||||
cfg := configWithMDM(t, ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
}, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: mdmURL,
|
||||
mdm.KeyDisableClientRoutes: true,
|
||||
mdm.KeyBlockInbound: true,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.Equal(t, mdmURL, cfg.ManagementURL.String())
|
||||
assert.True(t, cfg.DisableClientRoutes)
|
||||
assert.True(t, cfg.BlockInbound)
|
||||
@@ -65,16 +96,12 @@ func TestApply_MDMBeatsCLIInput(t *testing.T) {
|
||||
const mdmURL = "https://mdm.example.com:443"
|
||||
const cliURL = "https://cli.example.com:443"
|
||||
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: mdmURL,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
cfg := configWithMDM(t, ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
ManagementURL: cliURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
}, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: mdmURL,
|
||||
}))
|
||||
|
||||
// MDM wins over CLI-supplied management URL.
|
||||
assert.Equal(t, mdmURL, cfg.ManagementURL.String())
|
||||
@@ -82,16 +109,12 @@ func TestApply_MDMBeatsCLIInput(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestApply_MDMInvalidURL_KeepsPreviousValue(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
cfg := configWithMDM(t, ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
}, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "not-a-url",
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// Invalid MDM URL is logged and skipped: default URL stays in place
|
||||
// to keep the client functional.
|
||||
assert.Equal(t, DefaultManagementURL, cfg.ManagementURL.String())
|
||||
@@ -106,24 +129,20 @@ func TestApply_MDMBoolKeysOverrideOnDiskValue(t *testing.T) {
|
||||
tmp := filepath.Join(t.TempDir(), "config.json")
|
||||
|
||||
// Seed without MDM.
|
||||
withMDMPolicy(t, mdm.NewPolicy(nil))
|
||||
_, err := UpdateOrCreateConfig(ConfigInput{
|
||||
configWithMDM(t, ConfigInput{
|
||||
ConfigPath: tmp,
|
||||
DisableClientRoutes: boolPtr(false),
|
||||
RosenpassEnabled: boolPtr(false),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}, mdm.NewPolicy(nil))
|
||||
|
||||
// Now enable MDM enforcement for these keys.
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
cfg := configWithMDM(t, ConfigInput{
|
||||
ConfigPath: tmp,
|
||||
}, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyDisableClientRoutes: true,
|
||||
mdm.KeyRosenpassEnabled: true,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{ConfigPath: tmp})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.True(t, cfg.DisableClientRoutes, "MDM override should flip on-disk false to true")
|
||||
assert.True(t, cfg.RosenpassEnabled)
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyDisableClientRoutes))
|
||||
@@ -133,16 +152,12 @@ func TestApply_MDMBoolKeysOverrideOnDiskValue(t *testing.T) {
|
||||
func TestApply_MDMPreSharedKeyRedactionSentinelRejected(t *testing.T) {
|
||||
const maskSentinel = "**********"
|
||||
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
cfg := configWithMDM(t, ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
}, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyPreSharedKey: maskSentinel,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
// Mask sentinel must not be persisted as the actual PSK.
|
||||
assert.NotEqual(t, maskSentinel, cfg.PreSharedKey)
|
||||
// Key still marked managed so user writes are still rejected.
|
||||
|
||||
@@ -84,16 +84,48 @@ func NewPolicy(values map[string]any) *Policy {
|
||||
return &Policy{values: values}
|
||||
}
|
||||
|
||||
// LoadPolicy reads the platform-native MDM configuration. Returns an
|
||||
// empty (but non-nil) Policy when no source is present, the source is
|
||||
// empty, or the platform is unsupported.
|
||||
// PolicyFetcher is implemented by mobile platforms (Android / iOS) that
|
||||
// push the OS-managed configuration into the Go runtime instead of
|
||||
// having Go read an on-disk source directly. Desktop platforms ignore
|
||||
// this interface — Loader.loadPlatform on windows/darwin reads the
|
||||
// registry / plist on its own. A Loader constructed with a non-nil
|
||||
// fetcher delegates to it on mobile; passing nil disables MDM
|
||||
// enforcement (loadPlatform returns nil values).
|
||||
type PolicyFetcher interface {
|
||||
Fetch() map[string]any
|
||||
}
|
||||
|
||||
// Loader is the DI-friendly entry point for reading the active MDM
|
||||
// policy. Construct one at the daemon's lifecycle owner (Server on
|
||||
// desktop, gomobile-exposed bridge on mobile) and pass it to anything
|
||||
// that needs to read MDM state (the reload ticker, profilemanager's
|
||||
// Config). Each callsite has the Loader handed in instead of looking
|
||||
// up package-level state.
|
||||
type Loader struct {
|
||||
fetcher PolicyFetcher
|
||||
}
|
||||
|
||||
// NewLoader constructs a Loader. The fetcher is consulted only on
|
||||
// mobile builds (ios || android); on desktop it is unused but accepted
|
||||
// to keep a single constructor signature across platforms — pass nil
|
||||
// on desktop.
|
||||
func NewLoader(f PolicyFetcher) *Loader {
|
||||
return &Loader{fetcher: f}
|
||||
}
|
||||
|
||||
// Load reads the platform-native MDM configuration and returns a
|
||||
// Policy. Returns an empty (but non-nil) Policy when no source is
|
||||
// present, the source is empty, or the platform is unsupported.
|
||||
//
|
||||
// Diagnostic logging differentiates the three states:
|
||||
// - source absent / unsupported platform: trace log only
|
||||
// - source present, zero keys: info "MDM enrolled (no managed keys)"
|
||||
// - source present, N keys: info "MDM enrolled with N managed keys: [...]"
|
||||
func LoadPolicy() *Policy {
|
||||
values, err := loadPlatformPolicy()
|
||||
func (l *Loader) Load() *Policy {
|
||||
if l == nil {
|
||||
return &Policy{values: map[string]any{}}
|
||||
}
|
||||
values, err := l.loadPlatform()
|
||||
if err != nil {
|
||||
log.Tracef("MDM policy load: %v", err)
|
||||
return &Policy{values: map[string]any{}}
|
||||
|
||||
@@ -25,8 +25,10 @@ import (
|
||||
// writable plist, as a defense against tampered installs.
|
||||
const policyPlistPath = "/Library/Managed Preferences/io.netbird.client.plist"
|
||||
|
||||
// loadPlatformPolicy reads the MDM-managed configuration from the macOS
|
||||
// managed-preferences plist at policyPlistPath. Returns:
|
||||
// loadPlatform reads the MDM-managed configuration from the macOS
|
||||
// managed-preferences plist at policyPlistPath. The Loader's fetcher
|
||||
// field is unused on this platform — the plist is the authoritative
|
||||
// source. Returns:
|
||||
// - (nil, nil) when the plist is absent (device not MDM-enrolled for
|
||||
// NetBird, or admin has not yet pushed a payload)
|
||||
// - (map, nil) with N entries when N managed values are present
|
||||
@@ -39,7 +41,13 @@ const policyPlistPath = "/Library/Managed Preferences/io.netbird.client.plist"
|
||||
// skipped so a stray entry in the payload does not block startup.
|
||||
// Native plist value types map naturally onto the Policy accessor
|
||||
// expectations (GetString / GetBool / GetInt / GetStringSlice).
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
func (l *Loader) loadPlatform() (map[string]any, error) {
|
||||
// Honour the injected fetcher when present so tests (and any
|
||||
// future non-macOS MDM channel) can short-circuit the plist read
|
||||
// with a scripted policy.
|
||||
if l != nil && l.fetcher != nil {
|
||||
return l.fetcher.Fetch(), nil
|
||||
}
|
||||
f, err := os.Open(policyPlistPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
|
||||
@@ -2,13 +2,14 @@
|
||||
|
||||
package mdm
|
||||
|
||||
// loadPlatformPolicy is unused on mobile: the native layer (Swift on iOS,
|
||||
// Kotlin/Java on Android) reads the OS managed-config store and pushes the
|
||||
// resulting dictionary in-process via a gomobile entry point that lands in
|
||||
// Phase 5 / Phase 6. The stub keeps the package compilable for mobile
|
||||
// builds and returns (nil, nil) — the platform-absent sentinel that
|
||||
// LoadPolicy in policy.go treats as "no MDM source present".
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
return nil, nil
|
||||
// loadPlatform reads the OS-managed configuration via the native
|
||||
// PolicyFetcher injected at Loader construction. Returns
|
||||
// (nil, nil) — the platform-absent sentinel that Loader.Load treats as
|
||||
// "no MDM source present" — when no fetcher was provided.
|
||||
func (l *Loader) loadPlatform() (map[string]any, error) {
|
||||
if l == nil || l.fetcher == nil {
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see Loader.Load.
|
||||
return nil, nil
|
||||
}
|
||||
return l.fetcher.Fetch(), nil
|
||||
}
|
||||
|
||||
@@ -2,13 +2,17 @@
|
||||
|
||||
package mdm
|
||||
|
||||
// loadPlatformPolicy returns no policy on platforms without an MDM channel
|
||||
// (Linux, FreeBSD). MDM enforcement is off and the client behaves as if
|
||||
// the feature did not exist. Returns (nil, nil) — the platform-absent
|
||||
// sentinel the caller (LoadPolicy in policy.go) treats as "no MDM
|
||||
// source present"; an error here would just translate to the same
|
||||
// outcome with an extra log line.
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
|
||||
// loadPlatform reads the MDM policy on platforms without a native MDM
|
||||
// channel (Linux, FreeBSD). When no fetcher was injected the policy is
|
||||
// (nil, nil) — the platform-absent sentinel that Loader.Load treats as
|
||||
// "MDM enforcement disabled". A non-nil fetcher takes precedence: it
|
||||
// is the test-seam used by unit tests to inject a scripted policy
|
||||
// without touching the OS, and the same hook supports any future
|
||||
// non-mobile OS that grows an out-of-band MDM channel.
|
||||
func (l *Loader) loadPlatform() (map[string]any, error) {
|
||||
if l != nil && l.fetcher != nil {
|
||||
return l.fetcher.Fetch(), nil
|
||||
}
|
||||
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see Loader.Load.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -150,10 +150,12 @@ func TestPolicy_GetStringSlice(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadPolicy_PlatformStubReturnsEmpty(t *testing.T) {
|
||||
// loadPlatformPolicy is a stub on every OS for Phase 1. LoadPolicy must
|
||||
// degrade gracefully and never return nil.
|
||||
p := LoadPolicy()
|
||||
func TestLoader_NilFetcherReturnsEmpty(t *testing.T) {
|
||||
// Loader.Load with no fetcher (desktop construction) must degrade
|
||||
// gracefully and never return nil; on linux loadPlatform is a stub
|
||||
// returning (nil, nil), and Load is expected to translate that
|
||||
// into a non-nil empty Policy.
|
||||
p := NewLoader(nil).Load()
|
||||
require.NotNil(t, p)
|
||||
assert.True(t, p.IsEmpty())
|
||||
assert.Empty(t, p.ManagedKeys())
|
||||
|
||||
@@ -61,8 +61,10 @@ func readRegistryValue(k registry.Key, name, canonical string, out map[string]an
|
||||
}
|
||||
}
|
||||
|
||||
// loadPlatformPolicy reads the MDM-managed configuration from the
|
||||
// Windows registry under HKLM\Software\Policies\NetBird. Returns:
|
||||
// loadPlatform reads the MDM-managed configuration from the Windows
|
||||
// registry under HKLM\Software\Policies\NetBird. The Loader's fetcher
|
||||
// field is unused on this platform — the registry is the
|
||||
// authoritative source. Returns:
|
||||
// - (nil, nil) when the key is absent (device not MDM-enrolled for NetBird)
|
||||
// - (map, nil) with N entries when N managed values are set (N may be 0)
|
||||
// - (nil, err) on open / enumerate registry errors
|
||||
@@ -70,7 +72,13 @@ func readRegistryValue(k registry.Key, name, canonical string, out map[string]an
|
||||
// Per-value type coercion + skip-on-error is delegated to
|
||||
// readRegistryValue. Unknown value names are logged and skipped so a
|
||||
// malformed deployment does not block startup.
|
||||
func loadPlatformPolicy() (map[string]any, error) {
|
||||
func (l *Loader) loadPlatform() (map[string]any, error) {
|
||||
// Honour the injected fetcher when present so tests (and any
|
||||
// future non-Windows MDM channel) can short-circuit the registry
|
||||
// read with a scripted policy.
|
||||
if l != nil && l.fetcher != nil {
|
||||
return l.fetcher.Fetch(), nil
|
||||
}
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, policyRegistryPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
if errors.Is(err, registry.ErrNotExist) {
|
||||
|
||||
@@ -15,33 +15,33 @@ import (
|
||||
// instead, hence anticipating the ticker mechanism entirely.
|
||||
const DefaultReloadInterval = 1 * time.Minute
|
||||
|
||||
// policyLoader is the indirection through which the ticker reads the
|
||||
// OS-native policy, both for the initial observation and on every tick.
|
||||
// Production points it at LoadPolicy; tests in this package override it to
|
||||
// feed a scripted sequence of policies without touching the real OS store.
|
||||
var policyLoader = LoadPolicy
|
||||
|
||||
// Ticker periodically re-reads the OS-native MDM policy via LoadPolicy and
|
||||
// invokes the onChange callback (supplied to Run) whenever the observed
|
||||
// Policy diverges from the last observation (added / removed / changed
|
||||
// keys). Launch with Run from a goroutine; cancel the supplied context
|
||||
// to stop.
|
||||
// Ticker periodically re-reads the OS-native MDM policy via the
|
||||
// injected Loader and invokes the onChange callback (supplied to Run)
|
||||
// whenever the observed Policy diverges from the last observation
|
||||
// (added / removed / changed keys). Launch with Run from a goroutine;
|
||||
// cancel the supplied context to stop.
|
||||
type Ticker struct {
|
||||
interval time.Duration
|
||||
loader *Loader
|
||||
prev *Policy
|
||||
}
|
||||
|
||||
// NewTicker constructs a Ticker that will re-read the OS-native policy
|
||||
// every reloadInterval once Run is called.
|
||||
// The initial snapshot is populated by calling policyLoader at
|
||||
// every reloadInterval once Run is called. The Loader is injected so
|
||||
// the ticker doesn't depend on any package-level state — production
|
||||
// passes the daemon-owned Loader, tests pass a fake Loader (built with
|
||||
// a fake PolicyFetcher).
|
||||
//
|
||||
// The initial snapshot is populated by calling loader.Load() at
|
||||
// construction time so the first tick only fires
|
||||
// onChange when the policy actually changed since boot — without
|
||||
// this baseline the first tick would report every currently-managed
|
||||
// key as "added" and trigger a spurious engine restart.
|
||||
func NewTicker(reloadInterval time.Duration) *Ticker {
|
||||
func NewTicker(reloadInterval time.Duration, loader *Loader) *Ticker {
|
||||
return &Ticker{
|
||||
interval: reloadInterval,
|
||||
prev: policyLoader(),
|
||||
loader: loader,
|
||||
prev: loader.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func (t *Ticker) Run(ctx context.Context, onChange func(prev, curr *Policy) erro
|
||||
log.Info("MDM policy reload ticker stopped")
|
||||
return
|
||||
case <-tk.C:
|
||||
curr := policyLoader()
|
||||
curr := t.loader.Load()
|
||||
if policiesEqual(t.prev, curr) {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -13,28 +13,40 @@ import (
|
||||
// testReloadInterval for speeding up the ticker cadence under `go test`
|
||||
const testReloadInterval = 1 * time.Second
|
||||
|
||||
// withPolicyLoader overrides the package-level policyLoader for the duration
|
||||
// of the test so the ticker observes a scripted policy instead of the real
|
||||
// OS-native store. The original loader is restored on cleanup.
|
||||
func withPolicyLoader(t *testing.T, fn func() *Policy) {
|
||||
t.Helper()
|
||||
prev := policyLoader
|
||||
policyLoader = fn
|
||||
t.Cleanup(func() { policyLoader = prev })
|
||||
// fakePolicyFetcher implements PolicyFetcher returning a scripted
|
||||
// policy map. Goroutine-safe so the test can mutate the script while
|
||||
// the ticker is observing it.
|
||||
type fakePolicyFetcher struct {
|
||||
mu sync.Mutex
|
||||
values map[string]any
|
||||
}
|
||||
|
||||
func (f *fakePolicyFetcher) Fetch() map[string]any {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
if f.values == nil {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(f.values))
|
||||
for k, v := range f.values {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (f *fakePolicyFetcher) set(values map[string]any) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.values = values
|
||||
}
|
||||
|
||||
func TestTicker_FiresOnChangeWithDelta(t *testing.T) {
|
||||
var mu sync.Mutex
|
||||
current := NewPolicy(nil) // initial observation: empty (no enforcement)
|
||||
withPolicyLoader(t, func() *Policy {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return current
|
||||
})
|
||||
fetcher := &fakePolicyFetcher{} // initial observation: empty (no enforcement)
|
||||
loader := NewLoader(fetcher)
|
||||
|
||||
type change struct{ prev, curr *Policy }
|
||||
changes := make(chan change, 1)
|
||||
tk := NewTicker(testReloadInterval)
|
||||
tk := NewTicker(testReloadInterval, loader)
|
||||
require.Equal(t, testReloadInterval, tk.interval)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -49,15 +61,13 @@ func TestTicker_FiresOnChangeWithDelta(t *testing.T) {
|
||||
})
|
||||
close(done)
|
||||
}()
|
||||
// Stop Run and wait for it to exit before returning, so the policyLoader
|
||||
// restore in t.Cleanup can't race the ticker goroutine still reading it.
|
||||
// Stop Run and wait for it to exit before returning, so the test
|
||||
// goroutine doesn't race the still-running ticker.
|
||||
defer func() { cancel(); <-done }()
|
||||
|
||||
// Flip the OS-observed policy from empty to one managed key. The next
|
||||
// tick must detect the diff and invoke onChange.
|
||||
mu.Lock()
|
||||
current = NewPolicy(map[string]any{KeyManagementURL: "https://mdm.example.com:443"})
|
||||
mu.Unlock()
|
||||
// Flip the OS-observed policy from empty to one managed key. The
|
||||
// next tick must detect the diff and invoke onChange.
|
||||
fetcher.set(map[string]any{KeyManagementURL: "https://mdm.example.com:443"})
|
||||
|
||||
select {
|
||||
case c := <-changes:
|
||||
@@ -69,12 +79,11 @@ func TestTicker_FiresOnChangeWithDelta(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestTicker_NoCallbackWhenPolicyUnchanged(t *testing.T) {
|
||||
withPolicyLoader(t, func() *Policy {
|
||||
return NewPolicy(map[string]any{KeyBlockInbound: true})
|
||||
})
|
||||
fetcher := &fakePolicyFetcher{values: map[string]any{KeyBlockInbound: true}}
|
||||
loader := NewLoader(fetcher)
|
||||
|
||||
fired := make(chan struct{}, 1)
|
||||
tk := NewTicker(testReloadInterval)
|
||||
tk := NewTicker(testReloadInterval, loader)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
@@ -90,8 +99,8 @@ func TestTicker_NoCallbackWhenPolicyUnchanged(t *testing.T) {
|
||||
}()
|
||||
defer func() { cancel(); <-done }()
|
||||
|
||||
// Over ~2 ticks at the 1s test cadence the policy never changes, so the
|
||||
// diff guard must suppress the callback entirely.
|
||||
// Over ~2 ticks at the 1s test cadence the policy never changes,
|
||||
// so the diff guard must suppress the callback entirely.
|
||||
select {
|
||||
case <-fired:
|
||||
t.Fatal("onChange fired despite an unchanged policy")
|
||||
|
||||
@@ -20,10 +20,6 @@ import (
|
||||
// a no-op echo, never as a conflict with the policy.
|
||||
const preSharedKeyRedactedSentinel = "**********"
|
||||
|
||||
// loadMDMPolicy is the indirection used by server handlers to read the
|
||||
// active MDM policy. Tests override this to inject a fake policy.
|
||||
var loadMDMPolicy = mdm.LoadPolicy
|
||||
|
||||
// conflictCheck is a value-aware comparison between a single field in
|
||||
// the incoming request and the corresponding MDM-enforced value. It
|
||||
// runs only when the field was actually set in the request (presence
|
||||
|
||||
@@ -110,6 +110,15 @@ type Server struct {
|
||||
// stopped by the rootCtx cancellation.
|
||||
mdmTicker *mdm.Ticker
|
||||
|
||||
// mdmLoader is the daemon-owned source of the active MDM policy.
|
||||
// Constructed once during Server.Start (with a nil PolicyFetcher on
|
||||
// desktop — the build-tagged Loader.loadPlatform reads the OS
|
||||
// registry / plist directly) and injected into every consumer:
|
||||
// mdmTicker for its periodic reload, the SetConfig / Login MDM
|
||||
// gates for conflict detection, and every Config produced via
|
||||
// getConfig() so its apply() picks up the same overlay.
|
||||
mdmLoader *mdm.Loader
|
||||
|
||||
updateManager *updater.Manager
|
||||
|
||||
jwtCache *jwtCache
|
||||
@@ -173,8 +182,14 @@ func (s *Server) Start() error {
|
||||
// Runs re-resolves Config (re-running profilemanager.Config.apply which
|
||||
// applies the freshly-read MDM policy as the last layer) and brings
|
||||
// the engine back with the new values.
|
||||
if s.mdmLoader == nil {
|
||||
// Desktop builds pass a nil PolicyFetcher: the Loader's
|
||||
// build-tagged loadPlatform reads the OS source directly
|
||||
// (registry on Windows, plist on macOS, no-op elsewhere).
|
||||
s.mdmLoader = mdm.NewLoader(nil)
|
||||
}
|
||||
if s.mdmTicker == nil {
|
||||
s.mdmTicker = mdm.NewTicker(mdm.DefaultReloadInterval)
|
||||
s.mdmTicker = mdm.NewTicker(mdm.DefaultReloadInterval, s.mdmLoader)
|
||||
go s.mdmTicker.Run(s.rootCtx, s.onMDMPolicyChange)
|
||||
}
|
||||
|
||||
@@ -370,7 +385,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
// by the active MDM policy. The error carries an MDMManagedFields-
|
||||
// Violation detail listing the offending key names. Non-conflicting
|
||||
// fields in the same request are not applied either.
|
||||
policy := loadMDMPolicy()
|
||||
policy := s.mdmLoader.Load()
|
||||
if err := rejectMDMManagedFieldConflicts(mdmManagedFieldConflicts(msg, policy)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -496,7 +511,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
if s.checkUpdateSettingsDisabled() {
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
|
||||
}
|
||||
policy := loadMDMPolicy()
|
||||
policy := s.mdmLoader.Load()
|
||||
if err := rejectMDMManagedFieldConflicts(loginRequestMDMConflicts(msg, policy)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1088,6 +1103,12 @@ func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*prof
|
||||
return nil, false, fmt.Errorf("failed to get config: %w", err)
|
||||
}
|
||||
|
||||
// Apply the daemon-owned MDM policy on top of the just-resolved
|
||||
// Config. profilemanager's apply() initialises the policy to
|
||||
// empty — the Loader lives outside Config, so this overlay step
|
||||
// is driven externally here.
|
||||
config.ApplyMDMPolicy(s.mdmLoader.Load())
|
||||
|
||||
return config, configExisted, nil
|
||||
}
|
||||
|
||||
@@ -1142,6 +1163,9 @@ func (s *Server) logoutFromProfile(ctx context.Context, profileName, username st
|
||||
if err != nil {
|
||||
return fmt.Errorf("profile '%s' not found", profileName)
|
||||
}
|
||||
// Honour any MDM-enforced ManagementURL when issuing the logout
|
||||
// RPC: the user-stored value may have been overridden by policy.
|
||||
config.ApplyMDMPolicy(s.mdmLoader.Load())
|
||||
|
||||
return s.sendLogoutRequestWithConfig(ctx, config)
|
||||
}
|
||||
@@ -1574,6 +1598,11 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
log.Errorf("failed to get active profile config: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile config: %w", err)
|
||||
}
|
||||
// Overlay the active MDM policy so the response's MDMManagedFields
|
||||
// list reflects what the GUI / CLI must render as read-only.
|
||||
// profilemanager.GetConfig itself returns a Config without the
|
||||
// overlay (Loader lives outside profilemanager).
|
||||
cfg.ApplyMDMPolicy(s.mdmLoader.Load())
|
||||
managementURL := cfg.ManagementURL
|
||||
adminURL := cfg.AdminURL
|
||||
|
||||
|
||||
@@ -16,14 +16,40 @@ import (
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
// withMDMPolicy temporarily overrides the server-package loadMDMPolicy hook
|
||||
// so SetConfig observes the supplied Policy. Restores the original loader
|
||||
// at test cleanup.
|
||||
func withMDMPolicy(t *testing.T, policy *mdm.Policy) {
|
||||
// fakeMDMFetcher implements mdm.PolicyFetcher returning a pre-set
|
||||
// policy map. Tests build one per Server instance to inject a
|
||||
// scripted MDM overlay via a Loader rather than via package-level state.
|
||||
type fakeMDMFetcher struct{ values map[string]any }
|
||||
|
||||
func (f *fakeMDMFetcher) Fetch() map[string]any { return f.values }
|
||||
|
||||
// withMDMPolicy installs an mdm.Loader on the given Server whose
|
||||
// loadPlatform returns the supplied Policy's underlying values. Use
|
||||
// after setupServerWithProfile to inject the scripted policy the
|
||||
// SetConfig / Login MDM gates will observe.
|
||||
func withMDMPolicy(t *testing.T, s *Server, policy *mdm.Policy) {
|
||||
t.Helper()
|
||||
prev := loadMDMPolicy
|
||||
loadMDMPolicy = func() *mdm.Policy { return policy }
|
||||
t.Cleanup(func() { loadMDMPolicy = prev })
|
||||
values := map[string]any{}
|
||||
if policy != nil {
|
||||
for _, k := range policy.ManagedKeys() {
|
||||
if v, ok := policy.GetString(k); ok {
|
||||
values[k] = v
|
||||
continue
|
||||
}
|
||||
if v, ok := policy.GetBool(k); ok {
|
||||
values[k] = v
|
||||
continue
|
||||
}
|
||||
if v, ok := policy.GetInt(k); ok {
|
||||
values[k] = v
|
||||
continue
|
||||
}
|
||||
if v, ok := policy.GetStringSlice(k); ok {
|
||||
values[k] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
s.mdmLoader = mdm.NewLoader(&fakeMDMFetcher{values: values})
|
||||
}
|
||||
|
||||
// setupServerWithProfile mirrors the boilerplate of TestSetConfig_AllFieldsSaved:
|
||||
@@ -89,12 +115,11 @@ func extractViolation(t *testing.T, err error) *proto.MDMManagedFieldsViolation
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMReject_SingleField(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
withMDMPolicy(t, s, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: username,
|
||||
@@ -106,13 +131,12 @@ func TestSetConfig_MDMReject_SingleField(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSetConfig_MDMReject_MultipleFields(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
mdm.KeyBlockInbound: true,
|
||||
mdm.KeyRosenpassEnabled: true,
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
withMDMPolicy(t, s, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
mdm.KeyBlockInbound: true,
|
||||
mdm.KeyRosenpassEnabled: true,
|
||||
}))
|
||||
|
||||
blockInbound := false
|
||||
rosenpassEnabled := false
|
||||
@@ -137,12 +161,11 @@ func TestSetConfig_MDMReject_AllOrNothing(t *testing.T) {
|
||||
// enforced field AND a non-enforced field (RosenpassEnabled).
|
||||
// The whole request must be rejected — non-conflicting fields are not
|
||||
// applied either.
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
s, ctx, profName, username, cfgPath := setupServerWithProfile(t)
|
||||
withMDMPolicy(t, s, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, cfgPath := setupServerWithProfile(t)
|
||||
|
||||
rosenpassEnabled := true
|
||||
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
@@ -164,12 +187,11 @@ func TestSetConfig_MDMReject_AllOrNothing(t *testing.T) {
|
||||
func TestSetConfig_MDMAllow_NonManagedFields(t *testing.T) {
|
||||
// MDM enforces ManagementURL but the user only writes RosenpassEnabled.
|
||||
// Request must succeed.
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
withMDMPolicy(t, s, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyManagementURL: "https://mdm.example.com:443",
|
||||
}))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
|
||||
rosenpassEnabled := true
|
||||
resp, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
@@ -183,9 +205,8 @@ func TestSetConfig_MDMAllow_NonManagedFields(t *testing.T) {
|
||||
|
||||
func TestSetConfig_MDMEmpty_NoEnforcement(t *testing.T) {
|
||||
// No MDM policy active: any field can be written.
|
||||
withMDMPolicy(t, mdm.NewPolicy(nil))
|
||||
|
||||
s, ctx, profName, username, _ := setupServerWithProfile(t)
|
||||
withMDMPolicy(t, s, mdm.NewPolicy(nil))
|
||||
|
||||
resp, err := s.SetConfig(ctx, &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
|
||||
@@ -41,6 +41,8 @@ type Config struct {
|
||||
GRPCAddr string
|
||||
}
|
||||
|
||||
const localConnectorID = "local"
|
||||
|
||||
// Provider wraps a Dex server
|
||||
type Provider struct {
|
||||
config *Config
|
||||
@@ -544,7 +546,7 @@ func (p *Provider) CreateUser(ctx context.Context, email, username, password str
|
||||
|
||||
// Encode the user ID in Dex's format: base64(protobuf{user_id, connector_id})
|
||||
// This matches the format Dex uses in JWT tokens
|
||||
encodedID := EncodeDexUserID(userID, "local")
|
||||
encodedID := EncodeDexUserID(userID, localConnectorID)
|
||||
return encodedID, nil
|
||||
}
|
||||
|
||||
@@ -619,6 +621,13 @@ func DecodeDexUserID(encodedID string) (userID, connectorID string, err error) {
|
||||
return userID, connectorID, nil
|
||||
}
|
||||
|
||||
// IsLocalUserID reports whether encodedID is a Dex subject for the built-in
|
||||
// local password connector.
|
||||
func IsLocalUserID(encodedID string) bool {
|
||||
_, connectorID, err := DecodeDexUserID(encodedID)
|
||||
return err == nil && connectorID == localConnectorID
|
||||
}
|
||||
|
||||
// GetUser returns a user by email
|
||||
func (p *Provider) GetUser(ctx context.Context, email string) (storage.Password, error) {
|
||||
return p.storage.GetPassword(ctx, email)
|
||||
|
||||
@@ -115,6 +115,26 @@ func TestDecodeDexUserID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsLocalUserID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
encodedID string
|
||||
want bool
|
||||
}{
|
||||
{name: "local connector", encodedID: EncodeDexUserID("7aad8c05-3287-473f-b42a-365504bf25e7", "local"), want: true},
|
||||
{name: "federated connector", encodedID: EncodeDexUserID("entra-user", "entra"), want: false},
|
||||
{name: "non-dex external IdP id", encodedID: "google-oauth2|1234567890", want: false},
|
||||
{name: "invalid base64", encodedID: "not-valid-base64!!!", want: false},
|
||||
{name: "empty", encodedID: "", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, IsLocalUserID(tt.encodedID))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeDexUserID(t *testing.T) {
|
||||
userID := "7aad8c05-3287-473f-b42a-365504bf25e7"
|
||||
connectorID := "local"
|
||||
|
||||
@@ -28,6 +28,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
@@ -1588,7 +1589,10 @@ func (am *DefaultAccountManager) updateUserAuthWithSingleMode(ctx context.Contex
|
||||
// and propagates changes to peers if group propagation is enabled.
|
||||
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
|
||||
func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error {
|
||||
if userAuth.IsChild || userAuth.IsPAT {
|
||||
// Child accounts and PAT-authenticated requests do not sync JWT groups.
|
||||
// Embedded-Dex local users also skip sync because local password authentication
|
||||
// does not provide external IdP group claims.
|
||||
if userAuth.IsChild || userAuth.IsPAT || dex.IsLocalUserID(userAuth.UserId) {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
@@ -723,6 +724,28 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
|
||||
require.Equal(t, g2.Name, "group2", "group2 name should match")
|
||||
require.Equal(t, g2.Issued, types.GroupIssuedJWT, "group2 issued should match")
|
||||
})
|
||||
t.Run("local embedded-Dex user is skipped", func(t *testing.T) {
|
||||
initAccount.Settings.JWTGroupsEnabled = true
|
||||
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
|
||||
err := manager.Store.SaveAccount(context.Background(), initAccount)
|
||||
require.NoError(t, err, "save account failed")
|
||||
|
||||
localClaims := auth.UserAuth{
|
||||
AccountId: accountID,
|
||||
Domain: domain,
|
||||
UserId: dex.EncodeDexUserID("local-owner", "local"),
|
||||
Groups: []string{"group3", "group4"},
|
||||
}
|
||||
err = manager.SyncUserJWTGroups(context.Background(), localClaims)
|
||||
require.NoError(t, err, "sync should be a no-op for local users")
|
||||
|
||||
account, err := manager.Store.GetAccount(context.Background(), accountID)
|
||||
require.NoError(t, err, "get account failed")
|
||||
for _, g := range account.Groups {
|
||||
require.NotEqual(t, "group3", g.Name, "local user JWT groups must not be synced")
|
||||
require.NotEqual(t, "group4", g.Name, "local user JWT groups must not be synced")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccountManager_PrivateAccount(t *testing.T) {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
@@ -74,7 +75,10 @@ func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (auth
|
||||
}
|
||||
|
||||
func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) {
|
||||
if userAuth.IsChild || userAuth.IsPAT {
|
||||
// Child accounts and PAT-authenticated requests do not use JWT group access checks.
|
||||
// Embedded-Dex local users also skip them because local password authentication
|
||||
// does not provide external IdP group claims.
|
||||
if userAuth.IsChild || userAuth.IsPAT || dex.IsLocalUserID(userAuth.UserId) {
|
||||
return userAuth, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -206,6 +207,43 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) {
|
||||
_, err = manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
|
||||
require.Error(t, err, "ensure user access is not in allowed groups")
|
||||
})
|
||||
|
||||
t.Run("Local embedded-Dex user is exempt from JWT allow-groups", func(t *testing.T) {
|
||||
account.Settings.JWTGroupsEnabled = true
|
||||
account.Settings.JWTGroupsClaimName = "idp-groups"
|
||||
account.Settings.JWTAllowGroups = []string{"not-a-group"}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err, "save account failed")
|
||||
|
||||
// Local Dex users have a "local" connector encoded in their user ID.
|
||||
localUserAuth := nbauth.UserAuth{
|
||||
AccountId: account.Id,
|
||||
Domain: domain,
|
||||
UserId: dex.EncodeDexUserID("local-owner", "local"),
|
||||
}
|
||||
|
||||
localUserAuth, err = manager.EnsureUserAccessByJWTGroups(context.Background(), localUserAuth, token)
|
||||
require.NoError(t, err, "local user must not be locked out by JWT allow-groups (issue #5337)")
|
||||
require.Len(t, localUserAuth.Groups, 0, "JWT groups must not be evaluated for local users")
|
||||
})
|
||||
|
||||
t.Run("Federated embedded-Dex user is still subject to JWT allow-groups", func(t *testing.T) {
|
||||
account.Settings.JWTGroupsEnabled = true
|
||||
account.Settings.JWTGroupsClaimName = "idp-groups"
|
||||
account.Settings.JWTAllowGroups = []string{"not-a-group"}
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err, "save account failed")
|
||||
|
||||
// A federated user (non-"local" connector) must remain restricted.
|
||||
fedUserAuth := nbauth.UserAuth{
|
||||
AccountId: account.Id,
|
||||
Domain: domain,
|
||||
UserId: dex.EncodeDexUserID("entra-user", "entra"),
|
||||
}
|
||||
|
||||
_, err = manager.EnsureUserAccessByJWTGroups(context.Background(), fedUserAuth, token)
|
||||
require.Error(t, err, "federated user must still be restricted by JWT allow-groups")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthManager_ValidateAndParseToken(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user