Compare commits

..

5 Commits

Author SHA1 Message Date
Zoltán Papp
3236a4c7fd [client] Cancel ServeDNS-path mgmt resolve on server Stop
awaitPendingResolve resolved on context.Background(), so a resolve
triggered by an incoming query kept running to dnsTimeout after Stop,
making the DNS listener's ShutdownContext wait it out. Give the Resolver
the server-lifetime ctx and use it there so Stop cancels the resolve.
2026-06-22 22:34:06 +02:00
Zoltán Papp
08ac4855f6 [client] Scope mgmt cache initial resolve to server lifetime
Background initial resolves used context.Background(), so they outlived
a server Stop() and kept running (up to dnsTimeout) against a cache
that was already being discarded. Thread the server-lifetime ctx from
UpdateFromServerDomains through kickoffResolve into AddDomain so Stop()
cancels in-flight resolves. The ctx is not per-sync, so a fast-returning
sync still won't cancel them.
2026-06-22 22:19:51 +02:00
Zoltán Papp
b6c79f1f71 [client] Trim verbose comments in mgmt cache async resolve 2026-06-22 19:46:26 +02:00
Zoltán Papp
37be8811a3 [client] Move mgmt cache pending-resolve test helpers to export_test.go
WaitForPendingResolves and pendingCount existed only to let tests wait
for the background initial resolves; they had no production caller. Keep
the production API surface clean by moving them into export_test.go
(unexported, test-build only) as waitForPendingResolves/pendingCount.
2026-06-22 19:40:04 +02:00
Zoltán Papp
a7d85ff3ab [client] Resolve management cache domains asynchronously
UpdateFromServerDomains resolved the infrastructure domains (signal,
relay, STUN, TURN) synchronously, in a serial loop with a 5s timeout per
domain. It runs deep inside handleSync, under the engine syncMsgMux (and
the DNS server mutex). When DNS is unhealthy each lookup waits out its
full timeout, so the sync lock is held for many seconds — observed as a
16.6s wait for the lock by a signal handler in the field. That starves
the signal/p2p processing that shares syncMsgMux, which in turn prevents
the handshake that would make DNS healthy again: a self-reinforcing loop.

The cache cannot simply be skipped: it is a prerequisite for the relay
connection. The relay dials its server by hostname, resolved through the
NetBird DNS chain where this resolver sits on top (PriorityMgmtCache);
on a miss it falls through to upstream — the dead path the cache exists
to bypass.

Instead, record intent on sync and resolve on demand:

- UpdateFromServerDomains now marks each requested domain pending and
  kicks off resolution in the background via a dedicated singleflight
  group (resolveGroup), then returns immediately. No DNS I/O, no timeout,
  no blocking under the lock.
- ServeDNS, on a miss for a pending domain, waits on the in-flight
  resolve (joining the same singleflight flight, bounded by dnsTimeout)
  instead of falling through to the dead upstream. The wait now happens
  in the relay/DNS-serving goroutine, never under syncMsgMux.
- UpdateServerConfig registers the cache handler from the requested
  domains (RequestedDomains) rather than the already-resolved ones, so
  the resolver owns the names before resolution completes and can
  intercept the lookup.

Background resolves use context.Background() so a fast-returning sync
cannot cancel an in-flight resolve. WaitForPendingResolves lets tests
(and any warm-cache path) wait for the background work to settle.

Cache semantics are otherwise unchanged: TTL, stale-while-revalidate
refresh, backoff, and the flow-domain exclusion all stay.
2026-06-22 19:33:49 +02:00
14 changed files with 304 additions and 830 deletions

View File

@@ -0,0 +1,23 @@
package mgmt
import "time"
// pendingCount returns how many initial resolves are still in flight. Test-only.
func (m *Resolver) pendingCount() int {
m.mutex.RLock()
defer m.mutex.RUnlock()
return len(m.pending)
}
// waitForPendingResolves blocks until all pending resolves settle or the
// timeout elapses, returning true if all settled. Test-only.
func (m *Resolver) waitForPendingResolves(timeout time.Duration) bool {
deadline := time.Now().Add(timeout)
for m.pendingCount() > 0 {
if time.Now().After(deadline) {
return false
}
time.Sleep(10 * time.Millisecond)
}
return true
}

View File

