mirror of
https://github.com/netbirdio/netbird.git
synced 2026-07-04 21:59:55 +00:00
Compare commits
9 Commits
netmap_pro
...
feat/admin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f25011f9ca | ||
|
|
1d5d8ef1b7 | ||
|
|
1428970a24 | ||
|
|
3aa6c02b93 | ||
|
|
f6900fb07c | ||
|
|
fe3c14413c | ||
|
|
520370a8b0 | ||
|
|
b5a16a1898 | ||
|
|
449b5cbb80 |
@@ -220,12 +220,6 @@ type Engine struct {
|
||||
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
||||
networkSerial uint64
|
||||
|
||||
// forwardingRules holds the ingress forward rules applied for the current target.
|
||||
// Wholesale sections (incl. forward rules) run only on the first pass of a target;
|
||||
// it is stashed here so the final, peer-converged pass can build the lazy-connection
|
||||
// exclude list without recomputing them on every bounded peer pass.
|
||||
forwardingRules []firewallManager.ForwardRule
|
||||
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServer sshServer
|
||||
@@ -780,15 +774,7 @@ func (e *Engine) blockLanAccess() {
|
||||
|
||||
// modifyPeers updates peers that have been modified (e.g. IP address has been changed).
|
||||
// It closes the existing connection, removes it from the peerConns map, and creates a new one.
|
||||
// maxPeersPerSyncPass is the default per-pass cap on how many peers each of
|
||||
// removePeers/modifyPeers/addNewPeers applies, so syncMsgMux is held only for a
|
||||
// batch at a time and other subsystems can interleave between passes. It is
|
||||
// passed in (not read globally) so tests can exercise the multi-pass path.
|
||||
const maxPeersPerSyncPass = 300
|
||||
|
||||
// modifyPeers re-applies up to maxBatch changed peers per call. It returns true
|
||||
// when more changed peers remained than the cap, so the caller re-runs.
|
||||
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
|
||||
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
|
||||
// first, check if peers have been modified
|
||||
var modified []*mgmProto.RemotePeerConfig
|
||||
@@ -818,32 +804,26 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch
|
||||
}
|
||||
}
|
||||
|
||||
more := false
|
||||
if len(modified) > maxBatch {
|
||||
modified = modified[:maxBatch]
|
||||
more = true
|
||||
}
|
||||
|
||||
// second, close all modified connections and remove them from the state map
|
||||
for _, p := range modified {
|
||||
if err := e.removePeer(p.GetWgPubKey()); err != nil {
|
||||
return false, err
|
||||
err := e.removePeer(p.GetWgPubKey())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// third, add the peer connections again
|
||||
for _, p := range modified {
|
||||
if err := e.addNewPeer(p); err != nil {
|
||||
return false, err
|
||||
err := e.addNewPeer(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return more, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
|
||||
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
|
||||
// removePeers removes up to maxBatch peers per call. It returns true when more
|
||||
// peers remained to remove than the cap, so the caller re-runs.
|
||||
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
|
||||
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
newPeers := make([]string, 0, len(peersUpdate))
|
||||
for _, p := range peersUpdate {
|
||||
newPeers = append(newPeers, p.GetWgPubKey())
|
||||
@@ -851,19 +831,14 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch
|
||||
|
||||
toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers)
|
||||
|
||||
more := false
|
||||
if len(toRemove) > maxBatch {
|
||||
toRemove = toRemove[:maxBatch]
|
||||
more = true
|
||||
}
|
||||
|
||||
for _, p := range toRemove {
|
||||
if err := e.removePeer(p); err != nil {
|
||||
return false, err
|
||||
err := e.removePeer(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("removed peer %s", p)
|
||||
}
|
||||
return more, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) removeAllPeers() error {
|
||||
@@ -942,17 +917,19 @@ func (e *Engine) phase(name string) func() {
|
||||
}
|
||||
}
|
||||
|
||||
// applySyncPass applies one bounded pass of the sync update under syncMsgMux and
|
||||
// returns true if more peers remained than the per-pass cap. It is driven by the
|
||||
// mapStateManager, which re-invokes it (releasing the lock between passes) until
|
||||
// the update is fully applied.
|
||||
func (e *Engine) applySyncPass(update *mgmProto.SyncResponse, firstPass bool) (bool, error) {
|
||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
started := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(started)
|
||||
log.Infof("sync finished in %s", duration)
|
||||
e.clientMetrics.RecordSyncDuration(e.ctx, duration)
|
||||
}()
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||
if e.ctx.Err() != nil {
|
||||
return false, e.ctx.Err()
|
||||
return e.ctx.Err()
|
||||
}
|
||||
|
||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||
@@ -963,7 +940,7 @@ func (e *Engine) applySyncPass(update *mgmProto.SyncResponse, firstPass bool) (b
|
||||
err := e.updateNetbirdConfig(update.GetNetbirdConfig())
|
||||
done()
|
||||
if err != nil {
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
|
||||
// Posture checks are bound to the network map presence:
|
||||
@@ -973,25 +950,28 @@ func (e *Engine) applySyncPass(update *mgmProto.SyncResponse, firstPass bool) (b
|
||||
// leave the previously applied checks untouched
|
||||
nm := update.GetNetworkMap()
|
||||
if nm == nil {
|
||||
return false, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
done = e.phase("checks")
|
||||
err = e.updateChecksIfNew(update.Checks)
|
||||
done()
|
||||
if err != nil {
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
|
||||
done = e.phase("persist")
|
||||
e.persistSyncResponse(update)
|
||||
done()
|
||||
|
||||
// only apply new changes and ignore old ones
|
||||
more, err := e.updateNetworkMap(nm, maxPeersPerSyncPass, firstPass)
|
||||
if err != nil {
|
||||
return false, err
|
||||
if err := e.updateNetworkMap(nm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.statusRecorder.PublishEvent(cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, "Network map updated", "", nil)
|
||||
|
||||
return more, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateNetbirdConfig applies the management-provided NetBird configuration:
|
||||
@@ -1039,13 +1019,6 @@ func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error {
|
||||
// (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled /
|
||||
// engine close) mid-call and have this write resurrect a file that was just removed.
|
||||
func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) {
|
||||
// Only persist updates that carry a network map. Config-only updates (e.g. relay
|
||||
// token rotation, STUN/TURN) have a nil NetworkMap; persisting them would overwrite
|
||||
// the last full map on disk and break restore-on-restart.
|
||||
if update.GetNetworkMap() == nil {
|
||||
return
|
||||
}
|
||||
|
||||
e.syncRespMux.RLock()
|
||||
defer e.syncRespMux.RUnlock()
|
||||
|
||||
@@ -1333,24 +1306,7 @@ func (e *Engine) receiveManagementEvents() {
|
||||
}
|
||||
e.applyInfoFlags(info)
|
||||
|
||||
// The map-state manager converges the latest update in the background in
|
||||
// bounded passes; the stream callback only hands it the newest target.
|
||||
persist := func(u *mgmProto.SyncResponse) {
|
||||
done := e.phase("persist")
|
||||
e.persistSyncResponse(u)
|
||||
done()
|
||||
}
|
||||
manager := newMapStateManager(e.applySyncPass, persist, func(d time.Duration) {
|
||||
log.Infof("sync finished in %s", d)
|
||||
e.clientMetrics.RecordSyncDuration(e.ctx, d)
|
||||
})
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
manager.run(e.ctx)
|
||||
}()
|
||||
|
||||
err := e.mgmClient.Sync(e.ctx, info, manager.SetTarget)
|
||||
err := e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
if err != nil {
|
||||
// happens if management is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
@@ -1401,107 +1357,21 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateNetworkMap applies the wholesale parts (config, routes, ACL, DNS) in full
|
||||
// and up to maxBatch peers per phase. It returns true when more peers remained
|
||||
// than the cap, so the caller re-runs until convergence.
|
||||
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap, maxBatch int, firstPass bool) (bool, error) {
|
||||
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
|
||||
if networkMap.GetPeerConfig() != nil {
|
||||
err := e.updateConfig(networkMap.GetPeerConfig())
|
||||
if err != nil {
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
serial := networkMap.GetSerial()
|
||||
if e.networkSerial > serial {
|
||||
log.Debugf("received outdated NetworkMap with serial %d, ignoring", serial)
|
||||
return false, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wholesale sections (firewall/ACL, DNS, routes, forward rules) are applied
|
||||
// up-front and only once per target: they are cheap, local, idempotent and must
|
||||
// be in place before peers come up (fail-closed). On the bounded re-runs that only
|
||||
// drain the remaining peer batches they are skipped — the applied forward rules are
|
||||
// reused from e.forwardingRules for the lazy-exclude finalize.
|
||||
if firstPass {
|
||||
e.applyWholesale(networkMap, serial)
|
||||
}
|
||||
|
||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||
|
||||
doneOffline := e.phase("offline_peers")
|
||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||
doneOffline()
|
||||
|
||||
// Filter out own peer from the remote peers list
|
||||
localPubKey := e.config.WgPrivateKey.PublicKey().String()
|
||||
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
|
||||
for _, p := range networkMap.GetRemotePeers() {
|
||||
if p.GetWgPubKey() != localPubKey {
|
||||
remotePeers = append(remotePeers, p)
|
||||
}
|
||||
}
|
||||
|
||||
// No special case for cleanup: when management signals RemotePeersIsEmpty (e.g. our
|
||||
// peer was deleted), remotePeers is already empty, so the bounded diff below removes
|
||||
// every peer in batches — same path as a normal update, no unbounded removeAllPeers
|
||||
// held under syncMsgMux in one shot.
|
||||
doneRemoved := e.phase("removed_peers")
|
||||
removeMore, err := e.removePeers(remotePeers, maxBatch)
|
||||
doneRemoved()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
doneModified := e.phase("modified_peers")
|
||||
modifyMore, err := e.modifyPeers(remotePeers, maxBatch)
|
||||
doneModified()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
doneAdded := e.phase("added_peers")
|
||||
addMore, err := e.addNewPeers(remotePeers, maxBatch)
|
||||
doneAdded()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// needMore signals the caller to re-run when a peer phase hit its per-pass cap.
|
||||
needMore := removeMore || modifyMore || addMore
|
||||
|
||||
e.statusRecorder.FinishPeerListModifications()
|
||||
|
||||
e.updatePeerSSHHostKeys(remotePeers)
|
||||
|
||||
if err := e.updateSSHClientConfig(remotePeers); err != nil {
|
||||
log.Warnf("failed to update SSH client config: %v", err)
|
||||
}
|
||||
|
||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||
|
||||
// Set the exclude list only once peers have fully converged (this pass added
|
||||
// the last batch). It needs all target peers present in the store, and
|
||||
// ExcludePeer has replace-semantics — a partial set mid-convergence would be wrong.
|
||||
if !needMore {
|
||||
doneLazy := e.phase("lazy_exclude")
|
||||
excludedLazyPeers := e.toExcludedLazyPeers(e.forwardingRules, remotePeers)
|
||||
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
||||
doneLazy()
|
||||
}
|
||||
|
||||
e.networkSerial = serial
|
||||
|
||||
return needMore, nil
|
||||
}
|
||||
|
||||
// applyWholesale applies the cheap, local, idempotent map sections — lazy feature
|
||||
// flag, firewall/legacy management, DNS, routes, ACL filtering, DNS forwarder and
|
||||
// ingress forward rules — that must be in place before peers come up. It runs once
|
||||
// per target (first pass only); the resulting forward rules are stashed in
|
||||
// e.forwardingRules for the lazy-exclude finalize on the peer-converged pass.
|
||||
func (e *Engine) applyWholesale(networkMap *mgmProto.NetworkMap, serial uint64) {
|
||||
if err := e.connMgr.UpdatedRemoteFeatureFlag(e.ctx, networkMap.GetPeerConfig().GetLazyConnectionEnabled()); err != nil {
|
||||
log.Errorf("failed to update lazy connection feature flag: %v", err)
|
||||
}
|
||||
@@ -1574,7 +1444,84 @@ func (e *Engine) applyWholesale(networkMap *mgmProto.NetworkMap, serial uint64)
|
||||
log.Errorf("failed to update forward rules, err: %v", err)
|
||||
}
|
||||
done()
|
||||
e.forwardingRules = forwardingRules
|
||||
|
||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||
|
||||
done = e.phase("offline_peers")
|
||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||
done()
|
||||
|
||||
remotePeers, err := e.reconcilePeers(networkMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
done = e.phase("lazy_exclude")
|
||||
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
|
||||
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
||||
done()
|
||||
|
||||
e.networkSerial = serial
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reconcilePeers applies the remote peer list from the network map (removing,
|
||||
// modifying and adding peers, then updating SSH config) and returns the remote
|
||||
// peers with our own peer filtered out, for use by later sync steps.
|
||||
func (e *Engine) reconcilePeers(networkMap *mgmProto.NetworkMap) ([]*mgmProto.RemotePeerConfig, error) {
|
||||
// Filter out own peer from the remote peers list
|
||||
localPubKey := e.config.WgPrivateKey.PublicKey().String()
|
||||
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
|
||||
for _, p := range networkMap.GetRemotePeers() {
|
||||
if p.GetWgPubKey() != localPubKey {
|
||||
remotePeers = append(remotePeers, p)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup request, most likely our peer has been deleted
|
||||
if networkMap.GetRemotePeersIsEmpty() {
|
||||
err := e.removeAllPeers()
|
||||
e.statusRecorder.FinishPeerListModifications()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return remotePeers, nil
|
||||
}
|
||||
|
||||
done := e.phase("removed_peers")
|
||||
err := e.removePeers(remotePeers)
|
||||
done()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
done = e.phase("modified_peers")
|
||||
err = e.modifyPeers(remotePeers)
|
||||
done()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
done = e.phase("added_peers")
|
||||
err = e.addNewPeers(remotePeers)
|
||||
done()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
e.statusRecorder.FinishPeerListModifications()
|
||||
|
||||
e.updatePeerSSHHostKeys(remotePeers)
|
||||
|
||||
if err := e.updateSSHClientConfig(remotePeers); err != nil {
|
||||
log.Warnf("failed to update SSH client config: %v", err)
|
||||
}
|
||||
|
||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||
|
||||
return remotePeers, nil
|
||||
}
|
||||
|
||||
func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool {
|
||||
@@ -1754,23 +1701,14 @@ func addrToString(addr netip.Addr) string {
|
||||
}
|
||||
|
||||
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
|
||||
// addNewPeers adds up to maxBatch not-yet-present peers per call. It returns true
|
||||
// when more new peers remained than the cap, so the caller re-runs.
|
||||
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
|
||||
added := 0
|
||||
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
for _, p := range peersUpdate {
|
||||
if _, ok := e.peerStore.PeerConn(p.GetWgPubKey()); ok {
|
||||
continue // already present (cheap skip), does not count toward the cap
|
||||
err := e.addNewPeer(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if added >= maxBatch {
|
||||
return true, nil // at least one more new peer remains
|
||||
}
|
||||
if err := e.addNewPeer(p); err != nil {
|
||||
return false, err
|
||||
}
|
||||
added++
|
||||
}
|
||||
return false, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// addNewPeer add peer if connection doesn't exist
|
||||
|
||||
@@ -124,7 +124,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
@@ -146,7 +146,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
@@ -159,7 +159,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
// time.Sleep(250 * time.Millisecond)
|
||||
@@ -174,7 +174,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
|
||||
@@ -437,7 +437,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
|
||||
for _, c := range []testCase{case1, case2, case3, case4, case5, case6} {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
_, err = engine.updateNetworkMap(c.networkMap, maxPeersPerSyncPass, true)
|
||||
err = engine.updateNetworkMap(c.networkMap)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -464,47 +464,6 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// chunked apply: with a per-pass cap smaller than the number of peers, a
|
||||
// single updateNetworkMap applies one batch and reports more==true; the
|
||||
// caller re-runs until convergence. (engine currently holds 0 peers.)
|
||||
t.Run("chunked add converges over multiple passes", func(t *testing.T) {
|
||||
nm := &mgmtProto.NetworkMap{
|
||||
Serial: 6,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
|
||||
}
|
||||
|
||||
more, err := engine.updateNetworkMap(nm, 1, true)
|
||||
require.NoError(t, err)
|
||||
require.True(t, more, "pass 1 should signal more")
|
||||
require.Len(t, engine.peerStore.PeersPubKey(), 1)
|
||||
|
||||
more, err = engine.updateNetworkMap(nm, 1, false)
|
||||
require.NoError(t, err)
|
||||
require.True(t, more, "pass 2 should signal more")
|
||||
require.Len(t, engine.peerStore.PeersPubKey(), 2)
|
||||
|
||||
more, err = engine.updateNetworkMap(nm, 1, false)
|
||||
require.NoError(t, err)
|
||||
require.False(t, more, "pass 3 should converge")
|
||||
require.Len(t, engine.peerStore.PeersPubKey(), 3)
|
||||
})
|
||||
|
||||
t.Run("chunked remove converges over multiple passes", func(t *testing.T) {
|
||||
nm := &mgmtProto.NetworkMap{
|
||||
Serial: 7,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1}, // remove peer2, peer3
|
||||
}
|
||||
|
||||
more, err := engine.updateNetworkMap(nm, 1, true)
|
||||
require.NoError(t, err)
|
||||
require.True(t, more, "pass 1 should signal more (2 to remove, cap 1)")
|
||||
|
||||
more, err = engine.updateNetworkMap(nm, 1, false)
|
||||
require.NoError(t, err)
|
||||
require.False(t, more, "pass 2 should converge")
|
||||
require.Len(t, engine.peerStore.PeersPubKey(), 1)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
@@ -675,7 +634,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = engine.updateNetworkMap(testCase.networkMap, maxPeersPerSyncPass, true)
|
||||
err = engine.updateNetworkMap(testCase.networkMap)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
|
||||
assert.Len(t, input.clientRoutes, testCase.expectedLen, "clientRoutes len should match")
|
||||
@@ -879,7 +838,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = engine.updateNetworkMap(testCase.networkMap, maxPeersPerSyncPass, true)
|
||||
err = engine.updateNetworkMap(testCase.networkMap)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
|
||||
assert.Len(t, input.inputNSGroups, testCase.expectedZonesLen, "zones len should match")
|
||||
|
||||
@@ -1,214 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// mapStateManager is the single read/write point between the management stream
|
||||
// (writes) and the convergence loop (reads/applies).
|
||||
//
|
||||
// The stream calls SetTarget with the latest full SyncResponse — the complete
|
||||
// desired state. A single background goroutine (run) applies it to the engine in
|
||||
// bounded passes via apply() until converged, releasing syncMsgMux between passes
|
||||
// so other subsystems interleave. If a newer update arrives mid-flight, the loop
|
||||
// coalesces: it keeps converging toward the latest target and the intermediate one
|
||||
// is SKIPPED — never applied on its own (logged, no onConverged).
|
||||
//
|
||||
// Convergence is a single comparison: appliedGen == targetGen. targetGen
|
||||
// increments on every SetTarget (an internal generation counter, so it also covers
|
||||
// config-only updates that carry no network-map serial).
|
||||
//
|
||||
// onConverged fires once for each — and only each — map that is actually processed
|
||||
// (i.e. converged as the target). Skipped/superseded maps and dropped-on-error maps
|
||||
// do NOT fire it. So "sync finished in X" / RecordSyncDuration always corresponds
|
||||
// to a real, completed alignment.
|
||||
type mapStateManager struct {
|
||||
// apply performs one bounded apply pass and reports whether more passes are needed.
|
||||
// firstPass is true on the first pass of a given target, so the caller can run
|
||||
// wholesale (firewall/routes/DNS/forward-rules) once per target and skip it on the
|
||||
// re-runs that only drain the bounded peer batches. The manager owns this signal
|
||||
// because it owns the convergence boundary; the engine need not track serials for it.
|
||||
apply func(update *mgmProto.SyncResponse, firstPass bool) (bool, error)
|
||||
// onConverged is called once per processed map, with the elapsed time since that
|
||||
// map was received (for the sync-duration metric / "sync finished" log).
|
||||
onConverged func(time.Duration)
|
||||
// persist snapshots an update to disk for restore-on-restart. Called once per
|
||||
// update received from management (in SetTarget), including ones later coalesced
|
||||
// or skipped from apply, so the on-disk state mirrors what management last sent.
|
||||
// The impl skips config-only updates (nil NetworkMap). May be nil.
|
||||
persist func(*mgmProto.SyncResponse)
|
||||
|
||||
mu sync.Mutex
|
||||
target *mgmProto.SyncResponse
|
||||
targetGen uint64
|
||||
appliedGen uint64
|
||||
targetSetAt time.Time
|
||||
|
||||
wake chan struct{}
|
||||
}
|
||||
|
||||
func newMapStateManager(apply func(update *mgmProto.SyncResponse, firstPass bool) (bool, error), persist func(*mgmProto.SyncResponse), onConverged func(time.Duration)) *mapStateManager {
|
||||
return &mapStateManager{
|
||||
apply: apply,
|
||||
persist: persist,
|
||||
onConverged: onConverged,
|
||||
wake: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// SetTarget records the latest update as the desired state and wakes the loop.
|
||||
// It returns immediately; convergence happens in the background. Serial-based
|
||||
// staleness of the network map is still enforced inside apply (updateNetworkMap).
|
||||
func (m *mapStateManager) SetTarget(update *mgmProto.SyncResponse) error {
|
||||
m.mu.Lock()
|
||||
// A target that has not settled yet (targetGen > appliedGen) is being superseded
|
||||
// before it converged: we coalesce to the latest map and never apply this one on
|
||||
// its own. It is SKIPPED — logged here, and it will not fire onConverged.
|
||||
if m.target != nil && m.targetGen > m.appliedGen {
|
||||
log.Debugf("sync map (gen %d) superseded before convergence, skipping", m.targetGen)
|
||||
}
|
||||
m.target = m.mergeTarget(m.target, update)
|
||||
// Bump an internal generation counter, NOT the map serial: config-only updates
|
||||
// (relay token rotation, STUN/TURN) arrive with NetworkMap == nil and carry no
|
||||
// serial, yet must still be applied. Every SetTarget is therefore a distinct
|
||||
// target regardless of payload. Map-serial staleness is enforced separately
|
||||
// inside apply (updateNetworkMap).
|
||||
m.targetGen++
|
||||
m.targetSetAt = time.Now()
|
||||
m.mu.Unlock()
|
||||
|
||||
select {
|
||||
case m.wake <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
// Persist every update received from management — once per update (not per apply
|
||||
// pass), and including ones that get coalesced/skipped from apply, so the on-disk
|
||||
// state always reflects the latest map management sent. Done after waking the loop
|
||||
// so convergence can start in parallel with the disk write. The persist impl skips
|
||||
// config-only updates (nil NetworkMap).
|
||||
if m.persist != nil {
|
||||
m.persist(update)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// mergeTarget combines the currently pending target with a freshly received update
|
||||
// and returns the new desired state. It is called under m.mu from SetTarget and is
|
||||
// the single seam where the replace-vs-squash decision lives.
|
||||
//
|
||||
// Today management always sends a FULL map (the complete desired state), so the
|
||||
// update simply replaces whatever was pending — prev is ignored. When management
|
||||
// starts sending incremental/delta updates, squash `update` onto `prev` here; the
|
||||
// rest of the manager (generation tracking, convergence, signaling) is unaffected
|
||||
// because it already treats target as "the complete desired state, whatever it is".
|
||||
func (m *mapStateManager) mergeTarget(prev, update *mgmProto.SyncResponse) *mgmProto.SyncResponse {
|
||||
// Nothing pending to preserve (no prev, or prev already fully applied): plain replace.
|
||||
if prev == nil || update == nil || m.targetGen == m.appliedGen {
|
||||
return update
|
||||
}
|
||||
|
||||
// prev still has unapplied state (targetGen > appliedGen). In the sync protocol a
|
||||
// nil component means "no change", so if `update` omits a component that prev
|
||||
// carried, carry prev's forward — otherwise coalescing an update that superseded a
|
||||
// not-yet-applied one would silently drop the map or config it uniquely brought.
|
||||
// A present component in `update` is newer and wins. Management may send map-only
|
||||
// updates (nil config) and config-only updates (nil map); both are handled here.
|
||||
// A nil component in `update` means "no change", so fill it in from prev — otherwise
|
||||
// coalescing an update that superseded a not-yet-applied one would drop the map or
|
||||
// config it uniquely carried. A present component in `update` is newer and wins.
|
||||
// We mutate `update` in place: it is a fresh per-message allocation from the sync
|
||||
// stream (see receiveUpdatesEvents — not reused), and persisting this squashed target
|
||||
// is correct, since it is the current full (superset) desired state.
|
||||
if update.GetNetworkMap() == nil && prev.GetNetworkMap() != nil {
|
||||
update.NetworkMap = prev.GetNetworkMap()
|
||||
update.Checks = prev.Checks // checks travel with the map
|
||||
}
|
||||
if update.GetNetbirdConfig() == nil && prev.GetNetbirdConfig() != nil {
|
||||
update.NetbirdConfig = prev.GetNetbirdConfig()
|
||||
}
|
||||
return update
|
||||
}
|
||||
|
||||
// run drives convergence until ctx is done. It is meant to run in its own goroutine.
|
||||
func (m *mapStateManager) run(ctx context.Context) {
|
||||
// passGen is the generation of the most recent apply() call (0 = none). A pass is
|
||||
// the first for its target when its generation differs from the previous one —
|
||||
// true on a fresh target and on a coalesced switch to a newer target mid-flight.
|
||||
var passGen uint64
|
||||
for {
|
||||
m.mu.Lock()
|
||||
target, tg, ag := m.target, m.targetGen, m.appliedGen
|
||||
m.mu.Unlock()
|
||||
|
||||
// Fully converged (or nothing yet): block until a new target arrives.
|
||||
if target == nil || ag == tg {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-m.wake:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
firstPass := tg != passGen
|
||||
passGen = tg
|
||||
more, err := m.apply(target, firstPass)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
// Log and DROP this target — do not retry it. A deterministic failure
|
||||
// (e.g. a malformed peer in the map) would otherwise spin every pass
|
||||
// making no progress. Management is the source of truth and re-delivers
|
||||
// the full map on the next sync, so dropping is safe; peers already
|
||||
// applied this convergence stay (idempotent diffs) and the remainder is
|
||||
// reconciled by the next target. Mirrors the legacy handleSync path,
|
||||
// where the apply error was logged by the gRPC client and the update
|
||||
// dropped. No onConverged: this target did not converge.
|
||||
log.Errorf("apply sync pass, dropping update: %v", err)
|
||||
m.settle(tg, false)
|
||||
continue
|
||||
}
|
||||
|
||||
if more {
|
||||
// keep converging the current target; syncMsgMux was released by apply
|
||||
// between passes so other subsystems interleave.
|
||||
continue
|
||||
}
|
||||
|
||||
// This pass converged. Mark applied and signal this one map.
|
||||
m.settle(tg, true)
|
||||
// if a newer target arrived mid-pass, settle is a no-op (targetGen != tg) and
|
||||
// ag<tg next iteration -> apply it; this generation was skipped (logged in
|
||||
// SetTarget) and is not signaled.
|
||||
}
|
||||
}
|
||||
|
||||
// settle marks generation tg as processed so the loop goes idle instead of
|
||||
// re-applying the same target. It is a no-op when a newer target arrived during the
|
||||
// pass (targetGen != tg), leaving appliedGen behind so that target re-applies — the
|
||||
// just-finished generation was already counted as skipped.
|
||||
//
|
||||
// When signal is true (the pass converged) it fires onConverged once for this map;
|
||||
// when false (the target was dropped on error) it does not — the map did not converge.
|
||||
func (m *mapStateManager) settle(tg uint64, signal bool) {
|
||||
m.mu.Lock()
|
||||
if m.targetGen != tg {
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
m.appliedGen = tg
|
||||
setAt := m.targetSetAt
|
||||
m.mu.Unlock()
|
||||
|
||||
if signal && m.onConverged != nil {
|
||||
m.onConverged(time.Since(setAt))
|
||||
}
|
||||
}
|
||||
@@ -1,281 +0,0 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// mergeTarget fills components missing from the incoming update with the pending
|
||||
// (not-yet-applied) prev's, in place, so a coalesced/superseded update does not drop
|
||||
// the map or config it uniquely carried.
|
||||
func TestMapStateManager_MergeTargetPreservesPendingState(t *testing.T) {
|
||||
m := newMapStateManager(nil, nil, nil)
|
||||
|
||||
// config-only update while a full map is still converging (targetGen > appliedGen):
|
||||
// the pending map (+ checks) is filled into the update in place
|
||||
m.targetGen, m.appliedGen = 5, 4
|
||||
prev := &mgmProto.SyncResponse{NetworkMap: &mgmProto.NetworkMap{Serial: 5}}
|
||||
update := &mgmProto.SyncResponse{NetbirdConfig: &mgmProto.NetbirdConfig{}}
|
||||
merged := m.mergeTarget(prev, update)
|
||||
require.Same(t, update, merged, "merges in place, returns the update")
|
||||
require.EqualValues(t, 5, merged.GetNetworkMap().GetSerial(), "pending map preserved")
|
||||
require.NotNil(t, merged.GetNetbirdConfig(), "new config kept")
|
||||
|
||||
// symmetric: map-only update while a config-only update is pending -> keep the config
|
||||
m.targetGen, m.appliedGen = 5, 4
|
||||
prev = &mgmProto.SyncResponse{NetbirdConfig: &mgmProto.NetbirdConfig{}}
|
||||
update = &mgmProto.SyncResponse{NetworkMap: &mgmProto.NetworkMap{Serial: 7}}
|
||||
merged = m.mergeTarget(prev, update)
|
||||
require.EqualValues(t, 7, merged.GetNetworkMap().GetSerial(), "new map kept")
|
||||
require.NotNil(t, merged.GetNetbirdConfig(), "pending config preserved")
|
||||
|
||||
// prev already applied (targetGen == appliedGen): plain replace, no fill-in
|
||||
m.targetGen, m.appliedGen = 5, 5
|
||||
prev = &mgmProto.SyncResponse{NetworkMap: &mgmProto.NetworkMap{Serial: 5}}
|
||||
update = &mgmProto.SyncResponse{NetbirdConfig: &mgmProto.NetbirdConfig{}}
|
||||
merged = m.mergeTarget(prev, update)
|
||||
require.Same(t, update, merged)
|
||||
require.Nil(t, merged.GetNetworkMap(), "no map grafted when prev already applied")
|
||||
|
||||
// nothing to carry (update has a map, prev has no config): plain replace
|
||||
m.targetGen, m.appliedGen = 5, 4
|
||||
prev = &mgmProto.SyncResponse{NetworkMap: &mgmProto.NetworkMap{Serial: 5}}
|
||||
update = &mgmProto.SyncResponse{NetworkMap: &mgmProto.NetworkMap{Serial: 6}}
|
||||
require.Same(t, update, m.mergeTarget(prev, update))
|
||||
}
|
||||
|
||||
// converges over the bounded passes (apply returns more until the 3rd pass),
|
||||
// fires onConverged exactly once, then blocks (no further apply) until a new target.
|
||||
func TestMapStateManager_ConvergesThenStops(t *testing.T) {
|
||||
var passes int32
|
||||
var firstPasses int32
|
||||
converged := make(chan struct{}, 1)
|
||||
|
||||
apply := func(_ *mgmProto.SyncResponse, firstPass bool) (bool, error) {
|
||||
n := atomic.AddInt32(&passes, 1)
|
||||
if firstPass {
|
||||
atomic.AddInt32(&firstPasses, 1)
|
||||
}
|
||||
return n < 3, nil // more on pass 1 and 2, converge on pass 3
|
||||
}
|
||||
m := newMapStateManager(apply, nil, func(time.Duration) { converged <- struct{}{} })
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go m.run(ctx)
|
||||
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||
|
||||
select {
|
||||
case <-converged:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("manager did not converge")
|
||||
}
|
||||
require.EqualValues(t, 3, atomic.LoadInt32(&passes))
|
||||
require.EqualValues(t, 1, atomic.LoadInt32(&firstPasses), "firstPass true only on pass 1, false on re-runs of the same target")
|
||||
|
||||
// once converged the loop blocks: no further apply calls
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
require.EqualValues(t, 3, atomic.LoadInt32(&passes), "apply must not run after convergence")
|
||||
}
|
||||
|
||||
// persist runs once per received update (not per apply pass), regardless of how many
|
||||
// bounded passes that target takes to converge.
|
||||
func TestMapStateManager_PersistsOncePerUpdate(t *testing.T) {
|
||||
var passes, persists int32
|
||||
converged := make(chan struct{}, 1)
|
||||
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||
n := atomic.AddInt32(&passes, 1)
|
||||
return n < 3, nil // 3 passes for one target
|
||||
}
|
||||
persist := func(*mgmProto.SyncResponse) { atomic.AddInt32(&persists, 1) }
|
||||
m := newMapStateManager(apply, persist, func(time.Duration) { converged <- struct{}{} })
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go m.run(ctx)
|
||||
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||
select {
|
||||
case <-converged:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("did not converge")
|
||||
}
|
||||
require.EqualValues(t, 3, atomic.LoadInt32(&passes))
|
||||
require.EqualValues(t, 1, atomic.LoadInt32(&persists), "persist once per update, not per pass")
|
||||
}
|
||||
|
||||
// every update received from management is persisted — even one that is coalesced /
|
||||
// skipped from apply before it ever converges.
|
||||
func TestMapStateManager_PersistsEveryUpdateIncludingSkipped(t *testing.T) {
|
||||
release := make(chan struct{})
|
||||
var persists int32
|
||||
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||
<-release // hold the first apply so the second update coalesces/skips
|
||||
return false, nil
|
||||
}
|
||||
persist := func(*mgmProto.SyncResponse) { atomic.AddInt32(&persists, 1) }
|
||||
m := newMapStateManager(apply, persist, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go m.run(ctx)
|
||||
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{})) // map1 -> apply blocks
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{})) // map2 supersedes map1 (skipped from apply)
|
||||
close(release)
|
||||
|
||||
// both updates persisted even though map1 is skipped from apply
|
||||
require.Eventually(t, func() bool { return atomic.LoadInt32(&persists) == 2 }, 2*time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
// each map that is actually processed (converged before the next arrives) fires
|
||||
// onConverged exactly once — mirroring the legacy per-message handleSync timing.
|
||||
func TestMapStateManager_SignalsEachProcessedMap(t *testing.T) {
|
||||
converged := make(chan struct{}, 8)
|
||||
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||
return false, nil // converge in one pass
|
||||
}
|
||||
m := newMapStateManager(apply, nil, func(time.Duration) { converged <- struct{}{} })
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go m.run(ctx)
|
||||
|
||||
const maps = 3
|
||||
for i := 0; i < maps; i++ {
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||
select { // wait for this map to converge before sending the next (no coalescing)
|
||||
case <-converged:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("map %d not signaled", i)
|
||||
}
|
||||
}
|
||||
|
||||
// no extra signals once the stream goes quiet
|
||||
select {
|
||||
case <-converged:
|
||||
t.Fatal("unexpected extra onConverged")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
// a map superseded before it converges is skipped: only the latest (processed) map
|
||||
// fires onConverged, not the skipped one.
|
||||
func TestMapStateManager_SkippedMapNotSignaled(t *testing.T) {
|
||||
release := make(chan struct{})
|
||||
var applies, converged atomic.Int32
|
||||
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||
applies.Add(1)
|
||||
<-release // hold the first apply in-flight so we can queue a newer target
|
||||
return false, nil
|
||||
}
|
||||
m := newMapStateManager(apply, nil, func(time.Duration) { converged.Add(1) })
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go m.run(ctx)
|
||||
|
||||
// map1 is picked up; its apply blocks on release
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||
require.Eventually(t, func() bool { return applies.Load() >= 1 }, 2*time.Second, 5*time.Millisecond)
|
||||
|
||||
// map2 supersedes map1 before it settled -> map1 is skipped
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||
close(release) // let both applies proceed
|
||||
|
||||
// only the processed (latest) map signals; the skipped one does not
|
||||
require.Eventually(t, func() bool { return converged.Load() == 1 }, 2*time.Second, 10*time.Millisecond)
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
require.EqualValues(t, 1, converged.Load(), "skipped map must not fire onConverged")
|
||||
require.EqualValues(t, 2, applies.Load(), "both targets entered apply (map1 once, map2 once)")
|
||||
}
|
||||
|
||||
// an apply error drops the target: no retry of the same target, no onConverged,
|
||||
// the loop goes idle — and a fresh target is still applied afterwards.
|
||||
func TestMapStateManager_DropsTargetOnError(t *testing.T) {
|
||||
applied := make(chan struct{}, 8)
|
||||
var failNext atomic.Bool
|
||||
failNext.Store(true)
|
||||
|
||||
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||
applied <- struct{}{}
|
||||
if failNext.Load() {
|
||||
return false, errors.New("boom")
|
||||
}
|
||||
return false, nil // converge in one pass
|
||||
}
|
||||
var converged atomic.Int32
|
||||
m := newMapStateManager(apply, nil, func(time.Duration) { converged.Add(1) })
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go m.run(ctx)
|
||||
|
||||
// first target errors -> applied once, then dropped (no retry, no onConverged)
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||
select {
|
||||
case <-applied:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("errored target not applied")
|
||||
}
|
||||
select {
|
||||
case <-applied:
|
||||
t.Fatal("errored target must not be retried")
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
}
|
||||
require.EqualValues(t, 0, converged.Load(), "onConverged must not fire on error")
|
||||
|
||||
// a new target is still processed normally and converges
|
||||
failNext.Store(false)
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||
select {
|
||||
case <-applied:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("new target after error not applied")
|
||||
}
|
||||
require.Eventually(t, func() bool { return converged.Load() == 1 }, 2*time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
// a new target after convergence triggers a fresh apply; an idle (converged)
|
||||
// manager does not apply on its own.
|
||||
func TestMapStateManager_ReappliesOnNewTarget(t *testing.T) {
|
||||
applied := make(chan struct{}, 8)
|
||||
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||
applied <- struct{}{}
|
||||
return false, nil // converge in one pass
|
||||
}
|
||||
m := newMapStateManager(apply, nil, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go m.run(ctx)
|
||||
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||
select {
|
||||
case <-applied:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("first target not applied")
|
||||
}
|
||||
|
||||
// converged → must stay idle (no spurious apply)
|
||||
select {
|
||||
case <-applied:
|
||||
t.Fatal("unexpected apply while idle/converged")
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
}
|
||||
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||
select {
|
||||
case <-applied:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("new target not applied")
|
||||
}
|
||||
}
|
||||
@@ -85,7 +85,11 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
||||
defer g.srWatcher.RemoveListener(srReconnectedChan)
|
||||
|
||||
ticker := g.initialTicker(ctx)
|
||||
defer ticker.Stop()
|
||||
defer func() {
|
||||
// If backoff.Ticker.send is blocked, context.Done will not close the Ticker goroutine.
|
||||
// We have to explicitly call Stop, even if we use backoff.WithContext.
|
||||
ticker.Stop()
|
||||
}()
|
||||
|
||||
tickerChannel := ticker.C
|
||||
|
||||
|
||||
92
client/internal/peer/guard/guard_leak_test.go
Normal file
92
client/internal/peer/guard/guard_leak_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package guard
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
)
|
||||
|
||||
func newTestGuard(status connStatusFunc) *Guard {
|
||||
srw := NewSRWatcher(nil, nil, nil, ice.Config{})
|
||||
return NewGuard(log.WithField("test", "guard"), status, 50*time.Millisecond, srw)
|
||||
}
|
||||
|
||||
// countBackoffTickerGoroutines returns how many goroutines are currently sitting
|
||||
// in backoff/v4.(*Ticker).run (a ticker goroutine that has not exited).
|
||||
func countBackoffTickerGoroutines() int {
|
||||
buf := make([]byte, 1<<25) // 32MB
|
||||
n := runtime.Stack(buf, true)
|
||||
return strings.Count(string(buf[:n]), "backoff/v4.(*Ticker).run")
|
||||
}
|
||||
|
||||
// TestGuard_ReconnectTicker_NoGoroutineLeakOnShutdown reproduces a observed
|
||||
// leak: after a shutdown burst, ticker run/send goroutines stay parked
|
||||
// forever even though every reconnect loop has exited.
|
||||
func TestGuard_ReconnectTicker_NoGoroutineLeakOnShutdown(t *testing.T) {
|
||||
before := countBackoffTickerGoroutines()
|
||||
|
||||
const peers = 6000
|
||||
cancels := make([]context.CancelFunc, 0, peers)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// A status check slower than the tick cadence. This models the real
|
||||
// isConnectedOnAllWay/callback doing work: while the loop is busy in the
|
||||
// handler, the ticker fires the next tick and parks in send(), because
|
||||
// send() never selects on ctx.
|
||||
slowStatus := func() ConnStatus {
|
||||
time.Sleep(70 * time.Millisecond)
|
||||
return ConnStatusConnected
|
||||
}
|
||||
|
||||
for range peers {
|
||||
g := newTestGuard(slowStatus)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancels = append(cancels, cancel)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
g.Start(ctx, func() {})
|
||||
}()
|
||||
// Force the live ticker to be a newReconnectTicker.
|
||||
g.SetRelayedConnDisconnected()
|
||||
}
|
||||
|
||||
// Let the replacement tickers get past their 800ms initial interval, so
|
||||
// many are parked in send() waiting on the (slow) consumer when we tear
|
||||
// everything down.
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
|
||||
// Shutdown burst: cancel every peer at once, like engine teardown.
|
||||
for _, c := range cancels {
|
||||
c()
|
||||
}
|
||||
|
||||
// Every reconnect loop must return
|
||||
waitCh := make(chan struct{})
|
||||
go func() { wg.Wait(); close(waitCh) }()
|
||||
select {
|
||||
case <-waitCh:
|
||||
case <-time.After(30 * time.Second):
|
||||
t.Fatal("not all reconnect loops returned after ctx cancel")
|
||||
}
|
||||
|
||||
// Give any correctly-stopped ticker goroutines time to unwind.
|
||||
for range 50 {
|
||||
runtime.Gosched()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
leaked := countBackoffTickerGoroutines() - before
|
||||
t.Logf("backoff Ticker.run goroutines still parked after teardown of %d peers: %d", peers, leaked)
|
||||
if leaked > 0 {
|
||||
t.Errorf("LEAK: %d backoff ticker goroutines parked after all reconnect loops exited "+
|
||||
"(defer ticker.Stop() stops the initial ticker, not the live replacement)", leaked)
|
||||
}
|
||||
}
|
||||
191
client/internal/routemanager/exit_node_selection_test.go
Normal file
191
client/internal/routemanager/exit_node_selection_test.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package routemanager
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func newExitNodeTestManager() *DefaultManager {
|
||||
return &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
||||
}
|
||||
|
||||
func exitRoute(netID, peer string, skipAutoApply bool) *route.Route {
|
||||
return &route.Route{
|
||||
NetID: route.NetID(netID),
|
||||
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
||||
Peer: peer,
|
||||
SkipAutoApply: skipAutoApply,
|
||||
}
|
||||
}
|
||||
|
||||
func TestPickPreferredExitNode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
info exitNodeInfo
|
||||
want route.NetID
|
||||
}{
|
||||
{
|
||||
name: "persisted user selection wins over management",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"a", "b", "c"},
|
||||
userSelected: []route.NetID{"b"},
|
||||
selectedByManagement: []route.NetID{"a"},
|
||||
},
|
||||
want: "b",
|
||||
},
|
||||
{
|
||||
name: "multiple user-selected self-heal to deterministic min",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"a", "b", "c"},
|
||||
userSelected: []route.NetID{"c", "a"},
|
||||
},
|
||||
want: "a",
|
||||
},
|
||||
{
|
||||
name: "explicit opt-out keeps none",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"a", "b"},
|
||||
userDeselected: []route.NetID{"a", "b"},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "fresh defaults to management auto-apply pick",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"a", "b", "c"},
|
||||
selectedByManagement: []route.NetID{"b"},
|
||||
},
|
||||
want: "b",
|
||||
},
|
||||
{
|
||||
name: "no user pick and no management auto-apply selects none",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"c", "a", "b"},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "user-deselect does not block a management auto-apply sibling",
|
||||
info: exitNodeInfo{
|
||||
allIDs: []route.NetID{"a", "b"},
|
||||
userDeselected: []route.NetID{"a"},
|
||||
selectedByManagement: []route.NetID{"b"},
|
||||
},
|
||||
want: "b",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, pickPreferredExitNode(tt.info), "preferred exit node")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceSingleExitNode(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
all := []route.NetID{"a", "b", "c"}
|
||||
|
||||
m.enforceSingleExitNode("b", all)
|
||||
assert.False(t, m.routeSelector.IsSelected("a"), "a should be deselected")
|
||||
assert.True(t, m.routeSelector.IsSelected("b"), "b should be the only selected exit node")
|
||||
assert.False(t, m.routeSelector.IsSelected("c"), "c should be deselected")
|
||||
|
||||
// Switching the preferred node moves the single selection.
|
||||
m.enforceSingleExitNode("c", all)
|
||||
assert.False(t, m.routeSelector.IsSelected("a"), "a stays deselected")
|
||||
assert.False(t, m.routeSelector.IsSelected("b"), "b should now be deselected")
|
||||
assert.True(t, m.routeSelector.IsSelected("c"), "c should now be selected")
|
||||
|
||||
// Empty preferred turns every exit node off.
|
||||
m.enforceSingleExitNode("", all)
|
||||
for _, id := range all {
|
||||
assert.False(t, m.routeSelector.IsSelected(id), "no exit node should be selected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceSingleExitNode_RespectsDeselectAll(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
m.routeSelector.DeselectAllRoutes()
|
||||
|
||||
m.enforceSingleExitNode("b", []route.NetID{"a", "b"})
|
||||
|
||||
assert.True(t, m.routeSelector.IsDeselectAll(), "global deselect-all must stay in effect")
|
||||
assert.False(t, m.routeSelector.IsSelected("b"), "no exit node should be forced on while deselect-all is set")
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement_FreshSelectsOne(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
routes := route.HAMap{
|
||||
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
|
||||
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
|
||||
"lan|192.168.1.0/24": {{NetID: "lan", Network: netip.MustParsePrefix("192.168.1.0/24"), Peer: "p3"}},
|
||||
"exitC|0.0.0.0/0": {exitRoute("exitC", "p4", false)},
|
||||
}
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
// Exactly one exit node (the deterministic first) is selected.
|
||||
assert.True(t, m.routeSelector.IsSelected("exitA"), "exitA is the deterministic default")
|
||||
assert.False(t, m.routeSelector.IsSelected("exitB"), "exitB must not also be selected")
|
||||
assert.False(t, m.routeSelector.IsSelected("exitC"), "exitC must not also be selected")
|
||||
// Non-exit routes are left at their default-on state.
|
||||
assert.True(t, m.routeSelector.IsSelected("lan"), "non-exit route selection is untouched")
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement_HonorsPersistedPick(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
routes := route.HAMap{
|
||||
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
|
||||
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
|
||||
}
|
||||
all := []route.NetID{"exitA", "exitB"}
|
||||
|
||||
// Simulate the state the runtime select path leaves behind: exactly one
|
||||
// exit node explicitly selected, its sibling deselected.
|
||||
require.NoError(t, m.routeSelector.SelectRoutes([]route.NetID{"exitB"}, true, all))
|
||||
require.NoError(t, m.routeSelector.DeselectRoutes([]route.NetID{"exitA"}, all))
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
assert.True(t, m.routeSelector.IsSelected("exitB"), "persisted pick must stay selected")
|
||||
assert.False(t, m.routeSelector.IsSelected("exitA"), "the other exit node stays deselected")
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement_OptOutKeepsNone(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
routes := route.HAMap{
|
||||
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
|
||||
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
|
||||
}
|
||||
all := []route.NetID{"exitA", "exitB"}
|
||||
|
||||
// User deselected exit nodes and selected none.
|
||||
require.NoError(t, m.routeSelector.DeselectRoutes(all, all))
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
assert.False(t, m.routeSelector.IsSelected("exitA"), "opt-out keeps exitA off")
|
||||
assert.False(t, m.routeSelector.IsSelected("exitB"), "opt-out keeps exitB off")
|
||||
}
|
||||
|
||||
func TestUpdateRouteSelectorFromManagement_NoAutoApplySelectsNone(t *testing.T) {
|
||||
m := newExitNodeTestManager()
|
||||
// SkipAutoApply=true: management offers the exit nodes but doesn't request
|
||||
// auto-activation, so none should be selected until the user picks one.
|
||||
routes := route.HAMap{
|
||||
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", true)},
|
||||
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", true)},
|
||||
}
|
||||
|
||||
m.updateRouteSelectorFromManagement(routes)
|
||||
|
||||
assert.False(t, m.routeSelector.IsSelected("exitA"), "no auto-apply keeps exitA off")
|
||||
assert.False(t, m.routeSelector.IsSelected("exitB"), "no auto-apply keeps exitB off")
|
||||
}
|
||||
@@ -701,7 +701,13 @@ func resolveURLsToIPs(urls []string) []net.IP {
|
||||
return ips
|
||||
}
|
||||
|
||||
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
|
||||
// updateRouteSelectorFromManagement reconciles exit-node selection on every
|
||||
// network map: it keeps at most one exit node selected — the user's persisted
|
||||
// pick, else whatever management marks for auto-apply (SkipAutoApply=false),
|
||||
// else none. We never auto-activate an exit node the map doesn't request; it
|
||||
// stays off until the user picks it. Exit nodes are mutually exclusive, but the
|
||||
// RouteSelector stores routes with default-on semantics, so without this every
|
||||
// available exit node would report selected at once.
|
||||
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
|
||||
m.mirrorV6ExitPairSelections(clientRoutes)
|
||||
|
||||
@@ -712,13 +718,14 @@ func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HA
|
||||
return
|
||||
}
|
||||
|
||||
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
|
||||
if len(exitNodeInfo.allIDs) == 0 {
|
||||
info := m.collectExitNodeInfo(clientRoutes)
|
||||
if len(info.allIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
m.updateExitNodeSelections(exitNodeInfo)
|
||||
m.logExitNodeUpdate(exitNodeInfo)
|
||||
preferred := pickPreferredExitNode(info)
|
||||
m.enforceSingleExitNode(preferred, info.allIDs)
|
||||
m.logExitNodeUpdate(info, preferred)
|
||||
}
|
||||
|
||||
// mirrorV6ExitPairSelections keeps every synthesized "-v6" exit route's selection
|
||||
@@ -746,6 +753,10 @@ type exitNodeInfo struct {
|
||||
userDeselected []route.NetID
|
||||
}
|
||||
|
||||
// collectExitNodeInfo categorises the available exit nodes by their persisted
|
||||
// selection state. It keys on the base (v4) NetID and skips the synthesized
|
||||
// "-v6" partner, which inherits its base's selection through the RouteSelector
|
||||
// — counting it separately would double-count the pair.
|
||||
func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeInfo {
|
||||
var info exitNodeInfo
|
||||
|
||||
@@ -755,6 +766,9 @@ func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeI
|
||||
}
|
||||
|
||||
netID := haID.NetID()
|
||||
if strings.HasSuffix(string(netID), route.V6ExitSuffix) {
|
||||
continue
|
||||
}
|
||||
info.allIDs = append(info.allIDs, netID)
|
||||
|
||||
if m.routeSelector.HasUserSelectionForRoute(netID) {
|
||||
@@ -791,45 +805,52 @@ func (m *DefaultManager) checkManagementSelection(routes []*route.Route, netID r
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) updateExitNodeSelections(info exitNodeInfo) {
|
||||
routesToDeselect := m.getRoutesToDeselect(info.allIDs)
|
||||
m.deselectExitNodes(routesToDeselect)
|
||||
m.selectExitNodesByManagement(info.selectedByManagement, info.allIDs)
|
||||
// pickPreferredExitNode chooses the single exit node to keep selected. In order:
|
||||
// - a persisted user selection wins (deterministic if several survive from
|
||||
// legacy state, so the set self-heals down to one);
|
||||
// - otherwise activate only what management marks for auto-apply
|
||||
// (SkipAutoApply=false); the lexicographically first if it marks several.
|
||||
//
|
||||
// Returns "" when neither holds — we never force an arbitrary exit node on. A
|
||||
// route the map doesn't auto-apply stays off until the user selects it.
|
||||
// info.userDeselected is informational only: an explicit deselect simply keeps
|
||||
// that route out of both lists above, so it can't be picked.
|
||||
func pickPreferredExitNode(info exitNodeInfo) route.NetID {
|
||||
if len(info.userSelected) > 0 {
|
||||
return minNetID(info.userSelected)
|
||||
}
|
||||
if len(info.selectedByManagement) > 0 {
|
||||
return minNetID(info.selectedByManagement)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *DefaultManager) getRoutesToDeselect(allIDs []route.NetID) []route.NetID {
|
||||
var routesToDeselect []route.NetID
|
||||
for _, netID := range allIDs {
|
||||
if !m.routeSelector.HasUserSelectionForRoute(netID) {
|
||||
routesToDeselect = append(routesToDeselect, netID)
|
||||
// enforceSingleExitNode makes preferred the only selected exit node: every other
|
||||
// available exit node is deselected and preferred (if any) is selected, without
|
||||
// disturbing non-exit route selections. The whole reconciliation runs under a
|
||||
// single RouteSelector lock (SetExclusiveExitNode) so a concurrent deselect-all
|
||||
// cannot interleave and get undone; a global deselect-all is left untouched so
|
||||
// the user's "all off" stays in effect.
|
||||
func (m *DefaultManager) enforceSingleExitNode(preferred route.NetID, allIDs []route.NetID) {
|
||||
m.routeSelector.SetExclusiveExitNode(preferred, allIDs)
|
||||
}
|
||||
|
||||
func (m *DefaultManager) logExitNodeUpdate(info exitNodeInfo, preferred route.NetID) {
|
||||
log.Debugf("Exit node selection: %d available, preferred=%q (%d user-selected, %d user-deselected, %d management-selected)",
|
||||
len(info.allIDs), preferred, len(info.userSelected), len(info.userDeselected), len(info.selectedByManagement))
|
||||
}
|
||||
|
||||
// minNetID returns the lexicographically smallest NetID, for a deterministic
|
||||
// default pick that stays stable across restarts.
|
||||
func minNetID(ids []route.NetID) route.NetID {
|
||||
if len(ids) == 0 {
|
||||
return ""
|
||||
}
|
||||
best := ids[0]
|
||||
for _, id := range ids[1:] {
|
||||
if id < best {
|
||||
best = id
|
||||
}
|
||||
}
|
||||
return routesToDeselect
|
||||
}
|
||||
|
||||
func (m *DefaultManager) deselectExitNodes(routesToDeselect []route.NetID) {
|
||||
if len(routesToDeselect) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err := m.routeSelector.DeselectRoutes(routesToDeselect, routesToDeselect)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to deselect exit nodes: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) selectExitNodesByManagement(selectedByManagement []route.NetID, allIDs []route.NetID) {
|
||||
if len(selectedByManagement) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err := m.routeSelector.SelectRoutes(selectedByManagement, true, allIDs)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to select exit nodes: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) logExitNodeUpdate(info exitNodeInfo) {
|
||||
log.Debugf("Updated route selector: %d exit nodes available, %d selected by management, %d user-selected, %d user-deselected",
|
||||
len(info.allIDs), len(info.selectedByManagement), len(info.userSelected), len(info.userDeselected))
|
||||
return best
|
||||
}
|
||||
|
||||
@@ -115,7 +115,38 @@ func (rs *RouteSelector) DeselectAllRoutes() {
|
||||
clear(rs.selectedRoutes)
|
||||
}
|
||||
|
||||
// IsDeselectAll reports whether the user has explicitly deselected all routes.
|
||||
// SetExclusiveExitNode atomically makes preferred the only selected exit node
|
||||
// among exitIDs: every other ID in exitIDs is deselected and preferred (when
|
||||
// non-empty) is selected, all under a single lock. Holding the lock across the
|
||||
// whole reconciliation prevents a concurrent DeselectAllRoutes from interleaving
|
||||
// between the deselect and select steps and being silently undone. A global
|
||||
// deselect-all is left untouched so the user's "all off" stays in effect;
|
||||
// non-exit routes are never referenced, so their selection is preserved.
|
||||
func (rs *RouteSelector) SetExclusiveExitNode(preferred route.NetID, exitIDs []route.NetID) {
|
||||
rs.mu.Lock()
|
||||
defer rs.mu.Unlock()
|
||||
|
||||
if rs.deselectAll {
|
||||
return
|
||||
}
|
||||
|
||||
for _, id := range exitIDs {
|
||||
if id == preferred {
|
||||
continue
|
||||
}
|
||||
rs.deselectedRoutes[id] = struct{}{}
|
||||
delete(rs.selectedRoutes, id)
|
||||
}
|
||||
|
||||
if preferred != "" {
|
||||
delete(rs.deselectedRoutes, preferred)
|
||||
rs.selectedRoutes[preferred] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// IsDeselectAll reports whether the global "deselect all" flag is set, i.e. the
|
||||
// user explicitly disabled every route. Callers enforcing per-route invariants
|
||||
// (e.g. single exit node) should leave the selection untouched when it is.
|
||||
func (rs *RouteSelector) IsDeselectAll() bool {
|
||||
rs.mu.RLock()
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
151
combined/cmd/admin.go
Normal file
151
combined/cmd/admin.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
|
||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
activitystore "github.com/netbirdio/netbird/management/server/activity/store"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// newAdminCommands creates the admin command tree with combined-specific resource openers.
|
||||
func newAdminCommands() *cobra.Command {
|
||||
return admincmd.NewCommands(admincmd.Openers{
|
||||
Resources: withAdminResources,
|
||||
Store: withAdminStoreOnly,
|
||||
IDP: withAdminIDPOnly,
|
||||
})
|
||||
}
|
||||
|
||||
func newLegacyTokenCommand() *cobra.Command {
|
||||
cmd := tokencmd.NewCommands(tokencmd.StoreOpener(withAdminStoreOnly))
|
||||
cmd.Deprecated = "use 'admin token' instead"
|
||||
return cmd
|
||||
}
|
||||
|
||||
// withAdminResources loads the combined YAML config, initializes stores, and calls fn.
|
||||
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
|
||||
return withAdminConfig(cmd, func(ctx context.Context, cfg *CombinedConfig) error {
|
||||
mgmtConfig, err := adminManagementConfig(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
managementStore, err := openAdminStore(ctx, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer admincmd.CloseStore(ctx, managementStore)
|
||||
|
||||
idpStorage, idpStorageFile, err := admincmd.OpenIDPStorage(mgmtConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer admincmd.CloseIDPStorage(idpStorage)
|
||||
|
||||
eventStore, esErr := openAdminEventStore(ctx, cfg, mgmtConfig)
|
||||
if esErr != nil {
|
||||
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Warning: audit events will not be recorded: %v\n", esErr)
|
||||
}
|
||||
if eventStore != nil {
|
||||
defer func() {
|
||||
if err := eventStore.Close(ctx); err != nil {
|
||||
log.Debugf("close activity event store: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage, IDPStorageFile: idpStorageFile, EventStore: eventStore})
|
||||
})
|
||||
}
|
||||
|
||||
// withAdminStoreOnly opens only the management store for admin subcommands that do not
|
||||
// need embedded IdP storage.
|
||||
func withAdminStoreOnly(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||
return withAdminConfig(cmd, func(ctx context.Context, cfg *CombinedConfig) error {
|
||||
managementStore, err := openAdminStore(ctx, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer admincmd.CloseStore(ctx, managementStore)
|
||||
|
||||
return fn(ctx, managementStore)
|
||||
})
|
||||
}
|
||||
|
||||
func withAdminIDPOnly(cmd *cobra.Command, fn func(ctx context.Context, idpStorage storage.Storage, storageFile string) error) error {
|
||||
return withAdminConfig(cmd, func(ctx context.Context, cfg *CombinedConfig) error {
|
||||
mgmtConfig, err := adminManagementConfig(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
idpStorage, idpStorageFile, err := admincmd.OpenIDPStorage(mgmtConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer admincmd.CloseIDPStorage(idpStorage)
|
||||
|
||||
return fn(ctx, idpStorage, idpStorageFile)
|
||||
})
|
||||
}
|
||||
|
||||
func withAdminConfig(cmd *cobra.Command, fn func(ctx context.Context, cfg *CombinedConfig) error) error {
|
||||
if err := util.InitLog("error", "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
cfg.ApplyAdminDefaults()
|
||||
applyServerStoreEnv(cfg.Server.Store)
|
||||
|
||||
return fn(ctx, cfg)
|
||||
}
|
||||
|
||||
func adminManagementConfig(cfg *CombinedConfig) (*nbconfig.Config, error) {
|
||||
mgmtConfig, err := cfg.ToManagementConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create management config: %w", err)
|
||||
}
|
||||
return mgmtConfig, nil
|
||||
}
|
||||
|
||||
func openAdminStore(ctx context.Context, cfg *CombinedConfig) (store.Store, error) {
|
||||
managementStore, err := store.NewStore(ctx, types.Engine(cfg.Management.Store.Engine), cfg.Management.DataDir, nil, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create store: %w", err)
|
||||
}
|
||||
return managementStore, nil
|
||||
}
|
||||
|
||||
func openAdminEventStore(ctx context.Context, cfg *CombinedConfig, config *nbconfig.Config) (activity.Store, error) {
|
||||
if config.DataStoreEncryptionKey == "" {
|
||||
return nil, fmt.Errorf("data store encryption key is not configured")
|
||||
}
|
||||
if err := applyActivityStoreEnv(cfg.Server.ActivityStore); err != nil {
|
||||
return nil, fmt.Errorf("configure activity event store: %w", err)
|
||||
}
|
||||
eventStore, err := activitystore.NewSqlStore(ctx, config.Datadir, config.DataStoreEncryptionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open activity event store: %w", err)
|
||||
}
|
||||
if eventStore == nil {
|
||||
return nil, fmt.Errorf("open activity event store: returned nil store")
|
||||
}
|
||||
return eventStore, nil
|
||||
}
|
||||
47
combined/cmd/admin_config_test.go
Normal file
47
combined/cmd/admin_config_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
)
|
||||
|
||||
func TestApplyAdminDefaultsCopiesServerStoreWithoutExposedAddress(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
cfg.Server.ExposedAddress = ""
|
||||
cfg.Server.DataDir = "/srv/netbird"
|
||||
cfg.Server.Store = StoreConfig{
|
||||
Engine: "postgres",
|
||||
DSN: "postgres://user:pass@example.com/netbird",
|
||||
}
|
||||
|
||||
cfg.ApplyAdminDefaults()
|
||||
|
||||
require.Equal(t, "/srv/netbird", cfg.Management.DataDir)
|
||||
require.Equal(t, "postgres", cfg.Management.Store.Engine)
|
||||
require.Equal(t, cfg.Server.Store.DSN, cfg.Management.Store.DSN)
|
||||
}
|
||||
|
||||
func TestOpenAdminEventStoreMissingEncryptionKeyReturnsNilInterface(t *testing.T) {
|
||||
eventStore, err := openAdminEventStore(context.Background(), &CombinedConfig{}, &nbconfig.Config{})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "encryption key")
|
||||
require.Nil(t, eventStore)
|
||||
}
|
||||
|
||||
func TestApplyServerStoreEnv(t *testing.T) {
|
||||
t.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", "")
|
||||
t.Setenv("NB_STORE_ENGINE_MYSQL_DSN", "")
|
||||
t.Setenv("NB_STORE_ENGINE_SQLITE_FILE", "")
|
||||
|
||||
applyServerStoreEnv(StoreConfig{Engine: "postgres", DSN: "postgres-dsn", File: "store.db"})
|
||||
require.Equal(t, "postgres-dsn", os.Getenv("NB_STORE_ENGINE_POSTGRES_DSN"))
|
||||
require.Equal(t, "store.db", os.Getenv("NB_STORE_ENGINE_SQLITE_FILE"))
|
||||
|
||||
applyServerStoreEnv(StoreConfig{Engine: "mysql", DSN: "mysql-dsn"})
|
||||
require.Equal(t, "mysql-dsn", os.Getenv("NB_STORE_ENGINE_MYSQL_DSN"))
|
||||
}
|
||||
@@ -6,8 +6,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
filePath "path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -299,6 +298,19 @@ func (c *CombinedConfig) ApplySimplifiedDefaults() {
|
||||
c.autoConfigureClientSettings(exposedProto, exposedHost, exposedHostPort, hasExternalStuns, hasExternalRelay, hasExternalSignal)
|
||||
}
|
||||
|
||||
// ApplyAdminDefaults applies the management settings needed by admin commands even
|
||||
// when the full server config is invalid and ApplySimplifiedDefaults cannot run.
|
||||
func (c *CombinedConfig) ApplyAdminDefaults() {
|
||||
if c.Management.DataDir == "" || c.Management.DataDir == "/var/lib/netbird/" {
|
||||
c.Management.DataDir = c.Server.DataDir
|
||||
}
|
||||
if c.Management.Store.Engine == "" || c.Management.Store.Engine == "sqlite" {
|
||||
if c.Server.Store.Engine != "" || c.Server.Store.File != "" || c.Server.Store.DSN != "" {
|
||||
c.Management.Store = c.Server.Store
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// applyRelayDefaults configures the relay service if no external relay is configured.
|
||||
func (c *CombinedConfig) applyRelayDefaults(exposedProto, exposedHostPort string, hasExternalRelay, hasExternalStuns bool) {
|
||||
if hasExternalRelay {
|
||||
@@ -576,11 +588,11 @@ func (c *CombinedConfig) buildEmbeddedIdPConfig(mgmt ManagementConfig) (*idp.Emb
|
||||
return nil, fmt.Errorf("authStore.dsn is required when authStore.engine is postgres")
|
||||
}
|
||||
} else {
|
||||
authStorageFile = path.Join(mgmt.DataDir, "idp.db")
|
||||
authStorageFile = filePath.Join(mgmt.DataDir, "idp.db")
|
||||
if c.Server.AuthStore.File != "" {
|
||||
authStorageFile = c.Server.AuthStore.File
|
||||
if !filepath.IsAbs(authStorageFile) {
|
||||
authStorageFile = filepath.Join(mgmt.DataDir, authStorageFile)
|
||||
if !filePath.IsAbs(authStorageFile) {
|
||||
authStorageFile = filePath.Join(mgmt.DataDir, authStorageFile)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -727,7 +739,7 @@ func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config, mgmtPort
|
||||
cfg.EmbeddedIdP.Storage.Type = "sqlite3"
|
||||
}
|
||||
if cfg.EmbeddedIdP.Storage.Config.File == "" && cfg.Datadir != "" {
|
||||
cfg.EmbeddedIdP.Storage.Config.File = path.Join(cfg.Datadir, "idp.db")
|
||||
cfg.EmbeddedIdP.Storage.Config.File = filePath.Join(cfg.Datadir, "idp.db")
|
||||
}
|
||||
|
||||
issuer := cfg.EmbeddedIdP.Issuer
|
||||
|
||||
@@ -64,7 +64,8 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
|
||||
_ = rootCmd.MarkPersistentFlagRequired("config")
|
||||
|
||||
rootCmd.AddCommand(newTokenCommands())
|
||||
rootCmd.AddCommand(newAdminCommands())
|
||||
rootCmd.AddCommand(newLegacyTokenCommand())
|
||||
}
|
||||
|
||||
func RootCmd() *cobra.Command {
|
||||
@@ -122,6 +123,37 @@ func execute(cmd *cobra.Command, _ []string) error {
|
||||
}
|
||||
|
||||
// initializeConfig loads and validates the configuration, then initializes logging.
|
||||
func applyServerStoreEnv(storeConfig StoreConfig) {
|
||||
if dsn := storeConfig.DSN; dsn != "" {
|
||||
switch strings.ToLower(storeConfig.Engine) {
|
||||
case "postgres":
|
||||
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
|
||||
case "mysql":
|
||||
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||
}
|
||||
}
|
||||
if file := storeConfig.File; file != "" {
|
||||
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
|
||||
}
|
||||
}
|
||||
|
||||
func applyActivityStoreEnv(storeConfig StoreConfig) error {
|
||||
if engine := storeConfig.Engine; engine != "" {
|
||||
engineLower := strings.ToLower(engine)
|
||||
if engineLower == "postgres" && storeConfig.DSN == "" {
|
||||
return fmt.Errorf("activityStore.dsn is required when activityStore.engine is postgres")
|
||||
}
|
||||
os.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE", engineLower)
|
||||
if dsn := storeConfig.DSN; dsn != "" {
|
||||
os.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN", dsn)
|
||||
}
|
||||
}
|
||||
if file := storeConfig.File; file != "" {
|
||||
os.Setenv("NB_ACTIVITY_EVENT_SQLITE_FILE", file)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func initializeConfig() error {
|
||||
var err error
|
||||
config, err = LoadConfig(configPath)
|
||||
@@ -137,30 +169,10 @@ func initializeConfig() error {
|
||||
return fmt.Errorf("failed to initialize log: %w", err)
|
||||
}
|
||||
|
||||
if dsn := config.Server.Store.DSN; dsn != "" {
|
||||
switch strings.ToLower(config.Server.Store.Engine) {
|
||||
case "postgres":
|
||||
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
|
||||
case "mysql":
|
||||
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||
}
|
||||
}
|
||||
if file := config.Server.Store.File; file != "" {
|
||||
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
|
||||
}
|
||||
applyServerStoreEnv(config.Server.Store)
|
||||
|
||||
if engine := config.Server.ActivityStore.Engine; engine != "" {
|
||||
engineLower := strings.ToLower(engine)
|
||||
if engineLower == "postgres" && config.Server.ActivityStore.DSN == "" {
|
||||
return fmt.Errorf("activityStore.dsn is required when activityStore.engine is postgres")
|
||||
}
|
||||
os.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE", engineLower)
|
||||
if dsn := config.Server.ActivityStore.DSN; dsn != "" {
|
||||
os.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN", dsn)
|
||||
}
|
||||
}
|
||||
if file := config.Server.ActivityStore.File; file != "" {
|
||||
os.Setenv("NB_ACTIVITY_EVENT_SQLITE_FILE", file)
|
||||
if err := applyActivityStoreEnv(config.Server.ActivityStore); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("Starting combined NetBird server")
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// newTokenCommands creates the token command tree with combined-specific store opener.
|
||||
func newTokenCommands() *cobra.Command {
|
||||
return tokencmd.NewCommands(withTokenStore)
|
||||
}
|
||||
|
||||
// withTokenStore loads the combined YAML config, initializes the store, and calls fn.
|
||||
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||
if err := util.InitLog("error", "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
|
||||
cfg, err := LoadConfig(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
if dsn := cfg.Server.Store.DSN; dsn != "" {
|
||||
switch strings.ToLower(cfg.Server.Store.Engine) {
|
||||
case "postgres":
|
||||
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
|
||||
case "mysql":
|
||||
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
|
||||
}
|
||||
}
|
||||
if file := cfg.Server.Store.File; file != "" {
|
||||
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
|
||||
}
|
||||
|
||||
datadir := cfg.Management.DataDir
|
||||
engine := types.Engine(cfg.Management.Store.Engine)
|
||||
|
||||
s, err := store.NewStore(ctx, engine, datadir, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create store: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := s.Close(ctx); err != nil {
|
||||
log.Debugf("close store: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, s)
|
||||
}
|
||||
@@ -41,7 +41,7 @@ type Config struct {
|
||||
GRPCAddr string
|
||||
}
|
||||
|
||||
const localConnectorID = "local"
|
||||
const LocalConnectorID = "local"
|
||||
|
||||
// Provider wraps a Dex server
|
||||
type Provider struct {
|
||||
@@ -495,18 +495,60 @@ func (p *Provider) Storage() storage.Storage {
|
||||
return p.storage
|
||||
}
|
||||
|
||||
// SetClientsMFAChain updates the MFAChain field on OAuth2 clients in Dex storage.
|
||||
// Pass a non-empty slice (e.g. []string{"default-totp"}) to enable MFA, or nil to disable it.
|
||||
func SetClientsMFAChain(ctx context.Context, st storage.Storage, clientIDs []string, mfaChain []string) error {
|
||||
previousChains := make(map[string][]string, len(clientIDs))
|
||||
for _, clientID := range clientIDs {
|
||||
client, err := st.GetClient(ctx, clientID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get client %s before MFA chain update: %w", clientID, err)
|
||||
}
|
||||
previousChains[clientID] = cloneMFAChain(client.MFAChain)
|
||||
}
|
||||
|
||||
updatedClientIDs := make([]string, 0, len(clientIDs))
|
||||
for _, clientID := range clientIDs {
|
||||
if err := st.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
|
||||
old.MFAChain = cloneMFAChain(mfaChain)
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
if rollbackErr := rollbackClientsMFAChain(ctx, st, updatedClientIDs, previousChains); rollbackErr != nil {
|
||||
return fmt.Errorf("failed to update MFA chain on client %s: %w (also failed to roll back previous MFA chains: %v)", clientID, err, rollbackErr)
|
||||
}
|
||||
return fmt.Errorf("failed to update MFA chain on client %s: %w", clientID, err)
|
||||
}
|
||||
updatedClientIDs = append(updatedClientIDs, clientID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func rollbackClientsMFAChain(ctx context.Context, st storage.Storage, clientIDs []string, previousChains map[string][]string) error {
|
||||
var rollbackErrs []error
|
||||
for i := len(clientIDs) - 1; i >= 0; i-- {
|
||||
clientID := clientIDs[i]
|
||||
previousChain := cloneMFAChain(previousChains[clientID])
|
||||
if err := st.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
|
||||
old.MFAChain = previousChain
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
rollbackErrs = append(rollbackErrs, fmt.Errorf("client %s: %w", clientID, err))
|
||||
}
|
||||
}
|
||||
return errors.Join(rollbackErrs...)
|
||||
}
|
||||
|
||||
func cloneMFAChain(chain []string) []string {
|
||||
if chain == nil {
|
||||
return nil
|
||||
}
|
||||
return append([]string(nil), chain...)
|
||||
}
|
||||
|
||||
// SetClientsMFAChain updates the MFAChain field on the dashboard and CLI OAuth2 clients.
|
||||
// Pass a non-empty slice (e.g. []string{"default-totp"}) to enable MFA, or nil to disable it.
|
||||
func (p *Provider) SetClientsMFAChain(ctx context.Context, clientIDs []string, mfaChain []string) error {
|
||||
for _, clientID := range clientIDs {
|
||||
if err := p.storage.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
|
||||
old.MFAChain = mfaChain
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to update MFA chain on client %s: %w", clientID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return SetClientsMFAChain(ctx, p.storage, clientIDs, mfaChain)
|
||||
}
|
||||
|
||||
// Handler returns the Dex server as an http.Handler for embedding in another server.
|
||||
@@ -546,7 +588,7 @@ func (p *Provider) CreateUser(ctx context.Context, email, username, password str
|
||||
|
||||
// Encode the user ID in Dex's format: base64(protobuf{user_id, connector_id})
|
||||
// This matches the format Dex uses in JWT tokens
|
||||
encodedID := EncodeDexUserID(userID, localConnectorID)
|
||||
encodedID := EncodeDexUserID(userID, LocalConnectorID)
|
||||
return encodedID, nil
|
||||
}
|
||||
|
||||
@@ -625,7 +667,7 @@ func DecodeDexUserID(encodedID string) (userID, connectorID string, err error) {
|
||||
// local password connector.
|
||||
func IsLocalUserID(encodedID string) bool {
|
||||
_, connectorID, err := DecodeDexUserID(encodedID)
|
||||
return err == nil && connectorID == localConnectorID
|
||||
return err == nil && connectorID == LocalConnectorID
|
||||
}
|
||||
|
||||
// GetUser returns a user by email
|
||||
|
||||
@@ -3,6 +3,8 @@ package dex
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -11,11 +13,44 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/memory"
|
||||
sqllib "github.com/dexidp/dex/storage/sql"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type updateFailingStorage struct {
|
||||
storage.Storage
|
||||
failClientID string
|
||||
}
|
||||
|
||||
func (s *updateFailingStorage) UpdateClient(ctx context.Context, id string, updater func(storage.Client) (storage.Client, error)) error {
|
||||
if id == s.failClientID {
|
||||
return errors.New("forced update failure")
|
||||
}
|
||||
return s.Storage.UpdateClient(ctx, id, updater)
|
||||
}
|
||||
|
||||
func TestSetClientsMFAChainRollsBackUpdatedClients(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
st := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
|
||||
|
||||
require.NoError(t, st.CreateClient(ctx, storage.Client{ID: "client-1", MFAChain: []string{"old-1"}}))
|
||||
require.NoError(t, st.CreateClient(ctx, storage.Client{ID: "client-2", MFAChain: []string{"old-2"}}))
|
||||
|
||||
err := SetClientsMFAChain(ctx, &updateFailingStorage{Storage: st, failClientID: "client-2"}, []string{"client-1", "client-2"}, []string{"new"})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "failed to update MFA chain on client client-2")
|
||||
|
||||
client1, err := st.GetClient(ctx, "client-1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"old-1"}, client1.MFAChain)
|
||||
|
||||
client2, err := st.GetClient(ctx, "client-2")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"old-2"}, client2.MFAChain)
|
||||
}
|
||||
|
||||
func TestUserCreationFlow(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
@@ -556,7 +556,7 @@ start_services_and_show_instructions() {
|
||||
echo "Creating proxy access token..."
|
||||
# Use docker exec with bash to run the token command directly
|
||||
PROXY_TOKEN=$($DOCKER_COMPOSE_COMMAND exec -T netbird-server \
|
||||
/go/bin/netbird-server token create --name "default-proxy" --config /etc/netbird/config.yaml 2>/dev/null | grep "^Token:" | awk '{print $2}')
|
||||
/go/bin/netbird-server admin token create --name "default-proxy" --config /etc/netbird/config.yaml 2>/dev/null | grep "^Token:" | awk '{print $2}')
|
||||
|
||||
if [[ -z "$PROXY_TOKEN" ]]; then
|
||||
echo "ERROR: Failed to create proxy token. Check netbird-server logs." > /dev/stderr
|
||||
|
||||
177
management/cmd/admin.go
Normal file
177
management/cmd/admin.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
|
||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
activitystore "github.com/netbirdio/netbird/management/server/activity/store"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var adminDatadir string
|
||||
|
||||
// newAdminCommands creates the admin command tree with management-specific resource openers.
|
||||
func newAdminCommands() *cobra.Command {
|
||||
cmd := admincmd.NewCommands(admincmd.Openers{
|
||||
Resources: withAdminResources,
|
||||
Store: withAdminStoreOnly,
|
||||
IDP: withAdminIDPOnly,
|
||||
})
|
||||
cmd.PersistentFlags().StringVar(&adminDatadir, "datadir", "", "Override the data directory from config (used for store.db and the default idp.db)")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newLegacyTokenCommand() *cobra.Command {
|
||||
cmd := tokencmd.NewCommands(tokencmd.StoreOpener(withAdminStoreOnly))
|
||||
cmd.Deprecated = "use 'admin token' instead"
|
||||
cmd.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
||||
return cmd
|
||||
}
|
||||
|
||||
// withAdminResources initializes logging, loads config, opens the management store
|
||||
// and embedded IdP storage, and calls fn.
|
||||
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
|
||||
return withAdminConfig(cmd, true, func(ctx context.Context, config *nbconfig.Config, datadir string) error {
|
||||
managementStore, err := openAdminStore(ctx, config, datadir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer admincmd.CloseStore(ctx, managementStore)
|
||||
|
||||
idpStorage, idpStorageFile, err := admincmd.OpenIDPStorage(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer admincmd.CloseIDPStorage(idpStorage)
|
||||
|
||||
eventStore, esErr := openAdminEventStore(ctx, config, datadir)
|
||||
if esErr != nil {
|
||||
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Warning: audit events will not be recorded: %v\n", esErr)
|
||||
}
|
||||
if eventStore != nil {
|
||||
defer func() {
|
||||
if err := eventStore.Close(ctx); err != nil {
|
||||
log.Debugf("close activity event store: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage, IDPStorageFile: idpStorageFile, EventStore: eventStore})
|
||||
})
|
||||
}
|
||||
|
||||
// withAdminStoreOnly opens only the management store for admin subcommands that do not
|
||||
// need embedded IdP storage.
|
||||
func withAdminStoreOnly(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||
return withAdminConfig(cmd, false, func(ctx context.Context, config *nbconfig.Config, datadir string) error {
|
||||
managementStore, err := openAdminStore(ctx, config, datadir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer admincmd.CloseStore(ctx, managementStore)
|
||||
|
||||
return fn(ctx, managementStore)
|
||||
})
|
||||
}
|
||||
|
||||
func withAdminIDPOnly(cmd *cobra.Command, fn func(ctx context.Context, idpStorage storage.Storage, storageFile string) error) error {
|
||||
return withAdminConfig(cmd, true, func(ctx context.Context, config *nbconfig.Config, _ string) error {
|
||||
idpStorage, idpStorageFile, err := admincmd.OpenIDPStorage(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer admincmd.CloseIDPStorage(idpStorage)
|
||||
|
||||
return fn(ctx, idpStorage, idpStorageFile)
|
||||
})
|
||||
}
|
||||
|
||||
func withAdminConfig(cmd *cobra.Command, applyIDPDefaults bool, fn func(ctx context.Context, config *nbconfig.Config, datadir string) error) error {
|
||||
if err := util.InitLog("error", "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
|
||||
config, datadir, err := loadAdminMgmtConfig(ctx, applyIDPDefaults)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
return fn(ctx, config, datadir)
|
||||
}
|
||||
|
||||
func loadAdminMgmtConfig(ctx context.Context, applyIDPDefaults bool) (*nbconfig.Config, string, error) {
|
||||
config := &nbconfig.Config{}
|
||||
if _, err := util.ReadJsonWithEnvSub(nbconfig.MgmtConfigPath, config); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
if applyIDPDefaults {
|
||||
if err := ApplyEmbeddedIdPConfig(ctx, config); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
}
|
||||
|
||||
datadir := config.Datadir
|
||||
applyAdminDatadirOverride(config, &datadir)
|
||||
return config, datadir, nil
|
||||
}
|
||||
|
||||
func applyAdminDatadirOverride(config *nbconfig.Config, datadir *string) {
|
||||
if adminDatadir == "" {
|
||||
return
|
||||
}
|
||||
|
||||
oldDatadir := *datadir
|
||||
*datadir = adminDatadir
|
||||
if config.EmbeddedIdP != nil && config.EmbeddedIdP.Storage.Type == "sqlite3" && isDefaultIDPStorageFile(config.EmbeddedIdP.Storage.Config.File, oldDatadir) {
|
||||
config.EmbeddedIdP.Storage.Config.File = filepath.Join(*datadir, "idp.db")
|
||||
}
|
||||
}
|
||||
|
||||
func isDefaultIDPStorageFile(file, datadir string) bool {
|
||||
if file == "" {
|
||||
return true
|
||||
}
|
||||
defaultFile := filepath.Join(datadir, "idp.db")
|
||||
legacyDefaultFile := path.Join(datadir, "idp.db")
|
||||
legacySlashDefaultFile := path.Join(filepath.ToSlash(datadir), "idp.db")
|
||||
return filepath.Clean(file) == filepath.Clean(defaultFile) ||
|
||||
file == legacyDefaultFile ||
|
||||
filepath.ToSlash(file) == legacySlashDefaultFile
|
||||
}
|
||||
|
||||
func openAdminStore(ctx context.Context, config *nbconfig.Config, datadir string) (store.Store, error) {
|
||||
managementStore, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create store: %w", err)
|
||||
}
|
||||
return managementStore, nil
|
||||
}
|
||||
|
||||
func openAdminEventStore(ctx context.Context, config *nbconfig.Config, datadir string) (activity.Store, error) {
|
||||
if config.DataStoreEncryptionKey == "" {
|
||||
return nil, fmt.Errorf("data store encryption key is not configured")
|
||||
}
|
||||
eventStore, err := activitystore.NewSqlStore(ctx, datadir, config.DataStoreEncryptionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open activity event store: %w", err)
|
||||
}
|
||||
if eventStore == nil {
|
||||
return nil, fmt.Errorf("open activity event store: returned nil store")
|
||||
}
|
||||
return eventStore, nil
|
||||
}
|
||||
577
management/cmd/admin/admin.go
Normal file
577
management/cmd/admin/admin.go
Normal file
@@ -0,0 +1,577 @@
|
||||
// Package admincmd provides reusable cobra commands for self-hosted administrator helpers.
|
||||
// Both the management and combined binaries use these commands, each providing
|
||||
// their own opener to handle config loading and storage initialization.
|
||||
package admincmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
nbdex "github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/cmd/proxy"
|
||||
"github.com/netbirdio/netbird/management/cmd/token"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// Resources contains the storages required by the admin commands.
|
||||
type Resources struct {
|
||||
Store store.Store
|
||||
IDPStorage storage.Storage
|
||||
IDPStorageFile string
|
||||
EventStore activity.Store
|
||||
}
|
||||
|
||||
// Opener initializes command resources from the command context and calls fn.
|
||||
type Opener func(cmd *cobra.Command, fn func(ctx context.Context, resources Resources) error) error
|
||||
|
||||
// StoreOpener initializes only the management store from the command context and calls fn.
|
||||
type StoreOpener func(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error
|
||||
|
||||
// IDPOpener initializes only the embedded IdP storage from the command context and calls fn.
|
||||
type IDPOpener func(cmd *cobra.Command, fn func(ctx context.Context, idpStorage storage.Storage, storageFile string) error) error
|
||||
|
||||
// Openers contains the resource openers needed by the admin command tree.
|
||||
type Openers struct {
|
||||
Resources Opener
|
||||
Store StoreOpener
|
||||
IDP IDPOpener
|
||||
}
|
||||
|
||||
type userSelector struct {
|
||||
email string
|
||||
userID string
|
||||
}
|
||||
|
||||
func (s userSelector) normalized() userSelector {
|
||||
return userSelector{
|
||||
email: strings.TrimSpace(s.email),
|
||||
userID: strings.TrimSpace(s.userID),
|
||||
}
|
||||
}
|
||||
|
||||
func (s userSelector) validate() error {
|
||||
s = s.normalized()
|
||||
if (s.email == "") == (s.userID == "") {
|
||||
return fmt.Errorf("provide exactly one of --email or --user-id")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewCommands creates the admin command tree with the given resource openers.
|
||||
func NewCommands(openers Openers) *cobra.Command {
|
||||
adminCmd := &cobra.Command{
|
||||
Use: "admin",
|
||||
Short: "Self-hosted administrator helpers",
|
||||
Long: "Administrative helpers for self-hosted deployments using the embedded identity provider.",
|
||||
}
|
||||
|
||||
userCmd := &cobra.Command{
|
||||
Use: "user",
|
||||
Short: "Manage local embedded IdP users",
|
||||
}
|
||||
|
||||
var passwordSelector userSelector
|
||||
var password string
|
||||
var passwordFile string
|
||||
passwordCmd := &cobra.Command{
|
||||
Use: "change-password (--email email | --user-id id) (--password password | --password-file path)",
|
||||
Aliases: []string{"set-password"},
|
||||
Short: "Change a local user's password",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
if err := passwordSelector.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
newPassword, err := resolvePasswordInput(cmd, password, passwordFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return openers.IDP(cmd, func(ctx context.Context, idpStorage storage.Storage, storageFile string) error {
|
||||
return runChangePassword(ctx, idpStorage, cmd.OutOrStdout(), passwordSelector, newPassword, storageFile)
|
||||
})
|
||||
},
|
||||
}
|
||||
addUserSelectorFlags(passwordCmd, &passwordSelector)
|
||||
passwordCmd.Flags().StringVar(&password, "password", "", "New password for the user")
|
||||
passwordCmd.Flags().StringVar(&passwordFile, "password-file", "", "Read new password from file ('-' for stdin)")
|
||||
|
||||
var resetSelector userSelector
|
||||
resetMFACmd := &cobra.Command{
|
||||
Use: "reset-mfa (--email email | --user-id id)",
|
||||
Short: "Reset a local user's MFA enrollment",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
if err := resetSelector.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
return openers.IDP(cmd, func(ctx context.Context, idpStorage storage.Storage, storageFile string) error {
|
||||
return runResetMFA(ctx, idpStorage, cmd.OutOrStdout(), resetSelector, storageFile)
|
||||
})
|
||||
},
|
||||
}
|
||||
addUserSelectorFlags(resetMFACmd, &resetSelector)
|
||||
|
||||
userCmd.AddCommand(passwordCmd, resetMFACmd)
|
||||
|
||||
mfaCmd := &cobra.Command{
|
||||
Use: "mfa",
|
||||
Short: "Manage local MFA for embedded IdP users",
|
||||
}
|
||||
|
||||
enableCmd := &cobra.Command{
|
||||
Use: "enable",
|
||||
Short: "Enable MFA for local embedded IdP users",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return openers.Resources(cmd, func(ctx context.Context, resources Resources) error {
|
||||
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), true)
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
disableCmd := &cobra.Command{
|
||||
Use: "disable",
|
||||
Short: "Disable MFA for local embedded IdP users",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return openers.Resources(cmd, func(ctx context.Context, resources Resources) error {
|
||||
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), false)
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
statusCmd := &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "Show local MFA status",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return openers.Resources(cmd, func(ctx context.Context, resources Resources) error {
|
||||
return runMFAStatus(ctx, resources, cmd.OutOrStdout())
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
mfaCmd.AddCommand(enableCmd, disableCmd, statusCmd)
|
||||
adminCmd.AddCommand(userCmd, mfaCmd)
|
||||
if openers.Store != nil {
|
||||
adminCmd.AddCommand(tokencmd.NewCommands(tokencmd.StoreOpener(openers.Store)))
|
||||
adminCmd.AddCommand(proxycmd.NewCommands(proxycmd.StoreOpener(openers.Store)))
|
||||
}
|
||||
return adminCmd
|
||||
}
|
||||
|
||||
// OpenEmbeddedIDPStorage opens the Dex storage configured for the embedded IdP.
|
||||
func OpenEmbeddedIDPStorage(cfg *idp.EmbeddedIdPConfig) (storage.Storage, error) {
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
return nil, fmt.Errorf("admin commands require the embedded IdP to be enabled")
|
||||
}
|
||||
|
||||
yamlConfig, err := cfg.ToYAMLConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build embedded IdP config: %w", err)
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
st, err := yamlConfig.Storage.OpenStorage(logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open embedded IdP storage: %w", err)
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// CloseStore closes the management store and logs cleanup errors at debug level.
|
||||
func CloseStore(ctx context.Context, s store.Store) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
if err := s.Close(ctx); err != nil {
|
||||
log.Debugf("close store: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// OpenIDPStorage opens embedded IdP storage and returns its sqlite file path when applicable.
|
||||
func OpenIDPStorage(config *nbconfig.Config) (storage.Storage, string, error) {
|
||||
if config == nil {
|
||||
return nil, "", fmt.Errorf("management config is required")
|
||||
}
|
||||
idpStorage, err := OpenEmbeddedIDPStorage(config.EmbeddedIdP)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return idpStorage, embeddedIDPStorageFile(config), nil
|
||||
}
|
||||
|
||||
func embeddedIDPStorageFile(config *nbconfig.Config) string {
|
||||
if config.EmbeddedIdP == nil || config.EmbeddedIdP.Storage.Type != "sqlite3" {
|
||||
return ""
|
||||
}
|
||||
return config.EmbeddedIdP.Storage.Config.File
|
||||
}
|
||||
|
||||
// CloseIDPStorage closes embedded IdP storage and logs cleanup errors at debug level.
|
||||
func CloseIDPStorage(s storage.Storage) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
if err := s.Close(); err != nil {
|
||||
log.Debugf("close embedded IdP storage: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func addUserSelectorFlags(cmd *cobra.Command, selector *userSelector) {
|
||||
cmd.Flags().StringVar(&selector.email, "email", "", "User email")
|
||||
cmd.Flags().StringVar(&selector.userID, "user-id", "", "User ID")
|
||||
}
|
||||
|
||||
func resolvePasswordInput(cmd *cobra.Command, password, passwordFile string) (string, error) {
|
||||
if password != "" && passwordFile != "" {
|
||||
return "", fmt.Errorf("provide only one of --password or --password-file")
|
||||
}
|
||||
if passwordFile == "" {
|
||||
return password, nil
|
||||
}
|
||||
|
||||
var data []byte
|
||||
var err error
|
||||
if passwordFile == "-" {
|
||||
data, err = io.ReadAll(cmd.InOrStdin())
|
||||
} else {
|
||||
data, err = os.ReadFile(passwordFile)
|
||||
}
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read password: %w", err)
|
||||
}
|
||||
return strings.TrimRight(string(data), "\r\n"), nil
|
||||
}
|
||||
|
||||
func runChangePassword(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector, password string, idpStorageFile string) error {
|
||||
if idpStorage == nil {
|
||||
return fmt.Errorf("embedded IdP storage is required")
|
||||
}
|
||||
selector = selector.normalized()
|
||||
if err := selector.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
if password == "" {
|
||||
return fmt.Errorf("password is required")
|
||||
}
|
||||
if err := server.ValidatePassword(password); err != nil {
|
||||
return fmt.Errorf("invalid password: %w", err)
|
||||
}
|
||||
|
||||
user, err := findLocalUser(ctx, idpStorage, selector, idpStorageFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
if err := idpStorage.UpdatePassword(ctx, user.Email, func(old storage.Password) (storage.Password, error) {
|
||||
old.Hash = hash
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
return fmt.Errorf("update password for %s: %w", user.Email, err)
|
||||
}
|
||||
|
||||
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(w, "Password updated for %s.\n", user.Email)
|
||||
return nil
|
||||
}
|
||||
|
||||
func runResetMFA(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector, idpStorageFile string) error {
|
||||
if idpStorage == nil {
|
||||
return fmt.Errorf("embedded IdP storage is required")
|
||||
}
|
||||
selector = selector.normalized()
|
||||
if err := selector.validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := findLocalUser(ctx, idpStorage, selector, idpStorageFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reset := false
|
||||
err = idpStorage.UpdateUserIdentity(ctx, user.UserID, idp.LocalConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
|
||||
reset = reset || len(old.MFASecrets) > 0 || len(old.WebAuthnCredentials) > 0
|
||||
old.MFASecrets = map[string]*storage.MFASecret{}
|
||||
old.WebAuthnCredentials = map[string][]storage.WebAuthnCredential{}
|
||||
return old, nil
|
||||
})
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
|
||||
return err
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("reset MFA for %s: %w", user.Email, err)
|
||||
}
|
||||
|
||||
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reset {
|
||||
_, _ = fmt.Fprintf(w, "MFA reset for %s. The user will re-enroll at next login.\n", user.Email)
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSetMFAEnabled(ctx context.Context, resources Resources, w io.Writer, enabled bool) error {
|
||||
if resources.Store == nil {
|
||||
return fmt.Errorf("management store is required")
|
||||
}
|
||||
if resources.IDPStorage == nil {
|
||||
return fmt.Errorf("embedded IdP storage is required")
|
||||
}
|
||||
|
||||
accountID, settings, err := getSingleAccountSettings(ctx, resources.Store)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldEnabled := settings.LocalMfaEnabled
|
||||
newSettings := settings.Copy()
|
||||
newSettings.LocalMfaEnabled = enabled
|
||||
|
||||
if err := setIDPClientsMFA(ctx, resources.IDPStorage, enabled); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := resources.Store.SaveAccountSettings(ctx, accountID, newSettings); err != nil {
|
||||
if rollbackErr := setIDPClientsMFA(ctx, resources.IDPStorage, oldEnabled); rollbackErr != nil {
|
||||
return fmt.Errorf("save local MFA account setting: %w (also failed to roll back embedded IdP MFA state: %v)", err, rollbackErr)
|
||||
}
|
||||
return fmt.Errorf("save local MFA account setting: %w", err)
|
||||
}
|
||||
|
||||
if err := storeMFAActivity(ctx, resources.EventStore, accountID, enabled); err != nil {
|
||||
_, _ = fmt.Fprintf(w, "Warning: failed to record audit event: %v\n", err)
|
||||
}
|
||||
|
||||
state := "disabled"
|
||||
if enabled {
|
||||
state = "enabled"
|
||||
}
|
||||
_, _ = fmt.Fprintf(w, "Local MFA %s.\n", state)
|
||||
return nil
|
||||
}
|
||||
|
||||
func runMFAStatus(ctx context.Context, resources Resources, w io.Writer) error {
|
||||
if resources.Store == nil {
|
||||
return fmt.Errorf("management store is required")
|
||||
}
|
||||
if resources.IDPStorage == nil {
|
||||
return fmt.Errorf("embedded IdP storage is required")
|
||||
}
|
||||
|
||||
_, settings, err := getSingleAccountSettings(ctx, resources.Store)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
accountStatus := "disabled"
|
||||
if settings.LocalMfaEnabled {
|
||||
accountStatus = "enabled"
|
||||
}
|
||||
|
||||
clientStatus, err := idpClientsMFAStatus(ctx, resources.IDPStorage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(w, "Account setting: %s\n", accountStatus)
|
||||
_, _ = fmt.Fprintf(w, "Embedded IdP clients: %s\n", clientStatus)
|
||||
return nil
|
||||
}
|
||||
|
||||
func getSingleAccountSettings(ctx context.Context, s store.Store) (string, *types.Settings, error) {
|
||||
count, err := s.GetAccountsCounter(ctx)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("count accounts: %w", err)
|
||||
}
|
||||
if count != 1 {
|
||||
return "", nil, fmt.Errorf("expected exactly one account, got %d; local MFA is supported only in single-account embedded IdP deployments", count)
|
||||
}
|
||||
|
||||
accountID, err := s.GetAnyAccountID(ctx)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("get account ID: %w", err)
|
||||
}
|
||||
|
||||
settings, err := s.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("get account settings: %w", err)
|
||||
}
|
||||
if settings == nil {
|
||||
settings = &types.Settings{}
|
||||
}
|
||||
return accountID, settings, nil
|
||||
}
|
||||
|
||||
func storeMFAActivity(ctx context.Context, eventStore activity.Store, accountID string, enabled bool) error {
|
||||
if eventStore == nil {
|
||||
return nil
|
||||
}
|
||||
event := activity.AccountLocalMfaDisabled
|
||||
if enabled {
|
||||
event = activity.AccountLocalMfaEnabled
|
||||
}
|
||||
_, err := eventStore.Save(ctx, &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: event,
|
||||
InitiatorID: string(hook.SystemSource),
|
||||
TargetID: accountID,
|
||||
AccountID: accountID,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("save local MFA audit event: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func findLocalUser(ctx context.Context, idpStorage storage.Storage, selector userSelector, idpStorageFile string) (storage.Password, error) {
|
||||
selector = selector.normalized()
|
||||
if err := selector.validate(); err != nil {
|
||||
return storage.Password{}, err
|
||||
}
|
||||
|
||||
if selector.email != "" {
|
||||
user, err := idpStorage.GetPassword(ctx, selector.email)
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
if empty, listErr := localUsersEmpty(ctx, idpStorage); listErr != nil {
|
||||
return storage.Password{}, listErr
|
||||
} else if empty {
|
||||
return storage.Password{}, noLocalUsersError(idpStorageFile)
|
||||
}
|
||||
return storage.Password{}, fmt.Errorf("local user with email %q not found", selector.email)
|
||||
}
|
||||
if err != nil {
|
||||
return storage.Password{}, fmt.Errorf("get local user by email %q: %w", selector.email, err)
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
rawUserID := selector.userID
|
||||
if decodedUserID, _, err := nbdex.DecodeDexUserID(selector.userID); err == nil && decodedUserID != "" {
|
||||
rawUserID = decodedUserID
|
||||
}
|
||||
|
||||
users, err := idpStorage.ListPasswords(ctx)
|
||||
if err != nil {
|
||||
return storage.Password{}, fmt.Errorf("list local users: %w", err)
|
||||
}
|
||||
for _, user := range users {
|
||||
if user.UserID == rawUserID || user.UserID == selector.userID {
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(users) == 0 {
|
||||
return storage.Password{}, noLocalUsersError(idpStorageFile)
|
||||
}
|
||||
|
||||
return storage.Password{}, fmt.Errorf("local user with ID %q not found", selector.userID)
|
||||
}
|
||||
|
||||
func localUsersEmpty(ctx context.Context, idpStorage storage.Storage) (bool, error) {
|
||||
users, err := idpStorage.ListPasswords(ctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("list local users: %w", err)
|
||||
}
|
||||
return len(users) == 0, nil
|
||||
}
|
||||
|
||||
func noLocalUsersError(idpStorageFile string) error {
|
||||
location := ""
|
||||
if idpStorageFile != "" {
|
||||
location = fmt.Sprintf(" (%s)", idpStorageFile)
|
||||
}
|
||||
return fmt.Errorf("no local users exist in the embedded IdP storage%s; the management server may never have started with this config, or --datadir points at the wrong location", location)
|
||||
}
|
||||
|
||||
func deleteLocalAuthSession(ctx context.Context, idpStorage storage.Storage, userID string) error {
|
||||
err := idpStorage.DeleteAuthSession(ctx, userID, idp.LocalConnectorID)
|
||||
if err == nil || errors.Is(err, storage.ErrNotFound) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("delete local auth session for user %s: %w", userID, err)
|
||||
}
|
||||
|
||||
func setIDPClientsMFA(ctx context.Context, idpStorage storage.Storage, enabled bool) error {
|
||||
var mfaChain []string
|
||||
if enabled {
|
||||
mfaChain = []string{idp.DefaultTOTPAuthenticatorID}
|
||||
}
|
||||
|
||||
clientIDs := []string{idp.StaticClientCLI, idp.StaticClientDashboard}
|
||||
if err := nbdex.SetClientsMFAChain(ctx, idpStorage, clientIDs, mfaChain); err != nil {
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return fmt.Errorf("embedded IdP client not found; start the management server once before toggling MFA: %w", err)
|
||||
}
|
||||
return fmt.Errorf("update MFA chain on embedded IdP clients: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func idpClientsMFAStatus(ctx context.Context, idpStorage storage.Storage) (string, error) {
|
||||
clientIDs := []string{idp.StaticClientCLI, idp.StaticClientDashboard}
|
||||
enabledCount := 0
|
||||
for _, clientID := range clientIDs {
|
||||
client, err := idpStorage.GetClient(ctx, clientID)
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
return "unknown", fmt.Errorf("embedded IdP client %q not found", clientID)
|
||||
}
|
||||
if err != nil {
|
||||
return "unknown", fmt.Errorf("get embedded IdP client %q: %w", clientID, err)
|
||||
}
|
||||
if hasAuthenticator(client.MFAChain, idp.DefaultTOTPAuthenticatorID) {
|
||||
enabledCount++
|
||||
}
|
||||
}
|
||||
|
||||
switch enabledCount {
|
||||
case 0:
|
||||
return "disabled", nil
|
||||
case len(clientIDs):
|
||||
return "enabled", nil
|
||||
default:
|
||||
return "partially enabled", nil
|
||||
}
|
||||
}
|
||||
|
||||
func hasAuthenticator(chain []string, authenticatorID string) bool {
|
||||
for _, id := range chain {
|
||||
if id == authenticatorID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
250
management/cmd/admin/admin_test.go
Normal file
250
management/cmd/admin/admin_test.go
Normal file
@@ -0,0 +1,250 @@
|
||||
package admincmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/memory"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
nbdex "github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
mgmtstore "github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func newTestIDPStorage(t *testing.T) storage.Storage {
|
||||
t.Helper()
|
||||
|
||||
st := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte("OldPass1!"), bcrypt.DefaultCost)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, st.CreatePassword(context.Background(), storage.Password{
|
||||
Email: "user@example.com",
|
||||
Username: "User",
|
||||
UserID: "user-1",
|
||||
Hash: hash,
|
||||
}))
|
||||
require.NoError(t, st.CreateUserIdentity(context.Background(), storage.UserIdentity{
|
||||
UserID: "user-1",
|
||||
ConnectorID: idp.LocalConnectorID,
|
||||
MFASecrets: map[string]*storage.MFASecret{
|
||||
idp.DefaultTOTPAuthenticatorID: {
|
||||
AuthenticatorID: idp.DefaultTOTPAuthenticatorID,
|
||||
Type: "TOTP",
|
||||
Secret: "otpauth://totp/NetBird:user@example.com?secret=ABC",
|
||||
Confirmed: true,
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
},
|
||||
WebAuthnCredentials: map[string][]storage.WebAuthnCredential{
|
||||
"webauthn": {{CredentialID: []byte("credential")}},
|
||||
},
|
||||
}))
|
||||
require.NoError(t, st.CreateAuthSession(context.Background(), storage.AuthSession{
|
||||
UserID: "user-1",
|
||||
ConnectorID: idp.LocalConnectorID,
|
||||
Nonce: "nonce",
|
||||
}))
|
||||
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: idp.StaticClientCLI, Name: "CLI"}))
|
||||
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: idp.StaticClientDashboard, Name: "Dashboard"}))
|
||||
|
||||
return st
|
||||
}
|
||||
|
||||
func TestRunChangePassword(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
st := newTestIDPStorage(t)
|
||||
var out bytes.Buffer
|
||||
|
||||
err := runChangePassword(ctx, st, &out, userSelector{email: "user@example.com"}, "NewPass1!", "")
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, out.String(), "Password updated")
|
||||
|
||||
user, err := st.GetPassword(ctx, "user@example.com")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, bcrypt.CompareHashAndPassword(user.Hash, []byte("NewPass1!")))
|
||||
|
||||
_, err = st.GetAuthSession(ctx, "user-1", idp.LocalConnectorID)
|
||||
require.ErrorIs(t, err, storage.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestRunChangePasswordValidatesPassword(t *testing.T) {
|
||||
st := newTestIDPStorage(t)
|
||||
err := runChangePassword(context.Background(), st, io.Discard, userSelector{email: "user@example.com"}, "short", "")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid password")
|
||||
}
|
||||
|
||||
func TestRunResetMFA(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
st := newTestIDPStorage(t)
|
||||
var out bytes.Buffer
|
||||
|
||||
encodedUserID := nbdex.EncodeDexUserID("user-1", idp.LocalConnectorID)
|
||||
err := runResetMFA(ctx, st, &out, userSelector{userID: encodedUserID}, "")
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, out.String(), "MFA reset")
|
||||
|
||||
identity, err := st.GetUserIdentity(ctx, "user-1", idp.LocalConnectorID)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, identity.MFASecrets)
|
||||
require.Empty(t, identity.WebAuthnCredentials)
|
||||
|
||||
_, err = st.GetAuthSession(ctx, "user-1", idp.LocalConnectorID)
|
||||
require.ErrorIs(t, err, storage.ErrNotFound)
|
||||
}
|
||||
|
||||
func TestRunResetMFAWithoutEnrollment(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
st := newTestIDPStorage(t)
|
||||
require.NoError(t, st.UpdateUserIdentity(ctx, "user-1", idp.LocalConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
|
||||
old.MFASecrets = nil
|
||||
old.WebAuthnCredentials = nil
|
||||
return old, nil
|
||||
}))
|
||||
|
||||
var out bytes.Buffer
|
||||
err := runResetMFA(ctx, st, &out, userSelector{email: "user@example.com"}, "")
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, out.String(), "No MFA enrollment found")
|
||||
}
|
||||
|
||||
func TestSetIDPClientsMFA(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
st := newTestIDPStorage(t)
|
||||
|
||||
require.NoError(t, setIDPClientsMFA(ctx, st, true))
|
||||
status, err := idpClientsMFAStatus(ctx, st)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "enabled", status)
|
||||
|
||||
require.NoError(t, setIDPClientsMFA(ctx, st, false))
|
||||
status, err = idpClientsMFAStatus(ctx, st)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "disabled", status)
|
||||
}
|
||||
|
||||
func newTestManagementStore(t *testing.T, localMFAEnabled bool) mgmtstore.Store {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
st, err := mgmtstore.NewStore(ctx, types.SqliteStoreEngine, t.TempDir(), nil, false)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, st.Close(ctx)) })
|
||||
require.NoError(t, st.SaveAccount(ctx, &types.Account{
|
||||
Id: "account-1",
|
||||
Settings: &types.Settings{LocalMfaEnabled: localMFAEnabled},
|
||||
}))
|
||||
return st
|
||||
}
|
||||
|
||||
func TestRunSetMFAEnabledDoesNotSaveWhenIDPUpdateFails(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
managementStore := newTestManagementStore(t, false)
|
||||
idpStorage := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
|
||||
|
||||
err := runSetMFAEnabled(ctx, Resources{Store: managementStore, IDPStorage: idpStorage}, io.Discard, true)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "embedded IdP client")
|
||||
|
||||
settings, err := managementStore.GetAccountSettings(ctx, mgmtstore.LockingStrengthNone, "account-1")
|
||||
require.NoError(t, err)
|
||||
require.False(t, settings.LocalMfaEnabled)
|
||||
}
|
||||
|
||||
func TestRunSetMFAEnabledUpdatesSettingsAfterIDP(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
managementStore := newTestManagementStore(t, false)
|
||||
idpStorage := newTestIDPStorage(t)
|
||||
|
||||
err := runSetMFAEnabled(ctx, Resources{Store: managementStore, IDPStorage: idpStorage}, io.Discard, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
settings, err := managementStore.GetAccountSettings(ctx, mgmtstore.LockingStrengthNone, "account-1")
|
||||
require.NoError(t, err)
|
||||
require.True(t, settings.LocalMfaEnabled)
|
||||
clientStatus, err := idpClientsMFAStatus(ctx, idpStorage)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "enabled", clientStatus)
|
||||
}
|
||||
|
||||
func TestRunSetMFAEnabledSucceedsWithNilEventStore(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
managementStore := newTestManagementStore(t, false)
|
||||
idpStorage := newTestIDPStorage(t)
|
||||
var out bytes.Buffer
|
||||
var err error
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
err = runSetMFAEnabled(ctx, Resources{Store: managementStore, IDPStorage: idpStorage, EventStore: nil}, &out, true)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, out.String(), "Local MFA enabled")
|
||||
|
||||
settings, err := managementStore.GetAccountSettings(ctx, mgmtstore.LockingStrengthNone, "account-1")
|
||||
require.NoError(t, err)
|
||||
require.True(t, settings.LocalMfaEnabled)
|
||||
}
|
||||
|
||||
func TestUserSelectorValidate(t *testing.T) {
|
||||
require.NoError(t, userSelector{email: " user@example.com "}.validate())
|
||||
require.NoError(t, userSelector{userID: "user-1"}.validate())
|
||||
require.Error(t, userSelector{}.validate())
|
||||
require.Error(t, userSelector{email: "user@example.com", userID: "user-1"}.validate())
|
||||
}
|
||||
|
||||
func TestFindLocalUserNotFound(t *testing.T) {
|
||||
st := newTestIDPStorage(t)
|
||||
_, err := findLocalUser(context.Background(), st, userSelector{email: "missing@example.com"}, "")
|
||||
require.Error(t, err)
|
||||
require.True(t, strings.Contains(err.Error(), "not found"))
|
||||
}
|
||||
|
||||
func TestFindLocalUserZeroUsersIncludesStoragePath(t *testing.T) {
|
||||
st := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
|
||||
_, err := findLocalUser(context.Background(), st, userSelector{email: "missing@example.com"}, "/var/lib/netbird/idp.db")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "no local users exist")
|
||||
require.Contains(t, err.Error(), "/var/lib/netbird/idp.db")
|
||||
}
|
||||
|
||||
func TestUserCommandValidatesSelectorBeforeOpeningStorage(t *testing.T) {
|
||||
opened := false
|
||||
cmd := NewCommands(Openers{
|
||||
IDP: func(cmd *cobra.Command, fn func(ctx context.Context, idpStorage storage.Storage, storageFile string) error) error {
|
||||
opened = true
|
||||
return nil
|
||||
},
|
||||
})
|
||||
cmd.SetArgs([]string{"user", "change-password", "--password", "NewPass1!"})
|
||||
cmd.SetOut(io.Discard)
|
||||
cmd.SetErr(io.Discard)
|
||||
|
||||
err := cmd.Execute()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "provide exactly one")
|
||||
require.False(t, opened)
|
||||
}
|
||||
|
||||
func TestResolvePasswordInputFromStdin(t *testing.T) {
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetIn(strings.NewReader("NewPass1!\n"))
|
||||
|
||||
password, err := resolvePasswordInput(cmd, "", "-")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "NewPass1!", password)
|
||||
}
|
||||
|
||||
func TestResolvePasswordInputRejectsMultipleSources(t *testing.T) {
|
||||
_, err := resolvePasswordInput(&cobra.Command{}, "NewPass1!", "-")
|
||||
require.Error(t, err)
|
||||
}
|
||||
80
management/cmd/admin_config_test.go
Normal file
80
management/cmd/admin_config_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
)
|
||||
|
||||
func TestApplyAdminDatadirOverrideRelocatesDefaultIDPStorage(t *testing.T) {
|
||||
oldDatadir := filepath.Join(t.TempDir(), "old")
|
||||
newDatadir := filepath.Join(t.TempDir(), "new")
|
||||
|
||||
for _, defaultFile := range []string{
|
||||
"",
|
||||
filepath.Join(oldDatadir, "idp.db"),
|
||||
path.Join(oldDatadir, "idp.db"),
|
||||
} {
|
||||
t.Run(defaultFile, func(t *testing.T) {
|
||||
cfg := &nbconfig.Config{
|
||||
EmbeddedIdP: &idp.EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Storage: idp.EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: idp.EmbeddedStorageTypeConfig{
|
||||
File: defaultFile,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
datadir := oldDatadir
|
||||
oldAdminDatadir := adminDatadir
|
||||
adminDatadir = newDatadir
|
||||
t.Cleanup(func() { adminDatadir = oldAdminDatadir })
|
||||
|
||||
applyAdminDatadirOverride(cfg, &datadir)
|
||||
|
||||
require.Equal(t, newDatadir, datadir)
|
||||
require.Equal(t, filepath.Join(newDatadir, "idp.db"), cfg.EmbeddedIdP.Storage.Config.File)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAdminEventStoreMissingEncryptionKeyReturnsNilInterface(t *testing.T) {
|
||||
eventStore, err := openAdminEventStore(context.Background(), &nbconfig.Config{}, t.TempDir())
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "encryption key")
|
||||
require.Nil(t, eventStore)
|
||||
}
|
||||
|
||||
func TestApplyAdminDatadirOverrideKeepsExplicitIDPStorage(t *testing.T) {
|
||||
oldDatadir := filepath.Join(t.TempDir(), "old")
|
||||
newDatadir := filepath.Join(t.TempDir(), "new")
|
||||
explicitFile := filepath.Join(t.TempDir(), "custom-idp.db")
|
||||
cfg := &nbconfig.Config{
|
||||
EmbeddedIdP: &idp.EmbeddedIdPConfig{
|
||||
Enabled: true,
|
||||
Storage: idp.EmbeddedStorageConfig{
|
||||
Type: "sqlite3",
|
||||
Config: idp.EmbeddedStorageTypeConfig{
|
||||
File: explicitFile,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
datadir := oldDatadir
|
||||
oldAdminDatadir := adminDatadir
|
||||
adminDatadir = newDatadir
|
||||
t.Cleanup(func() { adminDatadir = oldAdminDatadir })
|
||||
|
||||
applyAdminDatadirOverride(cfg, &datadir)
|
||||
|
||||
require.Equal(t, newDatadir, datadir)
|
||||
require.Equal(t, explicitFile, cfg.EmbeddedIdP.Storage.Config.File)
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
@@ -209,7 +210,7 @@ func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
|
||||
cfg.EmbeddedIdP.Storage.Type = "sqlite3"
|
||||
}
|
||||
if cfg.EmbeddedIdP.Storage.Config.File == "" && cfg.Datadir != "" {
|
||||
cfg.EmbeddedIdP.Storage.Config.File = path.Join(cfg.Datadir, "idp.db")
|
||||
cfg.EmbeddedIdP.Storage.Config.File = filepath.Join(cfg.Datadir, "idp.db")
|
||||
}
|
||||
|
||||
issuer := cfg.EmbeddedIdP.Issuer
|
||||
|
||||
141
management/cmd/proxy/proxy.go
Normal file
141
management/cmd/proxy/proxy.go
Normal file
@@ -0,0 +1,141 @@
|
||||
// Package proxycmd provides reusable cobra commands for managing reverse proxy instances.
|
||||
// Both the management and combined binaries use these commands, each providing
|
||||
// their own StoreOpener to handle config loading and store initialization.
|
||||
package proxycmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
// StoreOpener initializes a store from the command context and calls fn.
|
||||
type StoreOpener func(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error
|
||||
|
||||
const disconnectAllConfirmation = "disconnect all proxies"
|
||||
|
||||
// NewCommands creates the proxy command tree with the given store opener.
|
||||
// Returns the parent "proxy" command with the disconnect-all subcommand.
|
||||
func NewCommands(opener StoreOpener) *cobra.Command {
|
||||
var dryRun bool
|
||||
var force bool
|
||||
|
||||
proxyCmd := &cobra.Command{
|
||||
Use: "proxy",
|
||||
Short: "Manage reverse proxy instances",
|
||||
Long: "Commands for inspecting and repairing the reverse proxy instances registered with the management server.",
|
||||
}
|
||||
|
||||
disconnectAllCmd := &cobra.Command{
|
||||
Use: "disconnect-all",
|
||||
Short: "Force-mark all reverse proxy instances as disconnected",
|
||||
Long: "Lists all reverse proxy instances and force-marks them as disconnected, regardless of their session state. " +
|
||||
"Use this to repair stale connection state, e.g. after an unclean management server shutdown. " +
|
||||
"By default, it asks for manual confirmation before changing state. Use --dry-run to preview without changing state, or --force to skip confirmation. " +
|
||||
"Run during a maintenance window; affected live proxies may stay hidden until their next heartbeat or reconnect/re-register.",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, _ []string) error {
|
||||
return opener(cmd, func(ctx context.Context, s store.Store) error {
|
||||
return runDisconnectAll(ctx, s, cmd.OutOrStdout(), cmd.InOrStdin(), dryRun, force)
|
||||
})
|
||||
},
|
||||
}
|
||||
disconnectAllCmd.Flags().BoolVar(&dryRun, "dry-run", false, "List reverse proxy instances that would be disconnected without changing state")
|
||||
disconnectAllCmd.Flags().BoolVar(&force, "force", false, "Skip the confirmation prompt and apply the repair")
|
||||
|
||||
proxyCmd.AddCommand(disconnectAllCmd)
|
||||
return proxyCmd
|
||||
}
|
||||
|
||||
func runDisconnectAll(ctx context.Context, s store.Store, out io.Writer, in io.Reader, dryRun, force bool) error {
|
||||
proxies, err := s.GetAllProxies(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list proxies: %w", err)
|
||||
}
|
||||
|
||||
if len(proxies) == 0 {
|
||||
_, _ = fmt.Fprintln(out, "No reverse proxy instances found.")
|
||||
return nil
|
||||
}
|
||||
|
||||
toDisconnect := 0
|
||||
w := tabwriter.NewWriter(out, 0, 0, 2, ' ', 0)
|
||||
_, _ = fmt.Fprintln(w, "ID\tCLUSTER\tIP\tACCOUNT\tSTATUS\tLAST SEEN")
|
||||
_, _ = fmt.Fprintln(w, "--\t-------\t--\t-------\t------\t---------")
|
||||
|
||||
for _, p := range proxies {
|
||||
if p.Status != rpproxy.StatusDisconnected {
|
||||
toDisconnect++
|
||||
}
|
||||
|
||||
account := "-"
|
||||
if p.AccountID != nil {
|
||||
account = *p.AccountID
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||
p.ID,
|
||||
p.ClusterAddress,
|
||||
p.IPAddress,
|
||||
account,
|
||||
p.Status,
|
||||
p.LastSeen.Format("2006-01-02 15:04:05"),
|
||||
)
|
||||
}
|
||||
if err := w.Flush(); err != nil {
|
||||
return fmt.Errorf("write proxy list: %w", err)
|
||||
}
|
||||
|
||||
if dryRun {
|
||||
_, _ = fmt.Fprintf(out, "\nDry run: would force-mark %d of %d reverse proxy instance(s) as disconnected.\n", toDisconnect, len(proxies))
|
||||
return nil
|
||||
}
|
||||
|
||||
if !force {
|
||||
confirmed, err := confirmDisconnectAll(out, in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !confirmed {
|
||||
_, _ = fmt.Fprintln(out, "Aborted. No reverse proxy instances were changed.")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
disconnected, err := s.DisconnectAllProxies(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("disconnect proxies: %w", err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(out, "\nForce-marked %d of %d reverse proxy instance(s) as disconnected.\n", disconnected, len(proxies))
|
||||
return nil
|
||||
}
|
||||
|
||||
func confirmDisconnectAll(out io.Writer, in io.Reader) (bool, error) {
|
||||
if in == nil {
|
||||
in = strings.NewReader("")
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(out, "\nWARNING: This command changes stored reverse proxy state for every non-disconnected instance.")
|
||||
_, _ = fmt.Fprintln(out, "Run it during a maintenance window; affected live proxies may stay hidden until "+
|
||||
"their next heartbeat or reconnect/re-register.")
|
||||
_, _ = fmt.Fprintf(out, "Type %q to continue: ", disconnectAllConfirmation)
|
||||
|
||||
scanner := bufio.NewScanner(in)
|
||||
if !scanner.Scan() {
|
||||
if err := scanner.Err(); err != nil {
|
||||
return false, fmt.Errorf("read confirmation: %w", err)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return strings.EqualFold(strings.TrimSpace(scanner.Text()), disconnectAllConfirmation), nil
|
||||
}
|
||||
180
management/cmd/proxy/proxy_test.go
Normal file
180
management/cmd/proxy/proxy_test.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package proxycmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
func newTestStore(t *testing.T) store.Store {
|
||||
t.Helper()
|
||||
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func seedProxies(t *testing.T, ctx context.Context, s store.Store) {
|
||||
t.Helper()
|
||||
|
||||
accountID := "account-1"
|
||||
alreadyDisconnectedAt := time.Now().Add(-time.Hour)
|
||||
seed := []*rpproxy.Proxy{
|
||||
{
|
||||
ID: "proxy-1",
|
||||
SessionID: "session-1",
|
||||
ClusterAddress: "cluster-a.example.com",
|
||||
IPAddress: "10.0.0.1",
|
||||
LastSeen: time.Now(),
|
||||
Status: rpproxy.StatusConnected,
|
||||
},
|
||||
{
|
||||
ID: "proxy-2",
|
||||
SessionID: "session-2",
|
||||
ClusterAddress: "cluster-b.example.com",
|
||||
IPAddress: "10.0.0.2",
|
||||
AccountID: &accountID,
|
||||
LastSeen: time.Now(),
|
||||
Status: rpproxy.StatusConnected,
|
||||
},
|
||||
{
|
||||
ID: "proxy-3",
|
||||
SessionID: "session-3",
|
||||
ClusterAddress: "cluster-a.example.com",
|
||||
IPAddress: "10.0.0.3",
|
||||
LastSeen: time.Now().Add(-time.Hour),
|
||||
Status: rpproxy.StatusDisconnected,
|
||||
DisconnectedAt: &alreadyDisconnectedAt,
|
||||
},
|
||||
}
|
||||
for _, p := range seed {
|
||||
require.NoError(t, s.SaveProxy(ctx, p))
|
||||
}
|
||||
}
|
||||
|
||||
func proxiesByID(t *testing.T, ctx context.Context, s store.Store) map[string]*rpproxy.Proxy {
|
||||
t.Helper()
|
||||
|
||||
proxies, err := s.GetAllProxies(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, proxies, 3)
|
||||
|
||||
byID := make(map[string]*rpproxy.Proxy, len(proxies))
|
||||
for _, p := range proxies {
|
||||
byID[p.ID] = p
|
||||
}
|
||||
return byID
|
||||
}
|
||||
|
||||
func TestRunDisconnectAllWithConfirmation(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s := newTestStore(t)
|
||||
seedProxies(t, ctx, s)
|
||||
|
||||
var out bytes.Buffer
|
||||
require.NoError(t, runDisconnectAll(ctx, s, &out, strings.NewReader(disconnectAllConfirmation+"\n"), false, false))
|
||||
|
||||
output := out.String()
|
||||
require.Contains(t, output, "proxy-1")
|
||||
require.Contains(t, output, "proxy-2")
|
||||
require.Contains(t, output, "proxy-3")
|
||||
require.Contains(t, output, "cluster-a.example.com")
|
||||
require.Contains(t, output, "account-1")
|
||||
require.Contains(t, output, "Type \"disconnect all proxies\" to continue")
|
||||
require.Contains(t, output, "Force-marked 2 of 3 reverse proxy instance(s) as disconnected.")
|
||||
|
||||
for _, p := range proxiesByID(t, ctx, s) {
|
||||
require.Equal(t, rpproxy.StatusDisconnected, p.Status, "proxy %s should be disconnected", p.ID)
|
||||
require.NotNil(t, p.DisconnectedAt, "proxy %s should have a disconnected timestamp", p.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunDisconnectAllForceSkipsConfirmation(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s := newTestStore(t)
|
||||
seedProxies(t, ctx, s)
|
||||
|
||||
var out bytes.Buffer
|
||||
require.NoError(t, runDisconnectAll(ctx, s, &out, strings.NewReader(""), false, true))
|
||||
|
||||
output := out.String()
|
||||
require.NotContains(t, output, "Type \"disconnect all proxies\" to continue")
|
||||
require.Contains(t, output, "Force-marked 2 of 3 reverse proxy instance(s) as disconnected.")
|
||||
}
|
||||
|
||||
func TestRunDisconnectAllAbortLeavesProxiesUnchanged(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s := newTestStore(t)
|
||||
seedProxies(t, ctx, s)
|
||||
|
||||
var out bytes.Buffer
|
||||
require.NoError(t, runDisconnectAll(ctx, s, &out, strings.NewReader("no\n"), false, false))
|
||||
|
||||
output := out.String()
|
||||
require.Contains(t, output, "Type \"disconnect all proxies\" to continue")
|
||||
require.Contains(t, output, "Aborted. No reverse proxy instances were changed.")
|
||||
|
||||
byID := proxiesByID(t, ctx, s)
|
||||
require.Equal(t, rpproxy.StatusConnected, byID["proxy-1"].Status)
|
||||
require.Equal(t, rpproxy.StatusConnected, byID["proxy-2"].Status)
|
||||
require.Equal(t, rpproxy.StatusDisconnected, byID["proxy-3"].Status)
|
||||
}
|
||||
|
||||
func TestRunDisconnectAllDryRunLeavesProxiesUnchanged(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s := newTestStore(t)
|
||||
seedProxies(t, ctx, s)
|
||||
|
||||
var out bytes.Buffer
|
||||
require.NoError(t, runDisconnectAll(ctx, s, &out, strings.NewReader(""), true, false))
|
||||
|
||||
output := out.String()
|
||||
require.Contains(t, output, "Dry run: would force-mark 2 of 3 reverse proxy instance(s) as disconnected.")
|
||||
require.NotContains(t, output, "Type \"disconnect all proxies\" to continue")
|
||||
|
||||
byID := proxiesByID(t, ctx, s)
|
||||
require.Equal(t, rpproxy.StatusConnected, byID["proxy-1"].Status)
|
||||
require.Equal(t, rpproxy.StatusConnected, byID["proxy-2"].Status)
|
||||
require.Equal(t, rpproxy.StatusDisconnected, byID["proxy-3"].Status)
|
||||
}
|
||||
|
||||
func TestNewCommandsDisconnectAllDryRun(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s := newTestStore(t)
|
||||
seedProxies(t, ctx, s)
|
||||
|
||||
opened := false
|
||||
cmd := NewCommands(func(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||
opened = true
|
||||
return fn(cmd.Context(), s)
|
||||
})
|
||||
|
||||
var out bytes.Buffer
|
||||
cmd.SetOut(&out)
|
||||
cmd.SetErr(&out)
|
||||
cmd.SetIn(strings.NewReader(""))
|
||||
cmd.SetArgs([]string{"disconnect-all", "--dry-run"})
|
||||
|
||||
require.NoError(t, cmd.ExecuteContext(ctx))
|
||||
require.True(t, opened)
|
||||
require.Contains(t, out.String(), "Dry run: would force-mark 2 of 3 reverse proxy instance(s) as disconnected.")
|
||||
}
|
||||
|
||||
func TestRunDisconnectAllEmpty(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
s := newTestStore(t)
|
||||
|
||||
var out bytes.Buffer
|
||||
require.NoError(t, runDisconnectAll(ctx, s, &out, strings.NewReader(""), false, false))
|
||||
require.Contains(t, out.String(), "No reverse proxy instances found.")
|
||||
}
|
||||
@@ -83,7 +83,8 @@ func init() {
|
||||
|
||||
rootCmd.AddCommand(migrationCmd)
|
||||
|
||||
tc := newTokenCommands()
|
||||
tc.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
||||
rootCmd.AddCommand(tc)
|
||||
ac := newAdminCommands()
|
||||
ac.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
|
||||
rootCmd.AddCommand(ac)
|
||||
rootCmd.AddCommand(newLegacyTokenCommand())
|
||||
}
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var tokenDatadir string
|
||||
|
||||
// newTokenCommands creates the token command tree with management-specific store opener.
|
||||
func newTokenCommands() *cobra.Command {
|
||||
cmd := tokencmd.NewCommands(withTokenStore)
|
||||
cmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
|
||||
return cmd
|
||||
}
|
||||
|
||||
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
|
||||
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
|
||||
if err := util.InitLog("error", "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
|
||||
|
||||
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
|
||||
datadir := config.Datadir
|
||||
if tokenDatadir != "" {
|
||||
datadir = tokenDatadir
|
||||
}
|
||||
|
||||
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create store: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := s.Close(ctx); err != nil {
|
||||
log.Debugf("close store: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return fn(ctx, s)
|
||||
}
|
||||
@@ -608,11 +608,11 @@ func (s *ProxyServiceServer) disconnectProxy(conn *proxyConnection) {
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, conn.proxyID); err != nil {
|
||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", conn.proxyID, err)
|
||||
}
|
||||
conn.cancel()
|
||||
if err := s.proxyManager.Disconnect(context.Background(), conn.proxyID, conn.sessionID); err != nil {
|
||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", conn.proxyID, err)
|
||||
}
|
||||
|
||||
conn.cancel()
|
||||
log.Infof("Proxy %s session %s disconnected", conn.proxyID, conn.sessionID)
|
||||
}
|
||||
|
||||
|
||||
@@ -21,8 +21,11 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
staticClientDashboard = "netbird-dashboard"
|
||||
staticClientCLI = "netbird-cli"
|
||||
StaticClientDashboard = "netbird-dashboard"
|
||||
StaticClientCLI = "netbird-cli"
|
||||
DefaultTOTPAuthenticatorID = "default-totp"
|
||||
LocalConnectorID = dex.LocalConnectorID
|
||||
|
||||
defaultCLIRedirectURL1 = "http://localhost:53000/"
|
||||
defaultCLIRedirectURL2 = "http://localhost:54000/"
|
||||
defaultScopes = "openid profile email groups"
|
||||
@@ -185,14 +188,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
EnablePasswordDB: true,
|
||||
StaticClients: []storage.Client{
|
||||
{
|
||||
ID: staticClientDashboard,
|
||||
ID: StaticClientDashboard,
|
||||
Name: "NetBird Dashboard",
|
||||
Public: true,
|
||||
RedirectURIs: redirectURIs,
|
||||
PostLogoutRedirectURIs: sanitizePostLogoutRedirectURIs(dashboardPostLogoutRedirectURIs),
|
||||
},
|
||||
{
|
||||
ID: staticClientCLI,
|
||||
ID: StaticClientCLI,
|
||||
Name: "NetBird CLI",
|
||||
Public: true,
|
||||
RedirectURIs: redirectURIs,
|
||||
@@ -254,13 +257,13 @@ func sanitizePostLogoutRedirectURIs(uris []string) []string {
|
||||
|
||||
func configureMFA(cfg *dex.YAMLConfig, sessionMaxLifetime, sessionIdleTimeout string, rememberMe bool, sessionCookieEncryptionKey string) error {
|
||||
cfg.MFA.Authenticators = []dex.MFAAuthenticator{{
|
||||
ID: "default-totp",
|
||||
ID: DefaultTOTPAuthenticatorID,
|
||||
// Has to be caps otherwise it will fail
|
||||
Type: "TOTP",
|
||||
Config: map[string]interface{}{
|
||||
"issuer": "NetBird",
|
||||
},
|
||||
ConnectorTypes: []string{"local"},
|
||||
ConnectorTypes: []string{LocalConnectorID},
|
||||
}}
|
||||
|
||||
if sessionMaxLifetime == "" {
|
||||
@@ -736,7 +739,7 @@ func (m *EmbeddedIdPManager) GetDefaultScopes() string {
|
||||
|
||||
// GetCLIClientID returns the client ID for CLI authentication.
|
||||
func (m *EmbeddedIdPManager) GetCLIClientID() string {
|
||||
return staticClientCLI
|
||||
return StaticClientCLI
|
||||
}
|
||||
|
||||
// GetCLIRedirectURLs returns the redirect URLs configured for the CLI client.
|
||||
@@ -775,7 +778,7 @@ func (m *EmbeddedIdPManager) GetLocalKeysLocation() string {
|
||||
|
||||
// GetClientIDs returns the OAuth2 client IDs configured for this provider.
|
||||
func (m *EmbeddedIdPManager) GetClientIDs() []string {
|
||||
return []string{staticClientDashboard, staticClientCLI}
|
||||
return []string{StaticClientDashboard, StaticClientCLI}
|
||||
}
|
||||
|
||||
// GetUserIDClaim returns the JWT claim name used for user identification.
|
||||
@@ -792,11 +795,11 @@ func (m *EmbeddedIdPManager) IsLocalAuthDisabled() bool {
|
||||
func (m *EmbeddedIdPManager) SetMFAEnabled(ctx context.Context, enabled bool) error {
|
||||
var mfaChain []string
|
||||
if enabled {
|
||||
mfaChain = []string{"default-totp"}
|
||||
mfaChain = []string{DefaultTOTPAuthenticatorID}
|
||||
}
|
||||
if err := m.provider.SetClientsMFAChain(ctx, []string{
|
||||
staticClientCLI,
|
||||
staticClientDashboard,
|
||||
StaticClientCLI,
|
||||
StaticClientDashboard,
|
||||
}, mfaChain); err != nil {
|
||||
return fmt.Errorf("failed to set MFA enabled=%v: %w", enabled, err)
|
||||
}
|
||||
|
||||
@@ -331,7 +331,7 @@ func TestEmbeddedIdPConfig_ToYAMLConfig_IncludesDeviceCallbackRedirectURI(t *tes
|
||||
|
||||
var cliRedirectURIs []string
|
||||
for _, client := range yamlConfig.StaticClients {
|
||||
if client.ID == staticClientCLI {
|
||||
if client.ID == StaticClientCLI {
|
||||
cliRedirectURIs = client.RedirectURIs
|
||||
break
|
||||
}
|
||||
|
||||
@@ -6088,6 +6088,37 @@ func (s *SqlStore) DisconnectProxy(ctx context.Context, proxyID, sessionID strin
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllProxies returns all reverse proxy instance rows.
|
||||
func (s *SqlStore) GetAllProxies(ctx context.Context) ([]*proxy.Proxy, error) {
|
||||
var proxies []*proxy.Proxy
|
||||
result := s.db.Order("cluster_address, id").Find(&proxies)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxies: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get proxies")
|
||||
}
|
||||
return proxies, nil
|
||||
}
|
||||
|
||||
// DisconnectAllProxies force-marks every proxy that is not already disconnected
|
||||
// as disconnected, regardless of session ID. Unlike DisconnectProxy it is not
|
||||
// session-guarded: it is an administrative repair helper, not part of the
|
||||
// connection lifecycle. last_seen is left untouched so the stale-proxy reaper
|
||||
// keeps working off the real last heartbeat. Returns the number of proxies updated.
|
||||
func (s *SqlStore) DisconnectAllProxies(ctx context.Context) (int64, error) {
|
||||
result := s.db.
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("status != ?", proxy.StatusDisconnected).
|
||||
Updates(map[string]any{
|
||||
"status": proxy.StatusDisconnected,
|
||||
"disconnected_at": time.Now(),
|
||||
})
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect all proxies: %v", result.Error)
|
||||
return 0, status.Errorf(status.Internal, "failed to disconnect all proxies")
|
||||
}
|
||||
return result.RowsAffected, nil
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat updates the last_seen timestamp for the proxy's current session.
|
||||
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
now := time.Now()
|
||||
@@ -6095,7 +6126,11 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) err
|
||||
result := s.db.
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("id = ? AND session_id = ?", p.ID, p.SessionID).
|
||||
Update("last_seen", now)
|
||||
Updates(map[string]any{
|
||||
"last_seen": now,
|
||||
"status": proxy.StatusConnected,
|
||||
"disconnected_at": nil,
|
||||
})
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update proxy heartbeat: %v", result.Error)
|
||||
|
||||
156
management/server/store/sql_store_proxy_disconnect_test.go
Normal file
156
management/server/store/sql_store_proxy_disconnect_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
)
|
||||
|
||||
// TestSqlStore_DisconnectAllProxies guards the administrative
|
||||
// force-disconnect helper:
|
||||
//
|
||||
// 1. Every proxy that is not already disconnected is marked
|
||||
// disconnected regardless of its session ID (unlike
|
||||
// DisconnectProxy, which is session-guarded).
|
||||
// 2. Rows that are already disconnected are left untouched, so their
|
||||
// original disconnected_at is preserved and the returned count
|
||||
// reflects only the rows that actually changed.
|
||||
// 3. last_seen is not modified — the stale-proxy reaper keeps working
|
||||
// off the real last heartbeat.
|
||||
func TestSqlStore_DisconnectAllProxies(t *testing.T) {
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
ctx := context.Background()
|
||||
|
||||
lastSeenFresh := time.Now().Add(-30 * time.Second)
|
||||
lastSeenStale := time.Now().Add(-30 * time.Minute)
|
||||
oldDisconnectedAt := time.Now().Add(-time.Hour)
|
||||
|
||||
accountID := "acct-disconnect"
|
||||
proxies := []*rpproxy.Proxy{
|
||||
{
|
||||
ID: "p-connected-fresh",
|
||||
SessionID: "sess-1",
|
||||
ClusterAddress: "cluster-a.example.com",
|
||||
IPAddress: "10.0.0.1",
|
||||
LastSeen: lastSeenFresh,
|
||||
Status: rpproxy.StatusConnected,
|
||||
},
|
||||
{
|
||||
ID: "p-connected-stale",
|
||||
SessionID: "sess-2",
|
||||
ClusterAddress: "cluster-b.example.com",
|
||||
IPAddress: "10.0.0.2",
|
||||
AccountID: &accountID,
|
||||
LastSeen: lastSeenStale,
|
||||
Status: rpproxy.StatusConnected,
|
||||
},
|
||||
{
|
||||
ID: "p-already-disconnected",
|
||||
SessionID: "sess-3",
|
||||
ClusterAddress: "cluster-a.example.com",
|
||||
IPAddress: "10.0.0.3",
|
||||
LastSeen: lastSeenStale,
|
||||
Status: rpproxy.StatusDisconnected,
|
||||
DisconnectedAt: &oldDisconnectedAt,
|
||||
},
|
||||
}
|
||||
for _, p := range proxies {
|
||||
require.NoError(t, store.SaveProxy(ctx, p))
|
||||
}
|
||||
|
||||
all, err := store.GetAllProxies(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, all, 3)
|
||||
|
||||
disconnected, err := store.DisconnectAllProxies(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(2), disconnected)
|
||||
|
||||
all, err = store.GetAllProxies(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, all, 3)
|
||||
|
||||
byID := make(map[string]*rpproxy.Proxy, len(all))
|
||||
for _, p := range all {
|
||||
byID[p.ID] = p
|
||||
}
|
||||
|
||||
for id, p := range byID {
|
||||
assert.Equal(t, rpproxy.StatusDisconnected, p.Status, "proxy %s should be disconnected", id)
|
||||
require.NotNil(t, p.DisconnectedAt, "proxy %s should have disconnected_at set", id)
|
||||
}
|
||||
|
||||
// force-marked rows carry a fresh disconnected_at; the untouched row keeps its original one
|
||||
assert.WithinDuration(t, time.Now(), *byID["p-connected-fresh"].DisconnectedAt, 10*time.Second)
|
||||
assert.WithinDuration(t, time.Now(), *byID["p-connected-stale"].DisconnectedAt, 10*time.Second)
|
||||
assert.WithinDuration(t, oldDisconnectedAt, *byID["p-already-disconnected"].DisconnectedAt, time.Second)
|
||||
|
||||
// last_seen is preserved so the stale reaper schedule is unaffected
|
||||
assert.WithinDuration(t, lastSeenFresh, byID["p-connected-fresh"].LastSeen, time.Second)
|
||||
assert.WithinDuration(t, lastSeenStale, byID["p-connected-stale"].LastSeen, time.Second)
|
||||
|
||||
// idempotent: a second run has nothing left to update
|
||||
disconnected, err = store.DisconnectAllProxies(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), disconnected)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_UpdateProxyHeartbeatRestoresDisconnectedCurrentSession(t *testing.T) {
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
ctx := context.Background()
|
||||
proxy := &rpproxy.Proxy{
|
||||
ID: "p-heartbeat",
|
||||
SessionID: "sess-heartbeat",
|
||||
ClusterAddress: "cluster-heartbeat.example.com",
|
||||
IPAddress: "10.0.0.10",
|
||||
LastSeen: time.Now().Add(-30 * time.Second),
|
||||
Status: rpproxy.StatusConnected,
|
||||
}
|
||||
require.NoError(t, store.SaveProxy(ctx, proxy))
|
||||
|
||||
disconnected, err := store.DisconnectAllProxies(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), disconnected)
|
||||
|
||||
require.NoError(t, store.UpdateProxyHeartbeat(ctx, &rpproxy.Proxy{ID: proxy.ID, SessionID: proxy.SessionID}))
|
||||
|
||||
all, err := store.GetAllProxies(ctx)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, all, 1)
|
||||
assert.Equal(t, rpproxy.StatusConnected, all[0].Status)
|
||||
assert.Nil(t, all[0].DisconnectedAt)
|
||||
assert.WithinDuration(t, time.Now(), all[0].LastSeen, 10*time.Second)
|
||||
|
||||
addresses, err := store.GetActiveProxyClusterAddresses(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, addresses, proxy.ClusterAddress)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAllProxies_Empty(t *testing.T) {
|
||||
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
|
||||
t.Skip("skip CI tests on darwin and windows")
|
||||
}
|
||||
|
||||
runTestForAllEngines(t, "", func(t *testing.T, store Store) {
|
||||
all, err := store.GetAllProxies(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, all)
|
||||
})
|
||||
}
|
||||
@@ -323,6 +323,8 @@ type Store interface {
|
||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsPrivate(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
GetAllProxies(ctx context.Context) ([]*proxy.Proxy, error)
|
||||
DisconnectAllProxies(ctx context.Context) (int64, error)
|
||||
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
|
||||
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
|
||||
IsClusterAddressConflicting(ctx context.Context, clusterAddress, accountID string) (bool, error)
|
||||
|
||||
@@ -745,6 +745,21 @@ func (mr *MockStoreMockRecorder) DeleteZoneDNSRecords(ctx, accountID, zoneID int
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID)
|
||||
}
|
||||
|
||||
// DisconnectAllProxies mocks base method.
|
||||
func (m *MockStore) DisconnectAllProxies(ctx context.Context) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DisconnectAllProxies", ctx)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DisconnectAllProxies indicates an expected call of DisconnectAllProxies.
|
||||
func (mr *MockStoreMockRecorder) DisconnectAllProxies(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectAllProxies", reflect.TypeOf((*MockStore)(nil).DisconnectAllProxies), ctx)
|
||||
}
|
||||
|
||||
// DisconnectProxy mocks base method.
|
||||
func (m *MockStore) DisconnectProxy(ctx context.Context, proxyID, sessionID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1389,6 +1404,21 @@ func (mr *MockStoreMockRecorder) GetAllEphemeralPeers(ctx, lockStrength interfac
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllEphemeralPeers", reflect.TypeOf((*MockStore)(nil).GetAllEphemeralPeers), ctx, lockStrength)
|
||||
}
|
||||
|
||||
// GetAllProxies mocks base method.
|
||||
func (m *MockStore) GetAllProxies(ctx context.Context) ([]*proxy.Proxy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAllProxies", ctx)
|
||||
ret0, _ := ret[0].([]*proxy.Proxy)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAllProxies indicates an expected call of GetAllProxies.
|
||||
func (mr *MockStoreMockRecorder) GetAllProxies(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllProxies", reflect.TypeOf((*MockStore)(nil).GetAllProxies), ctx)
|
||||
}
|
||||
|
||||
// GetAllProxyAccessTokens mocks base method.
|
||||
func (m *MockStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types2.ProxyAccessToken, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1521,6 +1551,21 @@ func (mr *MockStoreMockRecorder) GetDNSRecordByID(ctx, lockStrength, accountID,
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSRecordByID", reflect.TypeOf((*MockStore)(nil).GetDNSRecordByID), ctx, lockStrength, accountID, zoneID, recordID)
|
||||
}
|
||||
|
||||
// GetEmbeddedProxyPeerIDsByCluster mocks base method.
|
||||
func (m *MockStore) GetEmbeddedProxyPeerIDsByCluster(ctx context.Context, accountID string) (map[string][]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetEmbeddedProxyPeerIDsByCluster", ctx, accountID)
|
||||
ret0, _ := ret[0].(map[string][]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetEmbeddedProxyPeerIDsByCluster indicates an expected call of GetEmbeddedProxyPeerIDsByCluster.
|
||||
func (mr *MockStoreMockRecorder) GetEmbeddedProxyPeerIDsByCluster(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEmbeddedProxyPeerIDsByCluster", reflect.TypeOf((*MockStore)(nil).GetEmbeddedProxyPeerIDsByCluster), ctx, accountID)
|
||||
}
|
||||
|
||||
// GetExpiredEphemeralServices mocks base method.
|
||||
func (m *MockStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*service.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1566,6 +1611,21 @@ func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, accountID, gr
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, accountID, groupName)
|
||||
}
|
||||
|
||||
// GetGroupIDsByPeerIDs mocks base method.
|
||||
func (m *MockStore) GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetGroupIDsByPeerIDs", ctx, accountID, peerIDs)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetGroupIDsByPeerIDs indicates an expected call of GetGroupIDsByPeerIDs.
|
||||
func (mr *MockStoreMockRecorder) GetGroupIDsByPeerIDs(ctx, accountID, peerIDs interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupIDsByPeerIDs", reflect.TypeOf((*MockStore)(nil).GetGroupIDsByPeerIDs), ctx, accountID, peerIDs)
|
||||
}
|
||||
|
||||
// GetGroupsByIDs mocks base method.
|
||||
func (m *MockStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types2.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1850,6 +1910,21 @@ func (mr *MockStoreMockRecorder) GetPeerIDByKey(ctx, lockStrength, key interface
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerIDByKey", reflect.TypeOf((*MockStore)(nil).GetPeerIDByKey), ctx, lockStrength, key)
|
||||
}
|
||||
|
||||
// GetPeerIDsByGroups mocks base method.
|
||||
func (m *MockStore) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPeerIDsByGroups", ctx, accountID, groupIDs)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPeerIDsByGroups indicates an expected call of GetPeerIDsByGroups.
|
||||
func (mr *MockStoreMockRecorder) GetPeerIDsByGroups(ctx, accountID, groupIDs interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerIDsByGroups", reflect.TypeOf((*MockStore)(nil).GetPeerIDsByGroups), ctx, accountID, groupIDs)
|
||||
}
|
||||
|
||||
// GetPeerIdByLabel mocks base method.
|
||||
func (m *MockStore) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID, hostname string) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1925,51 +2000,6 @@ func (mr *MockStoreMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupIDs int
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockStore)(nil).GetPeersByGroupIDs), ctx, accountID, groupIDs)
|
||||
}
|
||||
|
||||
// GetPeerIDsByGroups mocks base method.
|
||||
func (m *MockStore) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPeerIDsByGroups", ctx, accountID, groupIDs)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPeerIDsByGroups indicates an expected call of GetPeerIDsByGroups.
|
||||
func (mr *MockStoreMockRecorder) GetPeerIDsByGroups(ctx, accountID, groupIDs interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerIDsByGroups", reflect.TypeOf((*MockStore)(nil).GetPeerIDsByGroups), ctx, accountID, groupIDs)
|
||||
}
|
||||
|
||||
// GetGroupIDsByPeerIDs mocks base method.
|
||||
func (m *MockStore) GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetGroupIDsByPeerIDs", ctx, accountID, peerIDs)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetGroupIDsByPeerIDs indicates an expected call of GetGroupIDsByPeerIDs.
|
||||
func (mr *MockStoreMockRecorder) GetGroupIDsByPeerIDs(ctx, accountID, peerIDs interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupIDsByPeerIDs", reflect.TypeOf((*MockStore)(nil).GetGroupIDsByPeerIDs), ctx, accountID, peerIDs)
|
||||
}
|
||||
|
||||
// GetEmbeddedProxyPeerIDsByCluster mocks base method.
|
||||
func (m *MockStore) GetEmbeddedProxyPeerIDsByCluster(ctx context.Context, accountID string) (map[string][]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetEmbeddedProxyPeerIDsByCluster", ctx, accountID)
|
||||
ret0, _ := ret[0].(map[string][]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetEmbeddedProxyPeerIDsByCluster indicates an expected call of GetEmbeddedProxyPeerIDsByCluster.
|
||||
func (mr *MockStoreMockRecorder) GetEmbeddedProxyPeerIDsByCluster(ctx, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEmbeddedProxyPeerIDsByCluster", reflect.TypeOf((*MockStore)(nil).GetEmbeddedProxyPeerIDsByCluster), ctx, accountID)
|
||||
}
|
||||
|
||||
// GetPeersByIDs mocks base method.
|
||||
func (m *MockStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*peer.Peer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -1849,12 +1849,17 @@ func (am *DefaultAccountManager) DeleteUserInvite(ctx context.Context, accountID
|
||||
|
||||
const minPasswordLength = 8
|
||||
|
||||
// validatePassword checks password strength requirements:
|
||||
// validatePassword checks password strength requirements.
|
||||
func validatePassword(password string) error {
|
||||
return ValidatePassword(password)
|
||||
}
|
||||
|
||||
// ValidatePassword checks password strength requirements:
|
||||
// - Minimum 8 characters
|
||||
// - At least 1 digit
|
||||
// - At least 1 uppercase letter
|
||||
// - At least 1 special character
|
||||
func validatePassword(password string) error {
|
||||
func ValidatePassword(password string) error {
|
||||
if len(password) < minPasswordLength {
|
||||
return errors.New("password must be at least 8 characters long")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user