@@ -50,17 +50,31 @@ type cachedRecord struct {
consecFailures int
}
// pendingEntry marks a domain whose initial resolve is in flight, so ServeDNS
// can wait on it instead of falling through to upstream.
type pendingEntry struct{}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
// records, refreshing, pending, mgmtDomain and serverDomains are all guarded by mutex.
type Resolver struct {
// ctx is the server-lifetime context for background resolves.
ctx context.Context
records map[dns.Question]*cachedRecord
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
// pending holds domains whose initial resolve is in flight, keyed by
// punycode FQDN (trailing dot).
pending map[string]pendingEntry
chain ChainResolver
chainMaxPriority int
refreshGroup singleflight.Group
// resolveGroup dedups initial (cold-cache) resolves; kept separate from
// refreshGroup so initial and stale-refresh flights don't collapse.
resolveGroup singleflight.Group
// refreshing tracks questions whose refresh is running via the OS
// fallback path. A ServeDNS hit for a question in this map indicates
@@ -74,10 +88,12 @@ type Resolver struct {
}
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
func NewResolver(ctx context.Context) *Resolver {
return &Resolver{
ctx: ctx,
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
pending: make(map[string]pendingEntry),
cacheTTL: resolveCacheTTL(),
}
}
@@ -117,6 +133,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
m.mutex.RLock()
cached, found := m.records[question]
inflight := m.refreshing[question]
_, isPending := m.pending[question.Name]
var shouldRefresh bool
if found {
stale := time.Since(cached.cachedAt) > m.cacheTTL
@@ -126,8 +143,17 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
m.mutex.RUnlock()
if !found {
m.continueToNext(w, r)
return
// Registered but not resolved yet: wait on the in-flight resolve
// rather than falling through to (possibly dead) upstream.
if isPending && m.awaitPendingResolve(question.Name) {
m.mutex.RLock()
cached, found = m.records[question]
m.mutex.RUnlock()
}
if !found {
m.continueToNext(w, r)
return
}
}
if inflight != nil && inflight.CompareAndSwap(false, true) {
@@ -467,6 +493,13 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
return nil
}
// RequestedDomains returns the cacheable infrastructure domains (signal, relay,
// STUN, TURN; flow excluded) so the cache handler can be registered for them
// before resolution completes.
func (m *Resolver) RequestedDomains(serverDomains dnsconfig.ServerDomains) domain.List {
return m.extractDomainsFromServerDomains(serverDomains)
}
// GetCachedDomains returns a list of all cached domains.
func (m *Resolver) GetCachedDomains() domain.List {
m.mutex.RLock()
@@ -489,6 +522,11 @@ func (m *Resolver) GetCachedDomains() domain.List {
// UpdateFromServerDomains updates the cache with server domains from network configuration.
// It merges new domains with existing ones, replacing entire domain types when updated.
// Empty updates are ignored to prevent clearing infrastructure domains during partial updates.
// UpdateFromServerDomains records the requested domains and kicks off their
// resolution in the background, returning without blocking on DNS so it stays
// off the engine sync lock held by the caller. ctx scopes the background
// resolves to the server lifetime: it is not per-sync, so a fast-returning
// sync won't cancel them, but a server Stop will.
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) (domain.List, error) {
newDomains := m.extractDomainsFromServerDomains(serverDomains)
var removedDomains domain.List
@@ -507,11 +545,95 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
}
m.addNewDomains(ctx, newDomains)
m.kickoffResolve(ctx, newDomains)
return removedDomains, nil
}
// kickoffResolve marks each domain pending and starts a background resolve,
// skipping ones already fresh or in flight. Returns immediately.
func (m *Resolver) kickoffResolve(ctx context.Context, domains domain.List) {
for _, d := range domains {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
m.mutex.Lock()
_, hasPending := m.pending[dnsName]
cached := m.hasFreshRecordLocked(dnsName)
if !hasPending && !cached {
m.pending[dnsName] = pendingEntry{}
}
m.mutex.Unlock()
if hasPending || cached {
continue
}
m.scheduleInitialResolve(ctx, d, dnsName)
}
}
// scheduleInitialResolve runs AddDomain in the background, deduped per domain
// by resolveGroup, clearing the pending marker when it finishes. ctx is the
// server-lifetime context so a Stop cancels in-flight resolves.
func (m *Resolver) scheduleInitialResolve(ctx context.Context, d domain.Domain, dnsName string) {
key := "initial|" + dnsName
_ = m.resolveGroup.DoChan(key, func() (any, error) {
defer m.clearPending(dnsName)
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
return struct{}{}, err
}
log.Debugf("added/updated management cache domain=%s", d.SafeString())
return struct{}{}, nil
})
}
// hasFreshRecordLocked reports whether a non-stale A or AAAA record exists for
// the name. Caller holds m.mutex.
func (m *Resolver) hasFreshRecordLocked(dnsName string) bool {
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
if c, ok := m.records[q]; ok && time.Since(c.cachedAt) <= m.cacheTTL {
return true
}
}
return false
}
func (m *Resolver) clearPending(dnsName string) {
m.mutex.Lock()
delete(m.pending, dnsName)
m.mutex.Unlock()
}
// awaitPendingResolve joins the in-flight resolve for dnsName (bounded by
// dnsTimeout) and reports whether a record became available.
func (m *Resolver) awaitPendingResolve(dnsName string) bool {
key := "initial|" + dnsName
d, err := domain.FromString(strings.TrimSuffix(dnsName, "."))
if err != nil {
return false
}
ch := m.resolveGroup.DoChan(key, func() (any, error) {
defer m.clearPending(dnsName)
if err := m.AddDomain(m.ctx, d); err != nil {
return struct{}{}, err
}
return struct{}{}, nil
})
select {
case <-ch:
case <-time.After(dnsTimeout):
return false
}
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.hasFreshRecordLocked(dnsName)
}
// removeStaleDomains removes cached domains not present in the target domain list.
// Management domains are preserved and never removed during server domain updates.
func (m *Resolver) removeStaleDomains(currentDomains, newDomains domain.List) domain.List {
@@ -577,17 +699,6 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
return m.mgmtDomain != nil && domain == *m.mgmtDomain
}
// addNewDomains resolves and caches all domains from the update
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
for _, newDomain := range newDomains {
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
} else {
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
}
}
}
func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) domain.List {
var domains domain.List

View File

@@ -130,7 +130,7 @@ func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
r.cacheTTL = 10 * time.Millisecond
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
@@ -146,7 +146,7 @@ func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
})
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
r.cacheTTL = time.Hour
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
@@ -162,7 +162,7 @@ func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
}
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
@@ -183,7 +183,7 @@ func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
}
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
@@ -213,7 +213,7 @@ func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
}
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
@@ -262,7 +262,7 @@ func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
}
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
chain := newFakeChain()
chain.err = errors.New("boom")
r.SetChainResolver(chain, 50)
@@ -299,7 +299,7 @@ func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
}
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
chain := newFakeChain()
chain.hasRoot = false
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
@@ -320,7 +320,7 @@ func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
// ServeDNS being invoked for a question while a refresh for that question
// is inflight indicates a resolver loop (OS resolver sent the recursive
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
r := NewResolver()
r := NewResolver(context.Background())
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
@@ -346,7 +346,7 @@ func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
}
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
@@ -373,7 +373,7 @@ func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
}
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
@@ -393,7 +393,7 @@ func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
}
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")

View File

@@ -17,7 +17,7 @@ import (
)
func TestResolver_NewResolver(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
assert.NotNil(t, resolver)
assert.NotNil(t, resolver.records)
@@ -49,7 +49,7 @@ func TestResolveCacheTTL(t *testing.T) {
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
t.Setenv(envMgmtCacheTTL, "7s")
r := NewResolver()
r := NewResolver(context.Background())
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
}
@@ -169,7 +169,7 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := NewResolver()
resolver := NewResolver(context.Background())
// Test with IP address - should return error since IP addresses are rejected
mgmtURL, _ := url.Parse("https://127.0.0.1")
@@ -184,7 +184,7 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
}
func TestResolver_ServeDNS(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
ctx := context.Background()
// Add a test domain to the cache - use example.org which is reserved for testing
@@ -284,7 +284,7 @@ func TestResolver_ServeDNS(t *testing.T) {
}
func TestResolver_GetCachedDomains(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
ctx := context.Background()
testDomain, err := domain.FromString("example.org")
@@ -304,7 +304,7 @@ func TestResolver_GetCachedDomains(t *testing.T) {
}
func TestResolver_ManagementDomainProtection(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
ctx := context.Background()
mgmtURL, _ := url.Parse("https://example.org")
@@ -329,6 +329,7 @@ func TestResolver_ManagementDomainProtection(t *testing.T) {
if err != nil {
t.Logf("Server domains update failed: %v", err)
}
resolver.waitForPendingResolves(10 * time.Second)
finalDomains := resolver.GetCachedDomains()
@@ -351,7 +352,7 @@ func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
}
func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
ctx := context.Background()
// Set up initial domains using resolvable domains
@@ -366,6 +367,7 @@ func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
resolver.waitForPendingResolves(10 * time.Second)
// Verify domains were added
cachedDomains := resolver.GetCachedDomains()
@@ -385,7 +387,7 @@ func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
}
func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
ctx := context.Background()
// Set up initial complete domains using resolvable domains
@@ -400,6 +402,7 @@ func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
resolver.waitForPendingResolves(10 * time.Second)
assert.Len(t, resolver.GetCachedDomains(), 3)
// Update with partial ServerDomains (only signal domain - this should replace signal but preserve stun/turn)
@@ -410,6 +413,7 @@ func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
resolver.waitForPendingResolves(10 * time.Second)
// Should remove only the old signal domain
assert.Len(t, removedDomains, 1, "Should remove only the old signal domain")
@@ -429,7 +433,7 @@ func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
}
func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
ctx := context.Background()
// Set up initial complete domains using resolvable domains
@@ -444,6 +448,7 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
resolver.waitForPendingResolves(10 * time.Second)
assert.Len(t, resolver.GetCachedDomains(), 3)
// Update with partial ServerDomains (only flow domain - flow is intentionally excluded from
@@ -455,6 +460,7 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
resolver.waitForPendingResolves(10 * time.Second)
assert.Len(t, removedDomains, 0, "Should not remove any domains when only flow domain is provided")

View File

@@ -282,7 +282,7 @@ func newDefaultServer(
handlerChain := NewHandlerChain()
ctx, stop := context.WithCancel(ctx)
mgmtCacheResolver := mgmt.NewResolver()
mgmtCacheResolver := mgmt.NewResolver(ctx)
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
defaultServer := &DefaultServer{
@@ -622,7 +622,9 @@ func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) erro
s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache)
}
newDomains := s.mgmtCacheResolver.GetCachedDomains()
// Register for the requested domains, not just resolved ones: resolution
// now runs in the background, so the cache may still be empty here.
newDomains := s.mgmtCacheResolver.RequestedDomains(domains)
if len(newDomains) > 0 {
s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache)
}

View File

@@ -1,91 +0,0 @@
package cmd
import (
"context"
"fmt"
"os"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
// newAdminCommands creates the admin command tree with combined-specific resource openers.
func newAdminCommands() *cobra.Command {
cmd := admincmd.NewCommands(withAdminResources)
cmd.AddCommand(tokencmd.NewCommands(withAdminTokenStore))
return cmd
}
// withAdminResources loads the combined YAML config, initializes stores, and calls fn.
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, cfg *CombinedConfig) error {
mgmtConfig, err := cfg.ToManagementConfig()
if err != nil {
return fmt.Errorf("create management config: %w", err)
}
idpStorage, err := admincmd.OpenEmbeddedIDPStorage(mgmtConfig.EmbeddedIdP)
if err != nil {
return err
}
defer func() {
if err := idpStorage.Close(); err != nil {
log.Debugf("close embedded IdP storage: %v", err)
}
}()
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage})
})
}
// withAdminTokenStore opens only the management store for admin token commands.
func withAdminTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, _ *CombinedConfig) error {
return fn(ctx, managementStore)
})
}
func withAdminStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store, cfg *CombinedConfig) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
cfg, err := LoadConfig(configPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
if dsn := cfg.Server.Store.DSN; dsn != "" {
switch strings.ToLower(cfg.Server.Store.Engine) {
case "postgres":
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
case "mysql":
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
if file := cfg.Server.Store.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
managementStore, err := store.NewStore(ctx, types.Engine(cfg.Management.Store.Engine), cfg.Management.DataDir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := managementStore.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, managementStore, cfg)
}

View File

@@ -64,7 +64,7 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
_ = rootCmd.MarkPersistentFlagRequired("config")
rootCmd.AddCommand(newAdminCommands())
rootCmd.AddCommand(newTokenCommands())
}
func RootCmd() *cobra.Command {

63
combined/cmd/token.go Normal file
View File

@@ -0,0 +1,63 @@
package cmd
import (
"context"
"fmt"
"os"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
// newTokenCommands creates the token command tree with combined-specific store opener.
func newTokenCommands() *cobra.Command {
return tokencmd.NewCommands(withTokenStore)
}
// withTokenStore loads the combined YAML config, initializes the store, and calls fn.
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
cfg, err := LoadConfig(configPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
if dsn := cfg.Server.Store.DSN; dsn != "" {
switch strings.ToLower(cfg.Server.Store.Engine) {
case "postgres":
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
case "mysql":
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
if file := cfg.Server.Store.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
datadir := cfg.Management.DataDir
engine := types.Engine(cfg.Management.Store.Engine)
s, err := store.NewStore(ctx, engine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, s)
}

View File

@@ -1,89 +0,0 @@
package cmd
import (
"context"
"fmt"
"path/filepath"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
)
var adminDatadir string
// newAdminCommands creates the admin command tree with management-specific resource openers.
func newAdminCommands() *cobra.Command {
cmd := admincmd.NewCommands(withAdminResources)
cmd.PersistentFlags().StringVar(&adminDatadir, "datadir", "", "Override the data directory from config (used for store.db and the default idp.db)")
cmd.AddCommand(tokencmd.NewCommands(withAdminTokenStore))
return cmd
}
// withAdminResources initializes logging, loads config, opens the management store
// and embedded IdP storage, and calls fn.
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, config *nbconfig.Config) error {
idpStorage, err := admincmd.OpenEmbeddedIDPStorage(config.EmbeddedIdP)
if err != nil {
return err
}
defer func() {
if err := idpStorage.Close(); err != nil {
log.Debugf("close embedded IdP storage: %v", err)
}
}()
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage})
})
}
// withAdminTokenStore opens only the management store for admin token commands.
func withAdminTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, _ *nbconfig.Config) error {
return fn(ctx, managementStore)
})
}
func withAdminStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store, config *nbconfig.Config) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
datadir := config.Datadir
if adminDatadir != "" {
oldDatadir := datadir
datadir = adminDatadir
if config.EmbeddedIdP != nil && config.EmbeddedIdP.Storage.Type == "sqlite3" {
defaultIDPFile := filepath.Join(oldDatadir, "idp.db")
if config.EmbeddedIdP.Storage.Config.File == "" || config.EmbeddedIdP.Storage.Config.File == defaultIDPFile {
config.EmbeddedIdP.Storage.Config.File = filepath.Join(datadir, "idp.db")
}
}
}
managementStore, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := managementStore.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, managementStore, config)
}

View File

@@ -1,441 +0,0 @@
// Package admincmd provides reusable cobra commands for self-hosted administrator helpers.
// Both the management and combined binaries use these commands, each providing
// their own opener to handle config loading and storage initialization.
package admincmd
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"os"
"strings"
"github.com/dexidp/dex/storage"
"github.com/spf13/cobra"
"golang.org/x/crypto/bcrypt"
nbdex "github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
const (
localConnectorID = "local"
dashboardClientID = "netbird-dashboard"
cliClientID = "netbird-cli"
defaultTOTPAuthenticatorID = "default-totp"
)
// Resources contains the storages required by the admin commands.
type Resources struct {
Store store.Store
IDPStorage storage.Storage
}
// Opener initializes command resources from the command context and calls fn.
type Opener func(cmd *cobra.Command, fn func(ctx context.Context, resources Resources) error) error
type userSelector struct {
email string
userID string
}
func (s userSelector) normalized() userSelector {
return userSelector{
email: strings.TrimSpace(s.email),
userID: strings.TrimSpace(s.userID),
}
}
func (s userSelector) validate() error {
s = s.normalized()
if (s.email == "") == (s.userID == "") {
return fmt.Errorf("provide exactly one of --email or --user-id")
}
return nil
}
// NewCommands creates the admin command tree with the given resource opener.
func NewCommands(opener Opener) *cobra.Command {
adminCmd := &cobra.Command{
Use: "admin",
Short: "Self-hosted administrator helpers",
Long: "Administrative helpers for self-hosted deployments using the embedded identity provider.",
}
userCmd := &cobra.Command{
Use: "user",
Short: "Manage local embedded IdP users",
}
var passwordSelector userSelector
var password string
var passwordFile string
passwordCmd := &cobra.Command{
Use: "change-password (--email email | --user-id id) (--password password | --password-file path)",
Aliases: []string{"set-password"},
Short: "Change a local user's password",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
newPassword, err := resolvePasswordInput(cmd, password, passwordFile)
if err != nil {
return err
}
return opener(cmd, func(ctx context.Context, resources Resources) error {
return runChangePassword(ctx, resources.IDPStorage, cmd.OutOrStdout(), passwordSelector, newPassword)
})
},
}
addUserSelectorFlags(passwordCmd, &passwordSelector)
passwordCmd.Flags().StringVar(&password, "password", "", "New password for the user")
passwordCmd.Flags().StringVar(&passwordFile, "password-file", "", "Read new password from file ('-' for stdin)")
var resetSelector userSelector
resetMFACmd := &cobra.Command{
Use: "reset-mfa (--email email | --user-id id)",
Short: "Reset a local user's MFA enrollment",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return opener(cmd, func(ctx context.Context, resources Resources) error {
return runResetMFA(ctx, resources.IDPStorage, cmd.OutOrStdout(), resetSelector)
})
},
}
addUserSelectorFlags(resetMFACmd, &resetSelector)
userCmd.AddCommand(passwordCmd, resetMFACmd)
mfaCmd := &cobra.Command{
Use: "mfa",
Short: "Manage local MFA for embedded IdP users",
}
enableCmd := &cobra.Command{
Use: "enable",
Short: "Enable MFA for local embedded IdP users",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return opener(cmd, func(ctx context.Context, resources Resources) error {
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), true)
})
},
}
disableCmd := &cobra.Command{
Use: "disable",
Short: "Disable MFA for local embedded IdP users",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return opener(cmd, func(ctx context.Context, resources Resources) error {
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), false)
})
},
}
statusCmd := &cobra.Command{
Use: "status",
Short: "Show local MFA status",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return opener(cmd, func(ctx context.Context, resources Resources) error {
return runMFAStatus(ctx, resources, cmd.OutOrStdout())
})
},
}
mfaCmd.AddCommand(enableCmd, disableCmd, statusCmd)
adminCmd.AddCommand(userCmd, mfaCmd)
return adminCmd
}
// OpenEmbeddedIDPStorage opens the Dex storage configured for the embedded IdP.
func OpenEmbeddedIDPStorage(cfg *idp.EmbeddedIdPConfig) (storage.Storage, error) {
if cfg == nil || !cfg.Enabled {
return nil, fmt.Errorf("admin commands require the embedded IdP to be enabled")
}
yamlConfig, err := cfg.ToYAMLConfig()
if err != nil {
return nil, fmt.Errorf("build embedded IdP config: %w", err)
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
st, err := yamlConfig.Storage.OpenStorage(logger)
if err != nil {
return nil, fmt.Errorf("open embedded IdP storage: %w", err)
}
return st, nil
}
func addUserSelectorFlags(cmd *cobra.Command, selector *userSelector) {
cmd.Flags().StringVar(&selector.email, "email", "", "User email")
cmd.Flags().StringVar(&selector.userID, "user-id", "", "User ID")
}
func resolvePasswordInput(cmd *cobra.Command, password, passwordFile string) (string, error) {
if password != "" && passwordFile != "" {
return "", fmt.Errorf("provide only one of --password or --password-file")
}
if passwordFile == "" {
return password, nil
}
var data []byte
var err error
if passwordFile == "-" {
data, err = io.ReadAll(cmd.InOrStdin())
} else {
data, err = os.ReadFile(passwordFile)
}
if err != nil {
return "", fmt.Errorf("read password: %w", err)
}
return strings.TrimRight(string(data), "\r\n"), nil
}
func runChangePassword(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector, password string) error {
if idpStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
selector = selector.normalized()
if err := selector.validate(); err != nil {
return err
}
if password == "" {
return fmt.Errorf("password is required")
}
if err := server.ValidatePassword(password); err != nil {
return fmt.Errorf("invalid password: %w", err)
}
user, err := findLocalUser(ctx, idpStorage, selector)
if err != nil {
return err
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
if err := idpStorage.UpdatePassword(ctx, user.Email, func(old storage.Password) (storage.Password, error) {
old.Hash = hash
return old, nil
}); err != nil {
return fmt.Errorf("update password for %s: %w", user.Email, err)
}
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
return err
}
_, _ = fmt.Fprintf(w, "Password updated for %s.\n", user.Email)
return nil
}
func runResetMFA(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector) error {
if idpStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
selector = selector.normalized()
if err := selector.validate(); err != nil {
return err
}
user, err := findLocalUser(ctx, idpStorage, selector)
if err != nil {
return err
}
reset := false
err = idpStorage.UpdateUserIdentity(ctx, user.UserID, localConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
reset = reset || len(old.MFASecrets) > 0 || len(old.WebAuthnCredentials) > 0
old.MFASecrets = map[string]*storage.MFASecret{}
old.WebAuthnCredentials = map[string][]storage.WebAuthnCredential{}
return old, nil
})
if errors.Is(err, storage.ErrNotFound) {
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
return err
}
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
return nil
}
if err != nil {
return fmt.Errorf("reset MFA for %s: %w", user.Email, err)
}
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
return err
}
if reset {
_, _ = fmt.Fprintf(w, "MFA reset for %s. The user will re-enroll at next login.\n", user.Email)
} else {
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
}
return nil
}
func runSetMFAEnabled(ctx context.Context, resources Resources, w io.Writer, enabled bool) error {
if resources.Store == nil {
return fmt.Errorf("management store is required")
}
if resources.IDPStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
accounts := resources.Store.GetAllAccounts(ctx)
if len(accounts) != 1 {
return fmt.Errorf("expected exactly one account, got %d; local MFA is supported only in single-account embedded IdP deployments", len(accounts))
}
settings := &types.Settings{}
if accounts[0].Settings != nil {
settings = accounts[0].Settings.Copy()
}
settings.LocalMfaEnabled = enabled
if err := resources.Store.SaveAccountSettings(ctx, accounts[0].Id, settings); err != nil {
return fmt.Errorf("save local MFA account setting: %w", err)
}
if err := setIDPClientsMFA(ctx, resources.IDPStorage, enabled); err != nil {
return err
}
state := "disabled"
if enabled {
state = "enabled"
}
_, _ = fmt.Fprintf(w, "Local MFA %s.\n", state)
return nil
}
func runMFAStatus(ctx context.Context, resources Resources, w io.Writer) error {
if resources.Store == nil {
return fmt.Errorf("management store is required")
}
if resources.IDPStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
accounts := resources.Store.GetAllAccounts(ctx)
accountStatus := "unknown"
if len(accounts) == 1 && accounts[0].Settings != nil {
accountStatus = "disabled"
if accounts[0].Settings.LocalMfaEnabled {
accountStatus = "enabled"
}
}
clientStatus, err := idpClientsMFAStatus(ctx, resources.IDPStorage)
if err != nil {
return err
}
_, _ = fmt.Fprintf(w, "Account setting: %s\n", accountStatus)
_, _ = fmt.Fprintf(w, "Embedded IdP clients: %s\n", clientStatus)
return nil
}
func findLocalUser(ctx context.Context, idpStorage storage.Storage, selector userSelector) (storage.Password, error) {
selector = selector.normalized()
if err := selector.validate(); err != nil {
return storage.Password{}, err
}
if selector.email != "" {
user, err := idpStorage.GetPassword(ctx, selector.email)
if errors.Is(err, storage.ErrNotFound) {
return storage.Password{}, fmt.Errorf("local user with email %q not found", selector.email)
}
if err != nil {
return storage.Password{}, fmt.Errorf("get local user by email %q: %w", selector.email, err)
}
return user, nil
}
rawUserID := selector.userID
if decodedUserID, _, err := nbdex.DecodeDexUserID(selector.userID); err == nil && decodedUserID != "" {
rawUserID = decodedUserID
}
users, err := idpStorage.ListPasswords(ctx)
if err != nil {
return storage.Password{}, fmt.Errorf("list local users: %w", err)
}
for _, user := range users {
if user.UserID == rawUserID || user.UserID == selector.userID {
return user, nil
}
}
return storage.Password{}, fmt.Errorf("local user with ID %q not found", selector.userID)
}
func deleteLocalAuthSession(ctx context.Context, idpStorage storage.Storage, userID string) error {
err := idpStorage.DeleteAuthSession(ctx, userID, localConnectorID)
if err == nil || errors.Is(err, storage.ErrNotFound) {
return nil
}
return fmt.Errorf("delete local auth session for user %s: %w", userID, err)
}
func setIDPClientsMFA(ctx context.Context, idpStorage storage.Storage, enabled bool) error {
var mfaChain []string
if enabled {
mfaChain = []string{defaultTOTPAuthenticatorID}
}
for _, clientID := range []string{cliClientID, dashboardClientID} {
if err := idpStorage.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
old.MFAChain = mfaChain
return old, nil
}); err != nil {
if errors.Is(err, storage.ErrNotFound) {
return fmt.Errorf("embedded IdP client %q not found; start the management server once before toggling MFA", clientID)
}
return fmt.Errorf("update MFA chain on embedded IdP client %q: %w", clientID, err)
}
}
return nil
}
func idpClientsMFAStatus(ctx context.Context, idpStorage storage.Storage) (string, error) {
clientIDs := []string{cliClientID, dashboardClientID}
enabledCount := 0
for _, clientID := range clientIDs {
client, err := idpStorage.GetClient(ctx, clientID)
if errors.Is(err, storage.ErrNotFound) {
return "unknown", fmt.Errorf("embedded IdP client %q not found", clientID)
}
if err != nil {
return "unknown", fmt.Errorf("get embedded IdP client %q: %w", clientID, err)
}
if hasAuthenticator(client.MFAChain, defaultTOTPAuthenticatorID) {
enabledCount++
}
}
switch enabledCount {
case 0:
return "disabled", nil
case len(clientIDs):
return "enabled", nil
default:
return "partially enabled", nil
}
}
func hasAuthenticator(chain []string, authenticatorID string) bool {
for _, id := range chain {
if id == authenticatorID {
return true
}
}
return false
}

View File

@@ -1,160 +0,0 @@
package admincmd
import (
"bytes"
"context"
"io"
"log/slog"
"strings"
"testing"
"time"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/memory"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
nbdex "github.com/netbirdio/netbird/idp/dex"
)
func newTestIDPStorage(t *testing.T) storage.Storage {
t.Helper()
st := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
hash, err := bcrypt.GenerateFromPassword([]byte("OldPass1!"), bcrypt.DefaultCost)
require.NoError(t, err)
require.NoError(t, st.CreatePassword(context.Background(), storage.Password{
Email: "user@example.com",
Username: "User",
UserID: "user-1",
Hash: hash,
}))
require.NoError(t, st.CreateUserIdentity(context.Background(), storage.UserIdentity{
UserID: "user-1",
ConnectorID: localConnectorID,
MFASecrets: map[string]*storage.MFASecret{
defaultTOTPAuthenticatorID: {
AuthenticatorID: defaultTOTPAuthenticatorID,
Type: "TOTP",
Secret: "otpauth://totp/NetBird:user@example.com?secret=ABC",
Confirmed: true,
CreatedAt: time.Now(),
},
},
WebAuthnCredentials: map[string][]storage.WebAuthnCredential{
"webauthn": {{CredentialID: []byte("credential")}},
},
}))
require.NoError(t, st.CreateAuthSession(context.Background(), storage.AuthSession{
UserID: "user-1",
ConnectorID: localConnectorID,
Nonce: "nonce",
}))
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: cliClientID, Name: "CLI"}))
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: dashboardClientID, Name: "Dashboard"}))
return st
}
func TestRunChangePassword(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
var out bytes.Buffer
err := runChangePassword(ctx, st, &out, userSelector{email: "user@example.com"}, "NewPass1!")
require.NoError(t, err)
require.Contains(t, out.String(), "Password updated")
user, err := st.GetPassword(ctx, "user@example.com")
require.NoError(t, err)
require.NoError(t, bcrypt.CompareHashAndPassword(user.Hash, []byte("NewPass1!")))
_, err = st.GetAuthSession(ctx, "user-1", localConnectorID)
require.ErrorIs(t, err, storage.ErrNotFound)
}
func TestRunChangePasswordValidatesPassword(t *testing.T) {
st := newTestIDPStorage(t)
err := runChangePassword(context.Background(), st, io.Discard, userSelector{email: "user@example.com"}, "short")
require.Error(t, err)
require.Contains(t, err.Error(), "invalid password")
}
func TestRunResetMFA(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
var out bytes.Buffer
encodedUserID := nbdex.EncodeDexUserID("user-1", localConnectorID)
err := runResetMFA(ctx, st, &out, userSelector{userID: encodedUserID})
require.NoError(t, err)
require.Contains(t, out.String(), "MFA reset")
identity, err := st.GetUserIdentity(ctx, "user-1", localConnectorID)
require.NoError(t, err)
require.Empty(t, identity.MFASecrets)
require.Empty(t, identity.WebAuthnCredentials)
_, err = st.GetAuthSession(ctx, "user-1", localConnectorID)
require.ErrorIs(t, err, storage.ErrNotFound)
}
func TestRunResetMFAWithoutEnrollment(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
require.NoError(t, st.UpdateUserIdentity(ctx, "user-1", localConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
old.MFASecrets = nil
old.WebAuthnCredentials = nil
return old, nil
}))
var out bytes.Buffer
err := runResetMFA(ctx, st, &out, userSelector{email: "user@example.com"})
require.NoError(t, err)
require.Contains(t, out.String(), "No MFA enrollment found")
}
func TestSetIDPClientsMFA(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
require.NoError(t, setIDPClientsMFA(ctx, st, true))
status, err := idpClientsMFAStatus(ctx, st)
require.NoError(t, err)
require.Equal(t, "enabled", status)
require.NoError(t, setIDPClientsMFA(ctx, st, false))
status, err = idpClientsMFAStatus(ctx, st)
require.NoError(t, err)
require.Equal(t, "disabled", status)
}
func TestUserSelectorValidate(t *testing.T) {
require.NoError(t, userSelector{email: " user@example.com "}.validate())
require.NoError(t, userSelector{userID: "user-1"}.validate())
require.Error(t, userSelector{}.validate())
require.Error(t, userSelector{email: "user@example.com", userID: "user-1"}.validate())
}
func TestFindLocalUserNotFound(t *testing.T) {
st := newTestIDPStorage(t)
_, err := findLocalUser(context.Background(), st, userSelector{email: "missing@example.com"})
require.Error(t, err)
require.True(t, strings.Contains(err.Error(), "not found"))
}
func TestResolvePasswordInputFromStdin(t *testing.T) {
cmd := &cobra.Command{}
cmd.SetIn(strings.NewReader("NewPass1!\n"))
password, err := resolvePasswordInput(cmd, "", "-")
require.NoError(t, err)
require.Equal(t, "NewPass1!", password)
}
func TestResolvePasswordInputRejectsMultipleSources(t *testing.T) {
_, err := resolvePasswordInput(&cobra.Command{}, "NewPass1!", "-")
require.Error(t, err)
}

View File

@@ -83,7 +83,7 @@ func init() {
rootCmd.AddCommand(migrationCmd)
ac := newAdminCommands()
ac.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
rootCmd.AddCommand(ac)
tc := newTokenCommands()
tc.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
rootCmd.AddCommand(tc)
}

55
management/cmd/token.go Normal file
View File

@@ -0,0 +1,55 @@
package cmd
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
)
var tokenDatadir string
// newTokenCommands creates the token command tree with management-specific store opener.
func newTokenCommands() *cobra.Command {
cmd := tokencmd.NewCommands(withTokenStore)
cmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
return cmd
}
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
datadir := config.Datadir
if tokenDatadir != "" {
datadir = tokenDatadir
}
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, s)
}

View File

@@ -1847,17 +1847,12 @@ func (am *DefaultAccountManager) DeleteUserInvite(ctx context.Context, accountID
const minPasswordLength = 8
// validatePassword checks password strength requirements.
func validatePassword(password string) error {
return ValidatePassword(password)
}
// ValidatePassword checks password strength requirements:
// validatePassword checks password strength requirements:
// - Minimum 8 characters
// - At least 1 digit
// - At least 1 uppercase letter
// - At least 1 special character
func ValidatePassword(password string) error {
func validatePassword(password string) error {
if len(password) < minPasswordLength {
return errors.New("password must be at least 8 characters long")
}