mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-23 09:09:56 +00:00
Compare commits
5 Commits
refactor/e
...
refactor/m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5d6585bd43 | ||
|
|
be1c6f594b | ||
|
|
63be4170e0 | ||
|
|
27a1b2243d | ||
|
|
e3b3396d10 |
@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
||||
iv, _ := integrations.NewIntegratedValidator(ctx, nil, nil, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
||||
require.NoError(t, err)
|
||||
@@ -124,18 +124,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config, nil)
|
||||
|
||||
accountManager, err := mgmt.BuildManager(ctx, config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -66,8 +66,8 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
@@ -1641,7 +1641,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), nil, nil, peersManager, nil, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
@@ -1662,17 +1662,17 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config, nil)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -315,7 +315,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), nil, nil, peersManager, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
@@ -325,17 +325,17 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
|
||||
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
|
||||
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
|
||||
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config, nil)
|
||||
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -332,7 +332,7 @@ func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
|
||||
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
|
||||
}
|
||||
|
||||
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
||||
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), s.IDPHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
|
||||
if servers.relaySrv != nil {
|
||||
log.Infof("Relay WebSocket handler added (path: /relay)")
|
||||
}
|
||||
@@ -521,7 +521,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
|
||||
}
|
||||
|
||||
// createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic
|
||||
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
|
||||
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
|
||||
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
|
||||
|
||||
var relayAcceptFn func(conn listener.Conn)
|
||||
@@ -556,6 +556,10 @@ func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, re
|
||||
http.Error(w, "Relay service not enabled", http.StatusNotFound)
|
||||
}
|
||||
|
||||
// Embedded IdP (Dex)
|
||||
case idpHandler != nil && strings.HasPrefix(r.URL.Path, "/oauth2"):
|
||||
idpHandler.ServeHTTP(w, r)
|
||||
|
||||
// Management HTTP API (default)
|
||||
default:
|
||||
httpHandler.ServeHTTP(w, r)
|
||||
|
||||
@@ -55,6 +55,8 @@ type Controller struct {
|
||||
proxyController port_forwarding.Controller
|
||||
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
configExtender grpc.ConfigExtender
|
||||
}
|
||||
|
||||
type bufferUpdate struct {
|
||||
@@ -65,7 +67,7 @@ type bufferUpdate struct {
|
||||
|
||||
var _ network_map.Controller = (*Controller)(nil)
|
||||
|
||||
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller {
|
||||
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config, configExtender grpc.ConfigExtender) *Controller {
|
||||
nMetrics, err := newMetrics(metrics.UpdateChannelMetrics())
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
|
||||
@@ -84,6 +86,8 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
||||
|
||||
proxyController: proxyController,
|
||||
EphemeralPeersManager: ephemeralPeersManager,
|
||||
|
||||
configExtender: configExtender,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,7 +207,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
|
||||
peerGroups := account.GetPeerGroups(p.ID)
|
||||
start = time.Now()
|
||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort, c.configExtender)
|
||||
c.metrics.CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
|
||||
@@ -329,7 +333,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
peerGroups := account.GetPeerGroups(peerId)
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
|
||||
|
||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort, c.configExtender)
|
||||
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
|
||||
Update: update,
|
||||
MessageType: network_map.MessageTypeNetworkMap,
|
||||
|
||||
@@ -2,7 +2,6 @@ package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -12,76 +11,44 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
const (
|
||||
// cleanupWindow is the small grace period added on top of the
|
||||
// staleness horizon before a sweep fires. It absorbs minor clock
|
||||
// skew between the management server and the database and avoids
|
||||
// firing a sweep right at the boundary where last_seen could still
|
||||
// be one tick under the threshold.
|
||||
// cleanupWindow is the time window to wait after nearest peer deadline to start the cleanup procedure.
|
||||
cleanupWindow = 1 * time.Minute
|
||||
|
||||
// initialLoadMinDelay and initialLoadMaxDelay bracket the random
|
||||
// delay applied before the post-restart catch-up query runs. Spread
|
||||
// across replicas this prevents a thundering herd of catch-up
|
||||
// queries hitting the database simultaneously after a deploy.
|
||||
initialLoadMinDelay = 8 * time.Minute
|
||||
initialLoadMaxDelay = 10 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
timeNow = time.Now
|
||||
)
|
||||
|
||||
// accountEntry is the per-account state held by the cleanup tracker.
|
||||
// We don't track which peers are pending — the sweep query gets the
|
||||
// authoritative list straight from the database every time. We only
|
||||
// need to know the latest disconnect we've observed for this account
|
||||
// (so we can decide when it's safe to drop the entry) and the timer
|
||||
// that will fire the next sweep.
|
||||
type accountEntry struct {
|
||||
lastDisconnectedAt time.Time
|
||||
timer *time.Timer
|
||||
type ephemeralPeer struct {
|
||||
id string
|
||||
accountID string
|
||||
deadline time.Time
|
||||
next *ephemeralPeer
|
||||
}
|
||||
|
||||
// EphemeralManager tracks accounts that may have ephemeral peers in
|
||||
// need of cleanup and runs a per-account sweep at the appropriate
|
||||
// time. State is in-memory and account-scoped: a sweep deletes any
|
||||
// ephemeral peer in the account that has been disconnected for at
|
||||
// least lifeTime, then either drops the account from the tracker
|
||||
// (when no recent disconnects have arrived) or re-arms the timer.
|
||||
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
|
||||
// in worst case we will get invalid error message in this manager.
|
||||
|
||||
// EphemeralManager keep a list of ephemeral peers. After EphemeralLifeTime inactivity the peer will be deleted
|
||||
// automatically. Inactivity means the peer disconnected from the Management server.
|
||||
type EphemeralManager struct {
|
||||
store store.Store
|
||||
peersManager peers.Manager
|
||||
|
||||
accountsLock sync.Mutex
|
||||
accounts map[string]*accountEntry
|
||||
|
||||
// initialLoadTimer is the one-shot timer used to defer the
|
||||
// post-restart catch-up query; held so Stop() can cancel it.
|
||||
initialLoadTimer *time.Timer
|
||||
// stopped is flipped by Stop() so any timer that fires after
|
||||
// teardown becomes a no-op instead of touching a half-dismantled
|
||||
// store.
|
||||
stopped bool
|
||||
headPeer *ephemeralPeer
|
||||
tailPeer *ephemeralPeer
|
||||
peersLock sync.Mutex
|
||||
timer *time.Timer
|
||||
|
||||
lifeTime time.Duration
|
||||
cleanupWindow time.Duration
|
||||
|
||||
// initialLoadDelay returns the wall-clock delay to wait before
|
||||
// running the post-restart catch-up query. Pluggable so tests can
|
||||
// fire the load immediately.
|
||||
initialLoadDelay func() time.Duration
|
||||
|
||||
// bgCtx is the long-lived context captured at LoadInitialPeers
|
||||
// time. Timer-driven sweeps use it because they fire long after
|
||||
// the original gRPC handler ctx that produced an OnPeerDisconnected
|
||||
// call has been cancelled.
|
||||
bgCtx context.Context
|
||||
|
||||
// metrics is nil-safe; methods on telemetry.EphemeralPeersMetrics
|
||||
// no-op when the receiver is nil so deployments without an app
|
||||
// metrics provider work unchanged.
|
||||
@@ -91,265 +58,228 @@ type EphemeralManager struct {
|
||||
// NewEphemeralManager instantiate new EphemeralManager
|
||||
func NewEphemeralManager(store store.Store, peersManager peers.Manager) *EphemeralManager {
|
||||
return &EphemeralManager{
|
||||
store: store,
|
||||
peersManager: peersManager,
|
||||
accounts: make(map[string]*accountEntry),
|
||||
lifeTime: ephemeral.EphemeralLifeTime,
|
||||
cleanupWindow: cleanupWindow,
|
||||
initialLoadDelay: defaultInitialLoadDelay,
|
||||
store: store,
|
||||
peersManager: peersManager,
|
||||
|
||||
lifeTime: ephemeral.EphemeralLifeTime,
|
||||
cleanupWindow: cleanupWindow,
|
||||
}
|
||||
}
|
||||
|
||||
// SetMetrics attaches a metrics collector. Pass nil to detach.
|
||||
// SetMetrics attaches a metrics collector. Safe to call once before
|
||||
// LoadInitialPeers; later attachment is fine but earlier loads won't be
|
||||
// reflected in the gauge. Pass nil to detach.
|
||||
func (e *EphemeralManager) SetMetrics(m *telemetry.EphemeralPeersMetrics) {
|
||||
e.accountsLock.Lock()
|
||||
e.peersLock.Lock()
|
||||
e.metrics = m
|
||||
e.accountsLock.Unlock()
|
||||
e.peersLock.Unlock()
|
||||
}
|
||||
|
||||
// LoadInitialPeers schedules the post-restart catch-up query for a
|
||||
// random moment 8-10 minutes from now. Returns immediately. The
|
||||
// catch-up populates the per-account tracker from the database so any
|
||||
// peers that disconnected before the restart still get cleaned up.
|
||||
//
|
||||
// The random delay is critical: without it, every management replica
|
||||
// hitting the same Postgres instance after a deploy would issue the
|
||||
// catch-up query simultaneously.
|
||||
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
|
||||
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
|
||||
// head.
|
||||
func (e *EphemeralManager) LoadInitialPeers(ctx context.Context) {
|
||||
e.accountsLock.Lock()
|
||||
defer e.accountsLock.Unlock()
|
||||
if e.stopped {
|
||||
e.peersLock.Lock()
|
||||
defer e.peersLock.Unlock()
|
||||
|
||||
e.loadEphemeralPeers(ctx)
|
||||
if e.headPeer != nil {
|
||||
e.timer = time.AfterFunc(e.lifeTime, func() {
|
||||
e.cleanup(ctx)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Stop timer
|
||||
func (e *EphemeralManager) Stop() {
|
||||
e.peersLock.Lock()
|
||||
defer e.peersLock.Unlock()
|
||||
|
||||
if e.timer != nil {
|
||||
e.timer.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer
|
||||
// is active the manager will not delete it while it is active.
|
||||
func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Peer) {
|
||||
if !peer.Ephemeral {
|
||||
return
|
||||
}
|
||||
|
||||
e.bgCtx = ctx
|
||||
log.WithContext(ctx).Tracef("remove peer from ephemeral list: %s", peer.ID)
|
||||
|
||||
delay := e.initialLoadDelay()
|
||||
log.WithContext(ctx).Infof("ephemeral peer initial load scheduled in %s", delay)
|
||||
e.initialLoadTimer = time.AfterFunc(delay, func() {
|
||||
e.loadInitialAccounts(e.bgCtx)
|
||||
})
|
||||
}
|
||||
e.peersLock.Lock()
|
||||
defer e.peersLock.Unlock()
|
||||
|
||||
// Stop cancels the deferred initial load and any per-account timers.
|
||||
func (e *EphemeralManager) Stop() {
|
||||
e.accountsLock.Lock()
|
||||
defer e.accountsLock.Unlock()
|
||||
|
||||
e.stopped = true
|
||||
if e.initialLoadTimer != nil {
|
||||
e.initialLoadTimer.Stop()
|
||||
e.initialLoadTimer = nil
|
||||
if e.removePeer(peer.ID) {
|
||||
e.metrics.DecPending(1)
|
||||
}
|
||||
for _, entry := range e.accounts {
|
||||
if entry.timer != nil {
|
||||
entry.timer.Stop()
|
||||
}
|
||||
|
||||
// stop the unnecessary timer
|
||||
if e.headPeer == nil && e.timer != nil {
|
||||
e.timer.Stop()
|
||||
e.timer = nil
|
||||
}
|
||||
e.accounts = make(map[string]*accountEntry)
|
||||
}
|
||||
|
||||
// OnPeerConnected is a no-op in the account-scoped design. The sweep
|
||||
// query filters out connected peers at the database level, so we don't
|
||||
// need an explicit "remove from list" signal when a peer reconnects.
|
||||
// Kept on the interface to preserve the existing call sites.
|
||||
func (e *EphemeralManager) OnPeerConnected(_ context.Context, _ *nbpeer.Peer) {
|
||||
}
|
||||
|
||||
// OnPeerDisconnected registers a disconnect for the peer's account and
|
||||
// arms a sweep if one isn't already scheduled. Non-ephemeral peers are
|
||||
// ignored.
|
||||
// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer
|
||||
// is inactive it will be deleted after the EphemeralLifeTime period.
|
||||
func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) {
|
||||
if !peer.Ephemeral {
|
||||
return
|
||||
}
|
||||
|
||||
now := timeNow()
|
||||
log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID)
|
||||
|
||||
e.accountsLock.Lock()
|
||||
defer e.accountsLock.Unlock()
|
||||
if e.stopped {
|
||||
e.peersLock.Lock()
|
||||
defer e.peersLock.Unlock()
|
||||
|
||||
if e.isPeerOnList(peer.ID) {
|
||||
return
|
||||
}
|
||||
|
||||
entry, existed := e.accounts[peer.AccountID]
|
||||
if !existed {
|
||||
entry = &accountEntry{}
|
||||
e.accounts[peer.AccountID] = entry
|
||||
e.metrics.IncPending()
|
||||
}
|
||||
entry.lastDisconnectedAt = now
|
||||
|
||||
if entry.timer == nil {
|
||||
delay := e.lifeTime + e.cleanupWindow
|
||||
log.WithContext(ctx).Tracef("ephemeral: scheduling sweep for account %s in %s", peer.AccountID, delay)
|
||||
accountID := peer.AccountID
|
||||
entry.timer = time.AfterFunc(delay, func() {
|
||||
e.sweep(e.bgCtxOrFallback(ctx), accountID)
|
||||
e.addPeer(peer.AccountID, peer.ID, e.newDeadLine())
|
||||
e.metrics.IncPending()
|
||||
if e.timer == nil {
|
||||
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
|
||||
if delay < 0 {
|
||||
delay = 0
|
||||
}
|
||||
e.timer = time.AfterFunc(delay, func() {
|
||||
e.cleanup(ctx)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// bgCtxOrFallback returns the long-lived background context captured at
|
||||
// LoadInitialPeers time, falling back to the supplied ctx when the
|
||||
// manager hasn't been started through LoadInitialPeers (e.g. in tests
|
||||
// that drive the manager directly). Must be called with the lock held
|
||||
// or before the timer is armed.
|
||||
func (e *EphemeralManager) bgCtxOrFallback(ctx context.Context) context.Context {
|
||||
if e.bgCtx != nil {
|
||||
return e.bgCtx
|
||||
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
|
||||
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
|
||||
return
|
||||
}
|
||||
return ctx
|
||||
|
||||
t := e.newDeadLine()
|
||||
for _, p := range peers {
|
||||
e.addPeer(p.AccountID, p.ID, t)
|
||||
}
|
||||
e.metrics.AddPending(int64(len(peers)))
|
||||
|
||||
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", len(peers))
|
||||
}
|
||||
|
||||
// loadInitialAccounts runs the post-restart catch-up query and seeds
|
||||
// the tracker with one entry per account that has at least one
|
||||
// disconnected ephemeral peer.
|
||||
func (e *EphemeralManager) loadInitialAccounts(ctx context.Context) {
|
||||
accounts, err := e.store.GetEphemeralAccountsLastDisconnect(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to load ephemeral accounts on startup: %v", err)
|
||||
return
|
||||
}
|
||||
func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
log.Tracef("on ephemeral cleanup")
|
||||
deletePeers := make(map[string]*ephemeralPeer)
|
||||
|
||||
e.peersLock.Lock()
|
||||
now := timeNow()
|
||||
added := 0
|
||||
for p := e.headPeer; p != nil; p = p.next {
|
||||
if now.Before(p.deadline) {
|
||||
break
|
||||
}
|
||||
|
||||
e.accountsLock.Lock()
|
||||
defer e.accountsLock.Unlock()
|
||||
if e.stopped {
|
||||
return
|
||||
deletePeers[p.id] = p
|
||||
e.headPeer = p.next
|
||||
if p.next == nil {
|
||||
e.tailPeer = nil
|
||||
}
|
||||
}
|
||||
|
||||
for accountID, lastDisc := range accounts {
|
||||
// If we already learned about this account via an
|
||||
// OnPeerDisconnected that arrived during the random delay
|
||||
// window, prefer the live timestamp.
|
||||
if _, alreadyTracked := e.accounts[accountID]; alreadyTracked {
|
||||
if e.headPeer != nil {
|
||||
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
|
||||
if delay < 0 {
|
||||
delay = 0
|
||||
}
|
||||
e.timer = time.AfterFunc(delay, func() {
|
||||
e.cleanup(ctx)
|
||||
})
|
||||
} else {
|
||||
e.timer = nil
|
||||
}
|
||||
|
||||
e.peersLock.Unlock()
|
||||
|
||||
// Drop the gauge by the number of entries we just took off the list,
|
||||
// regardless of whether the subsequent DeletePeers call succeeds. The
|
||||
// list invariant is what the gauge tracks; failed delete batches are
|
||||
// counted separately via CountCleanupError so we can still see them.
|
||||
if len(deletePeers) > 0 {
|
||||
e.metrics.CountCleanupRun()
|
||||
e.metrics.DecPending(int64(len(deletePeers)))
|
||||
}
|
||||
|
||||
peerIDsPerAccount := make(map[string][]string)
|
||||
for id, p := range deletePeers {
|
||||
peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id)
|
||||
}
|
||||
|
||||
for accountID, peerIDs := range peerIDsPerAccount {
|
||||
log.WithContext(ctx).Tracef("cleanup: deleting %d ephemeral peers for account %s", len(peerIDs), accountID)
|
||||
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err)
|
||||
e.metrics.CountCleanupError()
|
||||
continue
|
||||
}
|
||||
e.metrics.CountPeersCleaned(int64(len(peerIDs)))
|
||||
}
|
||||
}
|
||||
|
||||
entry := &accountEntry{lastDisconnectedAt: lastDisc}
|
||||
horizon := lastDisc.Add(e.lifeTime)
|
||||
func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {
|
||||
ep := &ephemeralPeer{
|
||||
id: peerID,
|
||||
accountID: accountID,
|
||||
deadline: deadline,
|
||||
}
|
||||
|
||||
var delay time.Duration
|
||||
if horizon.After(now) {
|
||||
delay = horizon.Sub(now) + e.cleanupWindow
|
||||
} else {
|
||||
// Already past the staleness window — sweep right away
|
||||
// (one cleanupWindow later, to keep startup load smooth
|
||||
// when many accounts qualify at once).
|
||||
delay = e.cleanupWindow
|
||||
if e.headPeer == nil {
|
||||
e.headPeer = ep
|
||||
}
|
||||
if e.tailPeer != nil {
|
||||
e.tailPeer.next = ep
|
||||
}
|
||||
e.tailPeer = ep
|
||||
}
|
||||
|
||||
// removePeer drops the entry from the linked list. Returns true if a
|
||||
// matching entry was found and removed so callers can keep the pending
|
||||
// metric gauge in sync.
|
||||
func (e *EphemeralManager) removePeer(id string) bool {
|
||||
if e.headPeer == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if e.headPeer.id == id {
|
||||
e.headPeer = e.headPeer.next
|
||||
if e.tailPeer.id == id {
|
||||
e.tailPeer = nil
|
||||
}
|
||||
idForClosure := accountID
|
||||
entry.timer = time.AfterFunc(delay, func() {
|
||||
e.sweep(ctx, idForClosure)
|
||||
})
|
||||
e.accounts[accountID] = entry
|
||||
added++
|
||||
return true
|
||||
}
|
||||
|
||||
e.metrics.AddPending(int64(added))
|
||||
log.WithContext(ctx).Debugf("ephemeral: loaded %d account(s) for cleanup tracking", added)
|
||||
}
|
||||
|
||||
// sweep runs the cleanup pass for a single account. It queries the
|
||||
// database for disconnected ephemeral peers that have crossed the
|
||||
// staleness window, deletes them via peers.Manager, and then decides
|
||||
// whether to drop the account from the tracker or re-arm the timer.
|
||||
func (e *EphemeralManager) sweep(ctx context.Context, accountID string) {
|
||||
now := timeNow()
|
||||
|
||||
e.accountsLock.Lock()
|
||||
entry, ok := e.accounts[accountID]
|
||||
if !ok || e.stopped {
|
||||
e.accountsLock.Unlock()
|
||||
return
|
||||
}
|
||||
lastDisc := entry.lastDisconnectedAt
|
||||
entry.timer = nil
|
||||
e.accountsLock.Unlock()
|
||||
|
||||
threshold := now.Add(-e.lifeTime)
|
||||
stalePeerIDs, err := e.store.GetStaleEphemeralPeerIDsForAccount(ctx, accountID, threshold)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("ephemeral: failed to query stale peers for account %s: %v", accountID, err)
|
||||
e.metrics.CountCleanupError()
|
||||
e.rearm(ctx, accountID, e.cleanupWindow)
|
||||
return
|
||||
}
|
||||
|
||||
if len(stalePeerIDs) > 0 {
|
||||
log.WithContext(ctx).Tracef("ephemeral: deleting %d peer(s) for account %s", len(stalePeerIDs), accountID)
|
||||
if err := e.peersManager.DeletePeers(ctx, accountID, stalePeerIDs, activity.SystemInitiator, true); err != nil {
|
||||
log.WithContext(ctx).Errorf("ephemeral: failed to delete peers for account %s: %v", accountID, err)
|
||||
e.metrics.CountCleanupError()
|
||||
e.rearm(ctx, accountID, e.cleanupWindow)
|
||||
return
|
||||
for p := e.headPeer; p.next != nil; p = p.next {
|
||||
if p.next.id == id {
|
||||
// if we remove the last element from the chain then set the last-1 as tail
|
||||
if e.tailPeer.id == id {
|
||||
e.tailPeer = p
|
||||
}
|
||||
p.next = p.next.next
|
||||
return true
|
||||
}
|
||||
e.metrics.CountCleanupRun()
|
||||
e.metrics.CountPeersCleaned(int64(len(stalePeerIDs)))
|
||||
}
|
||||
|
||||
e.accountsLock.Lock()
|
||||
defer e.accountsLock.Unlock()
|
||||
if e.stopped {
|
||||
return
|
||||
}
|
||||
entry, ok = e.accounts[accountID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Drop rule: if every disconnect we've observed has now crossed
|
||||
// the staleness window, the sweep we just ran saw everything that
|
||||
// could possibly need cleaning. Dropping is safe — a future
|
||||
// disconnect will recreate the entry. The check uses the latest
|
||||
// lastDisc, which may have advanced (concurrently with the sweep
|
||||
// itself) due to a new OnPeerDisconnected, in which case we
|
||||
// correctly re-arm.
|
||||
horizon := entry.lastDisconnectedAt.Add(e.lifeTime)
|
||||
if !horizon.After(now) {
|
||||
delete(e.accounts, accountID)
|
||||
e.metrics.DecPending(1)
|
||||
log.WithContext(ctx).Tracef("ephemeral: dropping account %s (lastDisc=%s, horizon=%s, now=%s)",
|
||||
accountID, lastDisc, horizon, now)
|
||||
return
|
||||
}
|
||||
|
||||
delay := horizon.Sub(now) + e.cleanupWindow
|
||||
idForClosure := accountID
|
||||
entry.timer = time.AfterFunc(delay, func() {
|
||||
e.sweep(ctx, idForClosure)
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
// rearm reschedules a sweep `delay` from now. Used after a recoverable
|
||||
// error in the sweep path so the account doesn't get stuck.
|
||||
func (e *EphemeralManager) rearm(ctx context.Context, accountID string, delay time.Duration) {
|
||||
e.accountsLock.Lock()
|
||||
defer e.accountsLock.Unlock()
|
||||
if e.stopped {
|
||||
return
|
||||
func (e *EphemeralManager) isPeerOnList(id string) bool {
|
||||
for p := e.headPeer; p != nil; p = p.next {
|
||||
if p.id == id {
|
||||
return true
|
||||
}
|
||||
}
|
||||
entry, ok := e.accounts[accountID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
idForClosure := accountID
|
||||
entry.timer = time.AfterFunc(delay, func() {
|
||||
e.sweep(ctx, idForClosure)
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
// defaultInitialLoadDelay returns a random duration in
|
||||
// [initialLoadMinDelay, initialLoadMaxDelay). Process-wide
|
||||
// math/rand is acceptable here — the delay is purely a smoothing
|
||||
// jitter, not a security primitive.
|
||||
func defaultInitialLoadDelay() time.Duration {
|
||||
span := int64(initialLoadMaxDelay - initialLoadMinDelay)
|
||||
if span <= 0 {
|
||||
return initialLoadMinDelay
|
||||
}
|
||||
return initialLoadMinDelay + time.Duration(rand.Int63n(span))
|
||||
func (e *EphemeralManager) newDeadLine() time.Time {
|
||||
return timeNow().Add(e.lifeTime)
|
||||
}
|
||||
|
||||
@@ -2,544 +2,299 @@ package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// MockStore is a thin in-memory stand-in that implements only the two
|
||||
// methods the EphemeralManager uses. It honors the account / ephemeral
|
||||
// / connected / lastSeen attributes of each peer so the cleanup logic
|
||||
// can be exercised end-to-end without bringing up sqlite or Postgres.
|
||||
type MockStore struct {
|
||||
store.Store
|
||||
mu sync.Mutex
|
||||
account *types.Account
|
||||
}
|
||||
|
||||
func (s *MockStore) GetStaleEphemeralPeerIDsForAccount(_ context.Context, accountID string, olderThan time.Time) ([]string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.account == nil || s.account.Id != accountID {
|
||||
return nil, nil
|
||||
}
|
||||
var ids []string
|
||||
for _, p := range s.account.Peers {
|
||||
if !p.Ephemeral {
|
||||
continue
|
||||
}
|
||||
if p.Status == nil || p.Status.Connected {
|
||||
continue
|
||||
}
|
||||
if p.Status.LastSeen.Before(olderThan) {
|
||||
ids = append(ids, p.ID)
|
||||
func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ store.LockingStrength) ([]*nbpeer.Peer, error) {
|
||||
var peers []*nbpeer.Peer
|
||||
for _, v := range s.account.Peers {
|
||||
if v.Ephemeral {
|
||||
peers = append(peers, v)
|
||||
}
|
||||
}
|
||||
return ids, nil
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func (s *MockStore) GetEphemeralAccountsLastDisconnect(_ context.Context) (map[string]time.Time, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
out := map[string]time.Time{}
|
||||
if s.account == nil {
|
||||
return out, nil
|
||||
}
|
||||
var latest time.Time
|
||||
hasAny := false
|
||||
for _, p := range s.account.Peers {
|
||||
if !p.Ephemeral || p.Status == nil || p.Status.Connected {
|
||||
continue
|
||||
}
|
||||
if !hasAny || p.Status.LastSeen.After(latest) {
|
||||
latest = p.Status.LastSeen
|
||||
hasAny = true
|
||||
}
|
||||
}
|
||||
if hasAny {
|
||||
out[s.account.Id] = latest
|
||||
}
|
||||
return out, nil
|
||||
type MockAccountManager struct {
|
||||
mu sync.Mutex
|
||||
nbAccount.Manager
|
||||
store *MockStore
|
||||
deletePeerCalls int
|
||||
bufferUpdateCalls map[string]int
|
||||
wg *sync.WaitGroup
|
||||
}
|
||||
|
||||
// withFakeClock pins timeNow to a settable value for the duration of t.
|
||||
// Returns a getter and a setter so subtests can advance virtual time.
|
||||
func withFakeClock(t *testing.T, start time.Time) (get func() time.Time, set func(time.Time)) {
|
||||
t.Helper()
|
||||
var mu sync.Mutex
|
||||
now := start
|
||||
func (a *MockAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.deletePeerCalls++
|
||||
delete(a.store.account.Peers, peerID)
|
||||
if a.wg != nil {
|
||||
a.wg.Done()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *MockAccountManager) GetDeletePeerCalls() int {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
return a.deletePeerCalls
|
||||
}
|
||||
|
||||
func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if a.bufferUpdateCalls == nil {
|
||||
a.bufferUpdateCalls = make(map[string]int)
|
||||
}
|
||||
a.bufferUpdateCalls[accountID]++
|
||||
}
|
||||
|
||||
func (a *MockAccountManager) GetBufferUpdateCalls(accountID string) int {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if a.bufferUpdateCalls == nil {
|
||||
return 0
|
||||
}
|
||||
return a.bufferUpdateCalls[accountID]
|
||||
}
|
||||
|
||||
func (a *MockAccountManager) GetStore() store.Store {
|
||||
return a.store
|
||||
}
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
timeNow = time.Now
|
||||
})
|
||||
startTime := time.Now()
|
||||
timeNow = func() time.Time {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return now
|
||||
return startTime
|
||||
}
|
||||
t.Cleanup(func() { timeNow = time.Now })
|
||||
|
||||
return func() time.Time {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return now
|
||||
}, func(v time.Time) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
now = v
|
||||
}
|
||||
}
|
||||
|
||||
// newManagerForTest builds a manager with short timers and no random
|
||||
// initial-load delay so tests run instantly.
|
||||
func newManagerForTest(t *testing.T, st store.Store, peersMgr peers.Manager) *EphemeralManager {
|
||||
t.Helper()
|
||||
mgr := NewEphemeralManager(st, peersMgr)
|
||||
mgr.lifeTime = 100 * time.Millisecond
|
||||
mgr.cleanupWindow = 10 * time.Millisecond
|
||||
mgr.initialLoadDelay = func() time.Duration { return 0 }
|
||||
t.Cleanup(mgr.Stop)
|
||||
return mgr
|
||||
}
|
||||
|
||||
// TestOnPeerDisconnected_RegistersAndSweeps drives the OnPeerDisconnected
|
||||
// path with a fake clock: a single ephemeral peer disconnects, we
|
||||
// advance past the staleness window, and the sweep deletes it.
|
||||
func TestOnPeerDisconnected_RegistersAndSweeps(t *testing.T) {
|
||||
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
|
||||
|
||||
getNow, setNow := withFakeClock(t, time.Now())
|
||||
|
||||
store := &MockStore{}
|
||||
ctrl := gomock.NewController(t)
|
||||
peersMgr := peers.NewMockManager(ctrl)
|
||||
peersManager := peers.NewMockManager(ctrl)
|
||||
|
||||
var deletedMu sync.Mutex
|
||||
var deleted []string
|
||||
var deleteCalls atomic.Int32
|
||||
peersMgr.EXPECT().
|
||||
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
|
||||
DoAndReturn(func(_ context.Context, accountID string, peerIDs []string, _ string, _ bool) error {
|
||||
deleteCalls.Add(1)
|
||||
mockStore.mu.Lock()
|
||||
for _, id := range peerIDs {
|
||||
delete(mockStore.account.Peers, id)
|
||||
numberOfPeers := 5
|
||||
numberOfEphemeralPeers := 3
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
|
||||
// Expect DeletePeers to be called for ephemeral peers
|
||||
peersManager.EXPECT().
|
||||
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
|
||||
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||
for _, peerID := range peerIDs {
|
||||
delete(store.account.Peers, peerID)
|
||||
}
|
||||
mockStore.mu.Unlock()
|
||||
deletedMu.Lock()
|
||||
deleted = append(deleted, peerIDs...)
|
||||
deletedMu.Unlock()
|
||||
return nil
|
||||
}).AnyTimes()
|
||||
}).
|
||||
AnyTimes()
|
||||
|
||||
mgr := newManagerForTest(t, mockStore, peersMgr)
|
||||
mgr := NewEphemeralManager(store, peersManager)
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
|
||||
mgr.cleanup(context.Background())
|
||||
|
||||
// One ephemeral peer that disconnected "now".
|
||||
now := getNow()
|
||||
p := &nbpeer.Peer{
|
||||
ID: "p1",
|
||||
AccountID: "acc-1",
|
||||
Ephemeral: true,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now},
|
||||
if len(store.account.Peers) != numberOfPeers {
|
||||
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers))
|
||||
}
|
||||
mockStore.account.Peers[p.ID] = p
|
||||
mgr.OnPeerDisconnected(context.Background(), p)
|
||||
|
||||
// Advance past lifeTime + cleanupWindow so the timer-driven sweep fires.
|
||||
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
|
||||
require.Eventually(t, func() bool { return deleteCalls.Load() >= 1 }, 2*time.Second, 5*time.Millisecond,
|
||||
"sweep should fire and delete the stale peer")
|
||||
|
||||
deletedMu.Lock()
|
||||
deletedCopy := append([]string(nil), deleted...)
|
||||
deletedMu.Unlock()
|
||||
require.Equal(t, []string{"p1"}, deletedCopy, "only the one ephemeral peer should be deleted")
|
||||
}
|
||||
|
||||
// TestOnPeerDisconnected_NonEphemeralIgnored: a non-ephemeral disconnect
|
||||
// must not register the account or arm any timer.
|
||||
func TestOnPeerDisconnected_NonEphemeralIgnored(t *testing.T) {
|
||||
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
|
||||
withFakeClock(t, time.Now())
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
peersMgr := peers.NewMockManager(ctrl)
|
||||
// No DeletePeers expectation — must not be called.
|
||||
|
||||
mgr := newManagerForTest(t, mockStore, peersMgr)
|
||||
mgr.OnPeerDisconnected(context.Background(), &nbpeer.Peer{
|
||||
ID: "p1",
|
||||
AccountID: "acc-1",
|
||||
Ephemeral: false,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: timeNow()},
|
||||
func TestNewManagerPeerConnected(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
timeNow = time.Now
|
||||
})
|
||||
startTime := time.Now()
|
||||
timeNow = func() time.Time {
|
||||
return startTime
|
||||
}
|
||||
|
||||
mgr.accountsLock.Lock()
|
||||
require.Empty(t, mgr.accounts, "non-ephemeral disconnect must not register an account")
|
||||
mgr.accountsLock.Unlock()
|
||||
store := &MockStore{}
|
||||
ctrl := gomock.NewController(t)
|
||||
peersManager := peers.NewMockManager(ctrl)
|
||||
|
||||
numberOfPeers := 5
|
||||
numberOfEphemeralPeers := 3
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
|
||||
// Expect DeletePeers to be called for ephemeral peers (except the connected one)
|
||||
peersManager.EXPECT().
|
||||
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
|
||||
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||
for _, peerID := range peerIDs {
|
||||
delete(store.account.Peers, peerID)
|
||||
}
|
||||
return nil
|
||||
}).
|
||||
AnyTimes()
|
||||
|
||||
mgr := NewEphemeralManager(store, peersManager)
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
|
||||
|
||||
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
|
||||
mgr.cleanup(context.Background())
|
||||
|
||||
expected := numberOfPeers + 1
|
||||
if len(store.account.Peers) != expected {
|
||||
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSweep_DropsAccountWhenIdle: after a sweep cleans the stale peers,
|
||||
// if no more disconnects have arrived the account must be dropped from
|
||||
// the in-memory tracker.
|
||||
func TestSweep_DropsAccountWhenIdle(t *testing.T) {
|
||||
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
|
||||
getNow, setNow := withFakeClock(t, time.Now())
|
||||
func TestNewManagerPeerDisconnected(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
timeNow = time.Now
|
||||
})
|
||||
startTime := time.Now()
|
||||
timeNow = func() time.Time {
|
||||
return startTime
|
||||
}
|
||||
|
||||
store := &MockStore{}
|
||||
ctrl := gomock.NewController(t)
|
||||
peersMgr := peers.NewMockManager(ctrl)
|
||||
peersMgr.EXPECT().
|
||||
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
|
||||
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
|
||||
mockStore.mu.Lock()
|
||||
for _, id := range peerIDs {
|
||||
delete(mockStore.account.Peers, id)
|
||||
peersManager := peers.NewMockManager(ctrl)
|
||||
|
||||
numberOfPeers := 5
|
||||
numberOfEphemeralPeers := 3
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
|
||||
// Expect DeletePeers to be called for the one disconnected peer
|
||||
peersManager.EXPECT().
|
||||
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
|
||||
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||
for _, peerID := range peerIDs {
|
||||
delete(store.account.Peers, peerID)
|
||||
}
|
||||
mockStore.mu.Unlock()
|
||||
return nil
|
||||
}).AnyTimes()
|
||||
}).
|
||||
AnyTimes()
|
||||
|
||||
mgr := newManagerForTest(t, mockStore, peersMgr)
|
||||
mgr := NewEphemeralManager(store, peersManager)
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
for _, v := range store.account.Peers {
|
||||
mgr.OnPeerConnected(context.Background(), v)
|
||||
|
||||
now := getNow()
|
||||
p := &nbpeer.Peer{ID: "p1", AccountID: "acc-1", Ephemeral: true,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now}}
|
||||
mockStore.account.Peers[p.ID] = p
|
||||
mgr.OnPeerDisconnected(context.Background(), p)
|
||||
}
|
||||
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
|
||||
|
||||
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
|
||||
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
|
||||
mgr.cleanup(context.Background())
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
mgr.accountsLock.Lock()
|
||||
defer mgr.accountsLock.Unlock()
|
||||
return len(mgr.accounts) == 0
|
||||
}, 2*time.Second, 5*time.Millisecond, "account should be dropped after sweep with no new disconnects")
|
||||
expected := numberOfPeers + numberOfEphemeralPeers - 1
|
||||
if len(store.account.Peers) != expected {
|
||||
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSweep_ReArmsWhenNewDisconnectArrived: simulate the race where a
|
||||
// fresh disconnect arrives just before the sweep fires. The sweep must
|
||||
// observe the updated lastDisc and re-arm rather than drop.
|
||||
func TestSweep_ReArmsWhenNewDisconnectArrived(t *testing.T) {
|
||||
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
|
||||
getNow, setNow := withFakeClock(t, time.Now())
|
||||
func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) {
|
||||
const (
|
||||
ephemeralPeers = 10
|
||||
testLifeTime = 1 * time.Second
|
||||
testCleanupWindow = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
t.Cleanup(func() {
|
||||
timeNow = time.Now
|
||||
})
|
||||
startTime := time.Now()
|
||||
timeNow = func() time.Time {
|
||||
return startTime
|
||||
}
|
||||
|
||||
mockStore := &MockStore{}
|
||||
account := newAccountWithId(context.Background(), "account", "", "", false)
|
||||
mockStore.account = account
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(ephemeralPeers)
|
||||
mockAM := &MockAccountManager{
|
||||
store: mockStore,
|
||||
wg: wg,
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
peersMgr := peers.NewMockManager(ctrl)
|
||||
peersMgr.EXPECT().
|
||||
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
|
||||
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
|
||||
mockStore.mu.Lock()
|
||||
for _, id := range peerIDs {
|
||||
delete(mockStore.account.Peers, id)
|
||||
peersManager := peers.NewMockManager(ctrl)
|
||||
|
||||
// Set up expectation that DeletePeers will be called once with all peer IDs
|
||||
peersManager.EXPECT().
|
||||
DeletePeers(gomock.Any(), account.Id, gomock.Any(), gomock.Any(), true).
|
||||
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||
// Simulate the actual deletion behavior
|
||||
for _, peerID := range peerIDs {
|
||||
err := mockAM.DeletePeer(ctx, accountID, peerID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
mockStore.mu.Unlock()
|
||||
mockAM.BufferUpdateAccountPeers(ctx, accountID, types.UpdateReason{})
|
||||
return nil
|
||||
}).AnyTimes()
|
||||
}).
|
||||
Times(1)
|
||||
|
||||
mgr := newManagerForTest(t, mockStore, peersMgr)
|
||||
mgr := NewEphemeralManager(mockStore, peersManager)
|
||||
mgr.lifeTime = testLifeTime
|
||||
mgr.cleanupWindow = testCleanupWindow
|
||||
|
||||
now := getNow()
|
||||
p1 := &nbpeer.Peer{ID: "p1", AccountID: "acc-1", Ephemeral: true,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now}}
|
||||
mockStore.account.Peers[p1.ID] = p1
|
||||
mgr.OnPeerDisconnected(context.Background(), p1)
|
||||
|
||||
// Advance most of the way toward the first sweep, then introduce
|
||||
// a fresh disconnect that resets lastDisc.
|
||||
setNow(now.Add(mgr.lifeTime - 10*time.Millisecond))
|
||||
p2 := &nbpeer.Peer{ID: "p2", AccountID: "acc-1", Ephemeral: true,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: getNow()}}
|
||||
mockStore.account.Peers[p2.ID] = p2
|
||||
mgr.OnPeerDisconnected(context.Background(), p2)
|
||||
|
||||
// Push past p1's staleness so the first sweep runs and cleans p1
|
||||
// but observes p2 already on the account entry. It must re-arm.
|
||||
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
mockStore.mu.Lock()
|
||||
defer mockStore.mu.Unlock()
|
||||
_, gone := mockStore.account.Peers["p1"]
|
||||
return !gone
|
||||
}, 2*time.Second, 5*time.Millisecond, "p1 should be cleaned at the first sweep")
|
||||
|
||||
// The account should still be tracked because p2 is younger than lifeTime
|
||||
// from the sweep's vantage point at this moment.
|
||||
mgr.accountsLock.Lock()
|
||||
_, stillTracked := mgr.accounts["acc-1"]
|
||||
mgr.accountsLock.Unlock()
|
||||
require.True(t, stillTracked, "account should remain tracked because p2's disconnect kept it active")
|
||||
|
||||
// Push past p2's staleness; second sweep cleans p2 and drops the account.
|
||||
setNow(getNow().Add(mgr.lifeTime + 5*mgr.cleanupWindow))
|
||||
require.Eventually(t, func() bool {
|
||||
mgr.accountsLock.Lock()
|
||||
defer mgr.accountsLock.Unlock()
|
||||
return len(mgr.accounts) == 0
|
||||
}, 2*time.Second, 5*time.Millisecond, "account should drop after the final sweep")
|
||||
}
|
||||
|
||||
// TestSweep_BatchesPeersPerAccount: many ephemeral peers disconnect on
|
||||
// the same account; a single sweep must delete them all in one
|
||||
// DeletePeers call.
|
||||
func TestSweep_BatchesPeersPerAccount(t *testing.T) {
|
||||
const ephemeralPeers = 8
|
||||
|
||||
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
|
||||
getNow, setNow := withFakeClock(t, time.Now())
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
peersMgr := peers.NewMockManager(ctrl)
|
||||
|
||||
deleteBatches := make(chan []string, 4)
|
||||
peersMgr.EXPECT().
|
||||
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
|
||||
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
|
||||
cp := append([]string(nil), peerIDs...)
|
||||
mockStore.mu.Lock()
|
||||
for _, id := range peerIDs {
|
||||
delete(mockStore.account.Peers, id)
|
||||
}
|
||||
mockStore.mu.Unlock()
|
||||
deleteBatches <- cp
|
||||
return nil
|
||||
}).Times(1)
|
||||
|
||||
mgr := newManagerForTest(t, mockStore, peersMgr)
|
||||
|
||||
now := getNow()
|
||||
for i := 0; i < ephemeralPeers; i++ {
|
||||
id := fmt.Sprintf("p-%d", i)
|
||||
// Stagger by a fraction of cleanupWindow so they all fall on
|
||||
// the same sweep tick.
|
||||
when := now.Add(time.Duration(i) * time.Millisecond)
|
||||
p := &nbpeer.Peer{ID: id, AccountID: "acc-1", Ephemeral: true,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: when}}
|
||||
mockStore.account.Peers[id] = p
|
||||
// Add peers and disconnect them at slightly different times (within cleanup window)
|
||||
for i := range ephemeralPeers {
|
||||
p := &nbpeer.Peer{ID: fmt.Sprintf("peer-%d", i), AccountID: account.Id, Ephemeral: true}
|
||||
mockStore.account.Peers[p.ID] = p
|
||||
mgr.OnPeerDisconnected(context.Background(), p)
|
||||
startTime = startTime.Add(testCleanupWindow / (ephemeralPeers * 2))
|
||||
}
|
||||
|
||||
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
|
||||
// Advance time past the lifetime to trigger cleanup
|
||||
startTime = startTime.Add(testLifeTime + testCleanupWindow)
|
||||
|
||||
select {
|
||||
case batch := <-deleteBatches:
|
||||
require.Len(t, batch, ephemeralPeers, "all peers should be deleted in a single batch")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("expected one batched DeletePeers call")
|
||||
}
|
||||
// Wait for all deletions to complete
|
||||
wg.Wait()
|
||||
|
||||
assert.Len(t, mockStore.account.Peers, 0, "all ephemeral peers should be cleaned up after the lifetime")
|
||||
assert.Equal(t, 1, mockAM.GetBufferUpdateCalls(account.Id), "buffer update should be called once")
|
||||
assert.Equal(t, ephemeralPeers, mockAM.GetDeletePeerCalls(), "should have deleted all peers")
|
||||
}
|
||||
|
||||
// TestLoadInitialAccounts_SeedsFromStore exercises the post-restart
|
||||
// catch-up path: pre-populate the store, point the manager at it, and
|
||||
// confirm both already-stale and not-yet-stale peers get cleaned at
|
||||
// their proper times.
|
||||
func TestLoadInitialAccounts_SeedsFromStore(t *testing.T) {
|
||||
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
|
||||
getNow, setNow := withFakeClock(t, time.Now())
|
||||
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
|
||||
store.account = newAccountWithId(context.Background(), "my account", "", "", false)
|
||||
|
||||
now := getNow()
|
||||
// p-stale: already past the staleness window when load runs.
|
||||
mockStore.account.Peers["p-stale"] = &nbpeer.Peer{
|
||||
ID: "p-stale", AccountID: "acc-1", Ephemeral: true,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now.Add(-time.Hour)},
|
||||
}
|
||||
// p-fresh: disconnected but not yet stale.
|
||||
mockStore.account.Peers["p-fresh"] = &nbpeer.Peer{
|
||||
ID: "p-fresh", AccountID: "acc-1", Ephemeral: true,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now},
|
||||
for i := 0; i < numberOfPeers; i++ {
|
||||
peerId := fmt.Sprintf("peer_%d", i)
|
||||
p := &nbpeer.Peer{
|
||||
ID: peerId,
|
||||
Ephemeral: false,
|
||||
}
|
||||
store.account.Peers[p.ID] = p
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
peersMgr := peers.NewMockManager(ctrl)
|
||||
peersMgr.EXPECT().
|
||||
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
|
||||
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
|
||||
mockStore.mu.Lock()
|
||||
for _, id := range peerIDs {
|
||||
delete(mockStore.account.Peers, id)
|
||||
}
|
||||
mockStore.mu.Unlock()
|
||||
return nil
|
||||
}).AnyTimes()
|
||||
|
||||
mgr := newManagerForTest(t, mockStore, peersMgr)
|
||||
// Drive loadInitialAccounts directly with the fake-clock-aware now.
|
||||
mgr.loadInitialAccounts(context.Background())
|
||||
|
||||
// First sweep should fire shortly (cleanupWindow) for the stale peer.
|
||||
setNow(now.Add(5 * mgr.cleanupWindow))
|
||||
require.Eventually(t, func() bool {
|
||||
mockStore.mu.Lock()
|
||||
defer mockStore.mu.Unlock()
|
||||
_, gone := mockStore.account.Peers["p-stale"]
|
||||
return !gone
|
||||
}, 2*time.Second, 5*time.Millisecond, "p-stale should be deleted on the first sweep")
|
||||
|
||||
// p-fresh is not yet stale; advance past its window.
|
||||
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
|
||||
require.Eventually(t, func() bool {
|
||||
mockStore.mu.Lock()
|
||||
defer mockStore.mu.Unlock()
|
||||
_, gone := mockStore.account.Peers["p-fresh"]
|
||||
return !gone
|
||||
}, 2*time.Second, 5*time.Millisecond, "p-fresh should be deleted once it crosses the staleness window")
|
||||
}
|
||||
|
||||
// TestStop_CancelsPendingWork verifies that Stop() cancels both the
|
||||
// deferred initial load and per-account sweep timers and that
|
||||
// subsequent OnPeerDisconnected calls are ignored.
|
||||
func TestStop_CancelsPendingWork(t *testing.T) {
|
||||
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
|
||||
withFakeClock(t, time.Now())
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
peersMgr := peers.NewMockManager(ctrl)
|
||||
// DeletePeers must NOT be called after Stop.
|
||||
|
||||
mgr := NewEphemeralManager(mockStore, peersMgr)
|
||||
mgr.lifeTime = 100 * time.Millisecond
|
||||
mgr.cleanupWindow = 10 * time.Millisecond
|
||||
// Use a long delay so the initial-load timer is still pending.
|
||||
mgr.initialLoadDelay = func() time.Duration { return time.Hour }
|
||||
|
||||
mgr.LoadInitialPeers(context.Background())
|
||||
mgr.OnPeerDisconnected(context.Background(), &nbpeer.Peer{
|
||||
ID: "p1", AccountID: "acc-1", Ephemeral: true,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: timeNow()},
|
||||
})
|
||||
|
||||
mgr.accountsLock.Lock()
|
||||
require.NotNil(t, mgr.initialLoadTimer, "initial-load timer should be armed")
|
||||
require.Len(t, mgr.accounts, 1, "account should be tracked after disconnect")
|
||||
mgr.accountsLock.Unlock()
|
||||
|
||||
mgr.Stop()
|
||||
|
||||
mgr.accountsLock.Lock()
|
||||
require.Empty(t, mgr.accounts, "Stop should clear tracked accounts")
|
||||
require.True(t, mgr.stopped, "stopped flag must be set")
|
||||
mgr.accountsLock.Unlock()
|
||||
|
||||
// Post-stop disconnect must be ignored.
|
||||
mgr.OnPeerDisconnected(context.Background(), &nbpeer.Peer{
|
||||
ID: "p2", AccountID: "acc-1", Ephemeral: true,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: timeNow()},
|
||||
})
|
||||
mgr.accountsLock.Lock()
|
||||
require.Empty(t, mgr.accounts, "disconnects after Stop must be ignored")
|
||||
mgr.accountsLock.Unlock()
|
||||
}
|
||||
|
||||
// TestOnPeerConnected_IsNoop: the OnPeerConnected hook is preserved on
|
||||
// the interface but does nothing in the per-account model — the sweep
|
||||
// query filters connected peers at the DB level.
|
||||
func TestOnPeerConnected_IsNoop(t *testing.T) {
|
||||
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
|
||||
withFakeClock(t, time.Now())
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
peersMgr := peers.NewMockManager(ctrl)
|
||||
|
||||
mgr := newManagerForTest(t, mockStore, peersMgr)
|
||||
mgr.OnPeerDisconnected(context.Background(), &nbpeer.Peer{
|
||||
ID: "p1", AccountID: "acc-1", Ephemeral: true,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: timeNow()},
|
||||
})
|
||||
mgr.accountsLock.Lock()
|
||||
require.Len(t, mgr.accounts, 1, "disconnect should track the account")
|
||||
mgr.accountsLock.Unlock()
|
||||
|
||||
mgr.OnPeerConnected(context.Background(), &nbpeer.Peer{ID: "p1", AccountID: "acc-1", Ephemeral: true})
|
||||
mgr.accountsLock.Lock()
|
||||
require.Len(t, mgr.accounts, 1, "OnPeerConnected must be a no-op")
|
||||
mgr.accountsLock.Unlock()
|
||||
}
|
||||
|
||||
// TestSweep_StoreErrorReArms: if the stale-peer query fails, the
|
||||
// account must remain tracked and a follow-up sweep gets scheduled.
|
||||
func TestSweep_StoreErrorReArms(t *testing.T) {
|
||||
mockStore := &erroringStore{
|
||||
MockStore: MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)},
|
||||
}
|
||||
getNow, setNow := withFakeClock(t, time.Now())
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
peersMgr := peers.NewMockManager(ctrl)
|
||||
|
||||
mgr := newManagerForTest(t, mockStore, peersMgr)
|
||||
|
||||
p := &nbpeer.Peer{ID: "p1", AccountID: "acc-1", Ephemeral: true,
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: getNow()}}
|
||||
mockStore.account.Peers[p.ID] = p
|
||||
mgr.OnPeerDisconnected(context.Background(), p)
|
||||
|
||||
mockStore.fail.Store(true)
|
||||
setNow(getNow().Add(mgr.lifeTime + 5*mgr.cleanupWindow))
|
||||
|
||||
// Wait until the failing sweep has run at least once.
|
||||
require.Eventually(t, func() bool { return mockStore.failedCalls.Load() >= 1 },
|
||||
2*time.Second, 5*time.Millisecond, "expected at least one failing sweep")
|
||||
|
||||
mgr.accountsLock.Lock()
|
||||
_, stillTracked := mgr.accounts["acc-1"]
|
||||
mgr.accountsLock.Unlock()
|
||||
require.True(t, stillTracked, "account must remain tracked after a sweep error")
|
||||
|
||||
// Recover and ensure the rearmed sweep cleans up.
|
||||
peersMgr.EXPECT().
|
||||
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
|
||||
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
|
||||
mockStore.mu.Lock()
|
||||
for _, id := range peerIDs {
|
||||
delete(mockStore.account.Peers, id)
|
||||
}
|
||||
mockStore.mu.Unlock()
|
||||
return nil
|
||||
}).AnyTimes()
|
||||
mockStore.fail.Store(false)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
mockStore.mu.Lock()
|
||||
defer mockStore.mu.Unlock()
|
||||
_, gone := mockStore.account.Peers["p1"]
|
||||
return !gone
|
||||
}, 2*time.Second, 5*time.Millisecond, "rearmed sweep should clean up after the store recovers")
|
||||
}
|
||||
|
||||
// erroringStore is a MockStore that can be flipped into a failing mode
|
||||
// to exercise the sweep's error-rearm path.
|
||||
type erroringStore struct {
|
||||
MockStore
|
||||
fail atomic.Bool
|
||||
failedCalls atomic.Int32
|
||||
}
|
||||
|
||||
func (s *erroringStore) GetStaleEphemeralPeerIDsForAccount(ctx context.Context, accountID string, olderThan time.Time) ([]string, error) {
|
||||
if s.fail.Load() {
|
||||
s.failedCalls.Add(1)
|
||||
return nil, errors.New("synthetic store error")
|
||||
}
|
||||
return s.MockStore.GetStaleEphemeralPeerIDsForAccount(ctx, accountID, olderThan)
|
||||
}
|
||||
|
||||
// TestDefaultInitialLoadDelay confirms the jitter falls inside the
|
||||
// documented [8m, 10m) range — sanity check for the production timer.
|
||||
func TestDefaultInitialLoadDelay(t *testing.T) {
|
||||
for i := 0; i < 1000; i++ {
|
||||
d := defaultInitialLoadDelay()
|
||||
assert.GreaterOrEqual(t, d, initialLoadMinDelay)
|
||||
assert.Less(t, d, initialLoadMaxDelay)
|
||||
for i := 0; i < numberOfEphemeralPeers; i++ {
|
||||
peerId := fmt.Sprintf("ephemeral_peer_%d", i)
|
||||
p := &nbpeer.Peer{
|
||||
ID: peerId,
|
||||
Ephemeral: true,
|
||||
}
|
||||
store.account.Peers[p.ID] = p
|
||||
}
|
||||
}
|
||||
|
||||
@@ -596,7 +351,3 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis
|
||||
}
|
||||
return acc
|
||||
}
|
||||
|
||||
// silence the import "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||
// (still needed indirectly for ephemeral.EphemeralLifeTime in production paths).
|
||||
var _ = ephemeral.EphemeralLifeTime
|
||||
|
||||
@@ -10,8 +10,10 @@ import (
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||
"github.com/rs/cors"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
@@ -19,7 +21,6 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
cachestore "github.com/eko/gocache/lib/v4/store"
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
@@ -27,16 +28,20 @@ import (
|
||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
activitystore "github.com/netbirdio/netbird/management/server/activity/store"
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
const apiPrefix = "/api"
|
||||
|
||||
var (
|
||||
kaep = keepalive.EnforcementPolicy{
|
||||
MinTime: 15 * time.Second,
|
||||
@@ -94,12 +99,17 @@ func (s *BaseServer) Store() store.Store {
|
||||
|
||||
func (s *BaseServer) EventStore() activity.Store {
|
||||
return Create(s, func() activity.Store {
|
||||
integrationMetrics, err := integrations.InitIntegrationMetrics(context.Background(), s.Metrics())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize integration metrics: %v", err)
|
||||
var err error
|
||||
key := s.Config.DataStoreEncryptionKey
|
||||
if key == "" {
|
||||
log.Debugf("generate new activity store encryption key")
|
||||
key, err = crypt.GenerateKey()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to generate event store encryption key: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
eventStore, _, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics)
|
||||
eventStore, err := activitystore.NewSqlStore(context.Background(), s.Config.Datadir, key)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize event store: %v", err)
|
||||
}
|
||||
@@ -110,7 +120,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
|
||||
func (s *BaseServer) APIHandler() http.Handler {
|
||||
return Create(s, func() http.Handler {
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter())
|
||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -118,6 +128,22 @@ func (s *BaseServer) APIHandler() http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
// IDPHandler returns the HTTP handler for the embedded IdP (Dex), or nil if
|
||||
// the deployment isn't using the embedded variant.
|
||||
func (s *BaseServer) IDPHandler() http.Handler {
|
||||
embeddedIdP, ok := s.IdpManager().(*idp.EmbeddedIdPManager)
|
||||
if !ok || embeddedIdP == nil {
|
||||
return nil
|
||||
}
|
||||
return cors.AllowAll().Handler(embeddedIdP.Handler())
|
||||
}
|
||||
|
||||
func (s *BaseServer) Router() *mux.Router {
|
||||
return Create(s, func() *mux.Router {
|
||||
return mux.NewRouter().PathPrefix(apiPrefix).Subrouter()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
|
||||
return Create(s, func() *middleware.APIRateLimiter {
|
||||
cfg, enabled := middleware.RateLimiterConfigFromEnv()
|
||||
@@ -129,68 +155,72 @@ func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
|
||||
|
||||
func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
return Create(s, func() *grpc.Server {
|
||||
trustedPeers := s.Config.ReverseProxy.TrustedPeers
|
||||
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
|
||||
if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) {
|
||||
log.WithContext(context.Background()).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
|
||||
trustedPeers = defaultTrustedPeers
|
||||
}
|
||||
trustedHTTPProxies := s.Config.ReverseProxy.TrustedHTTPProxies
|
||||
trustedProxiesCount := s.Config.ReverseProxy.TrustedHTTPProxiesCount
|
||||
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
|
||||
log.WithContext(context.Background()).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
|
||||
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
|
||||
}
|
||||
realipOpts := []realip.Option{
|
||||
realip.WithTrustedPeers(trustedPeers),
|
||||
realip.WithTrustedProxies(trustedHTTPProxies),
|
||||
realip.WithTrustedProxiesCount(trustedProxiesCount),
|
||||
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
|
||||
}
|
||||
proxyUnary, proxyStream, proxyAuthClose := nbgrpc.NewProxyAuthInterceptors(s.Store())
|
||||
s.proxyAuthClose = proxyAuthClose
|
||||
gRPCOpts := []grpc.ServerOption{
|
||||
grpc.KeepaliveEnforcementPolicy(kaep),
|
||||
grpc.KeepaliveParams(kasp),
|
||||
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor, proxyUnary),
|
||||
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor, proxyStream),
|
||||
}
|
||||
|
||||
if s.Config.HttpConfig.LetsEncryptDomain != "" {
|
||||
certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create certificate service: %v", err)
|
||||
}
|
||||
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
|
||||
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
|
||||
} else if s.Config.HttpConfig.CertFile != "" && s.Config.HttpConfig.CertKey != "" {
|
||||
tlsConfig, err := loadTLSConfig(s.Config.HttpConfig.CertFile, s.Config.HttpConfig.CertKey)
|
||||
if err != nil {
|
||||
log.Fatalf("cannot load TLS credentials: %v", err)
|
||||
}
|
||||
transportCredentials := credentials.NewTLS(tlsConfig)
|
||||
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
|
||||
}
|
||||
|
||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider(), s.SessionStore())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create management server: %v", err)
|
||||
}
|
||||
serviceMgr := s.ServiceManager()
|
||||
srv.SetReverseProxyManager(serviceMgr)
|
||||
if serviceMgr != nil {
|
||||
serviceMgr.StartExposeReaper(context.Background())
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
|
||||
|
||||
mgmtProto.RegisterProxyServiceServer(gRPCAPIHandler, s.ReverseProxyGRPCServer())
|
||||
log.Info("ProxyService registered on gRPC server")
|
||||
|
||||
return gRPCAPIHandler
|
||||
return s.BuildGRPCServer(s.ExtendNetBirdConfig)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) BuildGRPCServer(configExtender nbgrpc.ConfigExtender) *grpc.Server {
|
||||
trustedPeers := s.Config.ReverseProxy.TrustedPeers
|
||||
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
|
||||
if len(trustedPeers) == 0 || slices.Equal(trustedPeers, defaultTrustedPeers) {
|
||||
log.WithContext(context.Background()).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
|
||||
trustedPeers = defaultTrustedPeers
|
||||
}
|
||||
trustedHTTPProxies := s.Config.ReverseProxy.TrustedHTTPProxies
|
||||
trustedProxiesCount := s.Config.ReverseProxy.TrustedHTTPProxiesCount
|
||||
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
|
||||
log.WithContext(context.Background()).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
|
||||
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
|
||||
}
|
||||
realipOpts := []realip.Option{
|
||||
realip.WithTrustedPeers(trustedPeers),
|
||||
realip.WithTrustedProxies(trustedHTTPProxies),
|
||||
realip.WithTrustedProxiesCount(trustedProxiesCount),
|
||||
realip.WithHeaders([]string{realip.XForwardedFor, realip.XRealIp}),
|
||||
}
|
||||
proxyUnary, proxyStream, proxyAuthClose := nbgrpc.NewProxyAuthInterceptors(s.Store())
|
||||
s.proxyAuthClose = proxyAuthClose
|
||||
gRPCOpts := []grpc.ServerOption{
|
||||
grpc.KeepaliveEnforcementPolicy(kaep),
|
||||
grpc.KeepaliveParams(kasp),
|
||||
grpc.ChainUnaryInterceptor(realip.UnaryServerInterceptorOpts(realipOpts...), unaryInterceptor, proxyUnary),
|
||||
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor, proxyStream),
|
||||
}
|
||||
|
||||
if s.Config.HttpConfig.LetsEncryptDomain != "" {
|
||||
certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create certificate service: %v", err)
|
||||
}
|
||||
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
|
||||
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
|
||||
} else if s.Config.HttpConfig.CertFile != "" && s.Config.HttpConfig.CertKey != "" {
|
||||
tlsConfig, err := loadTLSConfig(s.Config.HttpConfig.CertFile, s.Config.HttpConfig.CertKey)
|
||||
if err != nil {
|
||||
log.Fatalf("cannot load TLS credentials: %v", err)
|
||||
}
|
||||
transportCredentials := credentials.NewTLS(tlsConfig)
|
||||
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
|
||||
}
|
||||
|
||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider(), s.SessionStore(), configExtender)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create management server: %v", err)
|
||||
}
|
||||
serviceMgr := s.ServiceManager()
|
||||
srv.SetReverseProxyManager(serviceMgr)
|
||||
if serviceMgr != nil {
|
||||
serviceMgr.StartExposeReaper(context.Background())
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
|
||||
|
||||
mgmtProto.RegisterProxyServiceServer(gRPCAPIHandler, s.ReverseProxyGRPCServer())
|
||||
log.Info("ProxyService registered on gRPC server")
|
||||
|
||||
return gRPCAPIHandler
|
||||
}
|
||||
|
||||
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
|
||||
|
||||
@@ -21,7 +21,9 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
|
||||
@@ -38,8 +40,11 @@ func (s *BaseServer) JobManager() *job.Manager {
|
||||
|
||||
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
|
||||
return Create(s, func() integrated_validator.IntegratedValidator {
|
||||
// TODO: Replace
|
||||
integratedPeerValidator, err := integrations.NewIntegratedValidator(
|
||||
context.Background(),
|
||||
nil,
|
||||
s.Store(),
|
||||
s.PeersManager(),
|
||||
s.SettingsManager(),
|
||||
s.EventStore(),
|
||||
@@ -59,7 +64,7 @@ func (s *BaseServer) ProxyController() port_forwarding.Controller {
|
||||
|
||||
func (s *BaseServer) SecretsManager() grpc.SecretsManager {
|
||||
return Create(s, func() grpc.SecretsManager {
|
||||
secretsManager, err := grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.Config.TURNConfig, s.Config.Relay, s.SettingsManager(), s.GroupsManager())
|
||||
secretsManager, err := grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.Config.TURNConfig, s.Config.Relay, s.SettingsManager(), s.GroupsManager(), s.ExtendNetBirdConfig)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create secrets manager: %v", err)
|
||||
}
|
||||
@@ -122,7 +127,7 @@ func (s *BaseServer) EphemeralManager() ephemeral.Manager {
|
||||
|
||||
func (s *BaseServer) NetworkMapController() network_map.Controller {
|
||||
return Create(s, func() network_map.Controller {
|
||||
return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.DNSDomain(), s.ProxyController(), s.EphemeralManager(), s.Config)
|
||||
return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.DNSDomain(), s.ProxyController(), s.EphemeralManager(), s.Config, s.ExtendNetBirdConfig)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -145,3 +150,7 @@ func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
|
||||
func (s *BaseServer) DNSDomain() string {
|
||||
return s.dnsDomain
|
||||
}
|
||||
|
||||
func (s *BaseServer) ExtendNetBirdConfig(_ string, _ []string, config *proto.NetbirdConfig, _ *types.ExtraSettings) *proto.NetbirdConfig {
|
||||
return config
|
||||
}
|
||||
|
||||
@@ -57,13 +57,7 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
|
||||
|
||||
func (s *BaseServer) PermissionsManager() permissions.Manager {
|
||||
return Create(s, func() permissions.Manager {
|
||||
manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter())
|
||||
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
manager.SetAccountManager(s.AccountManager())
|
||||
})
|
||||
|
||||
return manager
|
||||
return permissions.NewManager(s.Store())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -153,7 +147,6 @@ func (s *BaseServer) IdpManager() idp.Manager {
|
||||
return idpManager
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -235,3 +228,7 @@ func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
||||
return &m
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) IsValidChildAccount(_ context.Context, _, _, _ string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -188,7 +188,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
log.WithContext(srvCtx).Infof("running gRPC backward compatibility server: %s", compatListener.Addr().String())
|
||||
}
|
||||
|
||||
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.Metrics().GetMeter())
|
||||
rootHandler := s.handlerFunc(srvCtx, s.GRPCServer(), s.APIHandler(), s.IDPHandler(), s.Metrics().GetMeter())
|
||||
switch {
|
||||
case s.certManager != nil:
|
||||
// a call to certManager.Listener() always creates a new listener so we do it once
|
||||
@@ -299,7 +299,7 @@ func (s *BaseServer) SetHandlerFunc(handler http.Handler) {
|
||||
log.Tracef("custom handler set successfully")
|
||||
}
|
||||
|
||||
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
|
||||
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, idpHandler http.Handler, meter metric.Meter) http.Handler {
|
||||
// Check if a custom handler was set (for multiplexing additional services)
|
||||
if customHandler, ok := s.GetContainer("customHandler"); ok {
|
||||
if handler, ok := customHandler.(http.Handler); ok {
|
||||
@@ -318,6 +318,8 @@ func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, ht
|
||||
gRPCHandler.ServeHTTP(writer, request)
|
||||
case request.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
|
||||
wsProxy.Handler().ServeHTTP(writer, request)
|
||||
case idpHandler != nil && strings.HasPrefix(request.URL.Path, "/oauth2"):
|
||||
idpHandler.ServeHTTP(writer, request)
|
||||
default:
|
||||
httpHandler.ServeHTTP(writer, request)
|
||||
}
|
||||
|
||||
@@ -10,8 +10,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -26,6 +24,8 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/sshauth"
|
||||
)
|
||||
|
||||
type ConfigExtender func(peerID string, peerGroups []string, config *proto.NetbirdConfig, extras *types.ExtraSettings) *proto.NetbirdConfig
|
||||
|
||||
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
||||
if config == nil {
|
||||
return nil
|
||||
@@ -127,7 +127,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
|
||||
return peerConfig
|
||||
}
|
||||
|
||||
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
|
||||
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64, configExtender ConfigExtender) *proto.SyncResponse {
|
||||
// IPv6 data in AllowedIPs and SourcePrefixes wildcard expansion depends on
|
||||
// whether the target peer supports IPv6. Routes and firewall rules are already
|
||||
// filtered at the source (network map builder).
|
||||
@@ -146,8 +146,10 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
}
|
||||
|
||||
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||
response.NetbirdConfig = extendedConfig
|
||||
if configExtender != nil {
|
||||
nbConfig = configExtender(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||
}
|
||||
response.NetbirdConfig = nbConfig
|
||||
|
||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||
|
||||
@@ -332,7 +334,6 @@ func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSource
|
||||
return result
|
||||
}
|
||||
|
||||
|
||||
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
|
||||
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
|
||||
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {
|
||||
|
||||
@@ -86,6 +86,8 @@ type Server struct {
|
||||
|
||||
reverseProxyManager rpservice.Manager
|
||||
reverseProxyMu sync.RWMutex
|
||||
|
||||
configExtender ConfigExtender
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
@@ -101,6 +103,7 @@ func NewServer(
|
||||
networkMapController network_map.Controller,
|
||||
oAuthConfigProvider idp.OAuthConfigProvider,
|
||||
sessionStore *auth.SessionStore,
|
||||
configExtender ConfigExtender,
|
||||
) (*Server, error) {
|
||||
if appMetrics != nil {
|
||||
// update gauge based on number of connected peers which is equal to open gRPC streams
|
||||
@@ -144,6 +147,7 @@ func NewServer(
|
||||
networkMapController: networkMapController,
|
||||
oAuthConfigProvider: oAuthConfigProvider,
|
||||
sessionStore: sessionStore,
|
||||
configExtender: configExtender,
|
||||
|
||||
loginFilter: newLoginFilter(),
|
||||
|
||||
@@ -932,7 +936,7 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
|
||||
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
|
||||
}
|
||||
|
||||
plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||
plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort, s.configExtender)
|
||||
|
||||
key, err := s.secretsManager.GetWGKey()
|
||||
if err != nil {
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -46,11 +45,12 @@ type TimeBasedAuthSecretsManager struct {
|
||||
turnCancelMap map[string]chan struct{}
|
||||
relayCancelMap map[string]chan struct{}
|
||||
wgKey wgtypes.Key
|
||||
configExtender ConfigExtender
|
||||
}
|
||||
|
||||
type Token auth.Token
|
||||
|
||||
func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) (*TimeBasedAuthSecretsManager, error) {
|
||||
func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager, configExtender ConfigExtender) (*TimeBasedAuthSecretsManager, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -65,6 +65,7 @@ func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager
|
||||
settingsManager: settingsManager,
|
||||
groupsManager: groupsManager,
|
||||
wgKey: key,
|
||||
configExtender: configExtender,
|
||||
}
|
||||
|
||||
if turnCfg != nil {
|
||||
@@ -286,6 +287,7 @@ func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, p
|
||||
log.WithContext(ctx).Errorf("failed to get peer groups: %v", err)
|
||||
}
|
||||
|
||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peerID, peerGroups, update.NetbirdConfig, extraSettings)
|
||||
update.NetbirdConfig = extendedConfig
|
||||
if m.configExtender != nil {
|
||||
update.NetbirdConfig = m.configExtender(peerID, peerGroups, update.NetbirdConfig, extraSettings)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
|
||||
Secret: secret,
|
||||
Turns: []*config.Host{TurnTestHost},
|
||||
TimeBasedCredentials: true,
|
||||
}, rc, settingsMockManager, groupsManager)
|
||||
}, rc, settingsMockManager, groupsManager, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
turnCredentials, err := tested.GenerateTurnToken()
|
||||
@@ -104,7 +104,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
||||
Secret: secret,
|
||||
Turns: []*config.Host{TurnTestHost},
|
||||
TimeBasedCredentials: true,
|
||||
}, rc, settingsMockManager, groupsManager)
|
||||
}, rc, settingsMockManager, groupsManager, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -208,7 +208,7 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
|
||||
Secret: secret,
|
||||
Turns: []*config.Host{TurnTestHost},
|
||||
TimeBasedCredentials: true,
|
||||
}, rc, settingsMockManager, groupsManager)
|
||||
}, rc, settingsMockManager, groupsManager, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
tested.SetupRefresh(context.Background(), "someAccountID", peer)
|
||||
|
||||
@@ -3186,7 +3186,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}, nil)
|
||||
manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -234,7 +234,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}, nil)
|
||||
|
||||
return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
}
|
||||
|
||||
@@ -15,15 +15,13 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxytoken"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
@@ -32,12 +30,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
|
||||
|
||||
nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
nbgroups "github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -56,17 +52,14 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const apiPrefix = "/api"
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
|
||||
func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter, isValidChildAccount middleware.IsValidChildAccountFunc) (http.Handler, error) {
|
||||
|
||||
// Register bypass paths for unauthenticated endpoints
|
||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||
@@ -100,25 +93,16 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
accountManager.GetUserFromUserAuth,
|
||||
rateLimiter,
|
||||
appMetrics.GetMeter(),
|
||||
isValidChildAccount,
|
||||
)
|
||||
|
||||
corsMiddleware := cors.AllowAll()
|
||||
|
||||
rootRouter := mux.NewRouter()
|
||||
metricsMiddleware := appMetrics.HTTPMiddleware()
|
||||
|
||||
prefix := apiPrefix
|
||||
router := rootRouter.PathPrefix(prefix).Subrouter()
|
||||
|
||||
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler)
|
||||
|
||||
if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, integratedValidator, appMetrics.GetMeter(), permissionsManager, peersManager, proxyController, settingsManager); err != nil {
|
||||
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
||||
}
|
||||
|
||||
// Check if embedded IdP is enabled for instance manager
|
||||
embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager)
|
||||
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP)
|
||||
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), idpManager)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create instance manager: %w", err)
|
||||
}
|
||||
@@ -154,10 +138,5 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
oauthHandler.RegisterEndpoints(router)
|
||||
}
|
||||
|
||||
// Mount embedded IdP handler at /oauth2 path if configured
|
||||
if embeddedIdpEnabled {
|
||||
rootRouter.PathPrefix("/oauth2").Handler(corsMiddleware.Handler(embeddedIdP.Handler()))
|
||||
}
|
||||
|
||||
return rootRouter, nil
|
||||
return router, nil
|
||||
}
|
||||
|
||||
@@ -11,8 +11,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
@@ -27,6 +25,8 @@ type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) err
|
||||
|
||||
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||
|
||||
type IsValidChildAccountFunc func(ctx context.Context, userID, accountID, childAccountID string) bool
|
||||
|
||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||
type AuthMiddleware struct {
|
||||
authManager serverauth.Manager
|
||||
@@ -35,6 +35,7 @@ type AuthMiddleware struct {
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||
rateLimiter *APIRateLimiter
|
||||
patUsageTracker *PATUsageTracker
|
||||
isValidChildAccount IsValidChildAccountFunc
|
||||
}
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
@@ -45,6 +46,7 @@ func NewAuthMiddleware(
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||
rateLimiter *APIRateLimiter,
|
||||
meter metric.Meter,
|
||||
isValidChildAccount IsValidChildAccountFunc,
|
||||
) *AuthMiddleware {
|
||||
var patUsageTracker *PATUsageTracker
|
||||
if meter != nil {
|
||||
@@ -62,6 +64,7 @@ func NewAuthMiddleware(
|
||||
getUserFromUserAuth: getUserFromUserAuth,
|
||||
rateLimiter: rateLimiter,
|
||||
patUsageTracker: patUsageTracker,
|
||||
isValidChildAccount: isValidChildAccount,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,7 +127,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
||||
}
|
||||
|
||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||
if integrations.IsValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||
if m.isValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||
userAuth.AccountId = impersonate[0]
|
||||
userAuth.IsChild = true
|
||||
}
|
||||
@@ -203,7 +206,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
}
|
||||
|
||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
||||
if integrations.IsValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||
if m.isValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
||||
userAuth.AccountId = impersonate[0]
|
||||
userAuth.IsChild = true
|
||||
}
|
||||
|
||||
@@ -211,6 +211,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
},
|
||||
disabledLimiter,
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handlerToTest := authMiddleware.Handler(nextHandler)
|
||||
@@ -270,6 +271,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -322,6 +324,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -365,6 +368,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -409,6 +413,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -473,6 +478,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -532,6 +538,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -587,6 +594,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
},
|
||||
NewAPIRateLimiter(rateLimitConfig),
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -687,6 +695,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
},
|
||||
disabledLimiter,
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
for _, tc := range tt {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
@@ -95,7 +96,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
}
|
||||
|
||||
requestBuffer := server.NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManager), &config.Config{})
|
||||
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManager), &config.Config{}, nil)
|
||||
am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false, cacheStore)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
@@ -135,7 +136,8 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
||||
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
@@ -224,7 +226,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
||||
}
|
||||
|
||||
requestBuffer := server.NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManager), &config.Config{})
|
||||
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManager), &config.Config{}, nil)
|
||||
am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false, cacheStore)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
@@ -264,7 +266,8 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
|
||||
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -92,7 +92,7 @@ func createManagerWithEmbeddedIdP(t testing.TB) (*DefaultAccountManager, *update
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, testStore)
|
||||
networkMapController := controller.NewController(ctx, testStore, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(testStore, peersManager), &config.Config{})
|
||||
networkMapController := controller.NewController(ctx, testStore, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(testStore, peersManager), &config.Config{}, nil)
|
||||
manager, err := BuildManager(ctx, &config.Config{}, testStore, networkMapController, job.NewJobManager(nil, testStore, peersManager), idpManager, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -376,7 +376,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
|
||||
return nil, nil, "", cleanup, err
|
||||
}
|
||||
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeralMgr, config)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeralMgr, config, nil)
|
||||
accountManager, err := BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "",
|
||||
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
|
||||
@@ -385,13 +385,13 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
|
||||
return nil, nil, "", cleanup, err
|
||||
}
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager, nil)
|
||||
if err != nil {
|
||||
cleanup()
|
||||
return nil, nil, "", cleanup, err
|
||||
}
|
||||
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController, nil, nil, nil)
|
||||
if err != nil {
|
||||
return nil, nil, "", cleanup, err
|
||||
}
|
||||
|
||||
@@ -216,7 +216,7 @@ func startServer(
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := server.NewAccountRequestBuffer(ctx, str)
|
||||
networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(str, peers.NewManager(str, permissionsManager)), config)
|
||||
networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(str, peers.NewManager(str, permissionsManager)), config, nil)
|
||||
|
||||
accountManager, err := server.BuildManager(
|
||||
context.Background(),
|
||||
@@ -241,7 +241,7 @@ func startServer(
|
||||
}
|
||||
|
||||
groupsManager := groups.NewManager(str, permissionsManager, accountManager)
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed creating secrets manager: %v", err)
|
||||
}
|
||||
@@ -257,6 +257,7 @@ func startServer(
|
||||
networkMapController,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed creating management server: %v", err)
|
||||
|
||||
@@ -803,7 +803,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}, nil)
|
||||
|
||||
return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
}
|
||||
|
||||
@@ -1179,7 +1179,7 @@ func TestToSyncResponse(t *testing.T) {
|
||||
}
|
||||
dnsCache := &cache.DNSConfigCache{}
|
||||
accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true}
|
||||
response := grpc.ToSyncResponse(context.Background(), config, config.HttpConfig, config.DeviceAuthorizationFlow, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort))
|
||||
response := grpc.ToSyncResponse(context.Background(), config, config.HttpConfig, config.DeviceAuthorizationFlow, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort), nil)
|
||||
|
||||
assert.NotNil(t, response)
|
||||
// assert peer config
|
||||
@@ -1299,7 +1299,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, s)
|
||||
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
|
||||
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}, nil)
|
||||
|
||||
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
assert.NoError(t, err)
|
||||
@@ -1390,7 +1390,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, s)
|
||||
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
|
||||
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}, nil)
|
||||
|
||||
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
assert.NoError(t, err)
|
||||
@@ -1549,7 +1549,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, s)
|
||||
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
|
||||
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}, nil)
|
||||
|
||||
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
assert.NoError(t, err)
|
||||
@@ -1634,7 +1634,7 @@ func Test_LoginPeer(t *testing.T) {
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, s)
|
||||
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
|
||||
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}, nil)
|
||||
|
||||
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -1299,7 +1299,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel.
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}, nil)
|
||||
|
||||
am, err := BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
|
||||
if err != nil {
|
||||
|
||||
@@ -3463,49 +3463,6 @@ func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength Lockin
|
||||
return allEphemeralPeers, nil
|
||||
}
|
||||
|
||||
// GetStaleEphemeralPeerIDsForAccount returns IDs of disconnected
|
||||
// ephemeral peers in the given account whose last_seen is strictly
|
||||
// older than olderThan.
|
||||
func (s *SqlStore) GetStaleEphemeralPeerIDsForAccount(ctx context.Context, accountID string, olderThan time.Time) ([]string, error) {
|
||||
var ids []string
|
||||
err := s.db.WithContext(ctx).
|
||||
Model(&nbpeer.Peer{}).
|
||||
Where("account_id = ? AND ephemeral = ? AND peer_status_connected = ? AND peer_status_last_seen < ?",
|
||||
accountID, true, false, olderThan).
|
||||
Pluck("id", &ids).Error
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to query stale ephemeral peers for account %s: %v", accountID, err)
|
||||
return nil, status.Errorf(status.Internal, "query stale ephemeral peers")
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// GetEphemeralAccountsLastDisconnect returns the latest peer_status_last_seen
|
||||
// per account across disconnected ephemeral peers. Returns one entry per
|
||||
// account that has at least one such peer.
|
||||
func (s *SqlStore) GetEphemeralAccountsLastDisconnect(ctx context.Context) (map[string]time.Time, error) {
|
||||
type row struct {
|
||||
AccountID string
|
||||
LastSeen time.Time
|
||||
}
|
||||
var rows []row
|
||||
err := s.db.WithContext(ctx).
|
||||
Model(&nbpeer.Peer{}).
|
||||
Select("account_id, MAX(peer_status_last_seen) AS last_seen").
|
||||
Where("ephemeral = ? AND peer_status_connected = ?", true, false).
|
||||
Group("account_id").
|
||||
Scan(&rows).Error
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to load ephemeral-account last disconnect map: %v", err)
|
||||
return nil, status.Errorf(status.Internal, "load ephemeral accounts")
|
||||
}
|
||||
out := make(map[string]time.Time, len(rows))
|
||||
for _, r := range rows {
|
||||
out[r.AccountID] = r.LastSeen
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// DeletePeer removes a peer from the store.
|
||||
func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID string) error {
|
||||
result := s.db.Delete(&nbpeer.Peer{}, accountAndIDQueryCondition, accountID, peerID)
|
||||
|
||||
@@ -165,15 +165,6 @@ type Store interface {
|
||||
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
|
||||
// GetStaleEphemeralPeerIDsForAccount returns the IDs of disconnected
|
||||
// ephemeral peers whose last_seen is strictly older than olderThan,
|
||||
// scoped to a single account. Used by the per-account cleanup sweep.
|
||||
GetStaleEphemeralPeerIDsForAccount(ctx context.Context, accountID string, olderThan time.Time) ([]string, error)
|
||||
// GetEphemeralAccountsLastDisconnect returns, for every account that
|
||||
// has at least one disconnected ephemeral peer, the most recent
|
||||
// last_seen across that account's disconnected ephemeral peers. Used
|
||||
// to reconstruct the per-account cleanup tracker after a restart.
|
||||
GetEphemeralAccountsLastDisconnect(ctx context.Context) (map[string]time.Time, error)
|
||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
// MarkPeerConnectedIfNewerSession sets the peer to connected with the
|
||||
|
||||
@@ -1376,36 +1376,6 @@ func (mr *MockStoreMockRecorder) GetAllEphemeralPeers(ctx, lockStrength interfac
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllEphemeralPeers", reflect.TypeOf((*MockStore)(nil).GetAllEphemeralPeers), ctx, lockStrength)
|
||||
}
|
||||
|
||||
// GetStaleEphemeralPeerIDsForAccount mocks base method.
|
||||
func (m *MockStore) GetStaleEphemeralPeerIDsForAccount(ctx context.Context, accountID string, olderThan time.Time) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetStaleEphemeralPeerIDsForAccount", ctx, accountID, olderThan)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetStaleEphemeralPeerIDsForAccount indicates an expected call of GetStaleEphemeralPeerIDsForAccount.
|
||||
func (mr *MockStoreMockRecorder) GetStaleEphemeralPeerIDsForAccount(ctx, accountID, olderThan interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStaleEphemeralPeerIDsForAccount", reflect.TypeOf((*MockStore)(nil).GetStaleEphemeralPeerIDsForAccount), ctx, accountID, olderThan)
|
||||
}
|
||||
|
||||
// GetEphemeralAccountsLastDisconnect mocks base method.
|
||||
func (m *MockStore) GetEphemeralAccountsLastDisconnect(ctx context.Context) (map[string]time.Time, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetEphemeralAccountsLastDisconnect", ctx)
|
||||
ret0, _ := ret[0].(map[string]time.Time)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetEphemeralAccountsLastDisconnect indicates an expected call of GetEphemeralAccountsLastDisconnect.
|
||||
func (mr *MockStoreMockRecorder) GetEphemeralAccountsLastDisconnect(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEphemeralAccountsLastDisconnect", reflect.TypeOf((*MockStore)(nil).GetEphemeralAccountsLastDisconnect), ctx)
|
||||
}
|
||||
|
||||
// GetAllProxyAccessTokens mocks base method.
|
||||
func (m *MockStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types2.ProxyAccessToken, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
)
|
||||
|
||||
// EphemeralPeersMetrics tracks the ephemeral peer cleanup pipeline: how
|
||||
// many accounts are currently being tracked for cleanup, how many sweep
|
||||
// runs deleted at least one peer, how many peers have been removed, and
|
||||
// how many delete batches failed.
|
||||
// many peers are currently scheduled for deletion, how many tick runs
|
||||
// the cleaner has performed, how many peers it has removed, and how
|
||||
// many delete batches failed.
|
||||
type EphemeralPeersMetrics struct {
|
||||
ctx context.Context
|
||||
|
||||
@@ -21,16 +21,16 @@ type EphemeralPeersMetrics struct {
|
||||
|
||||
// NewEphemeralPeersMetrics constructs the ephemeral cleanup counters.
|
||||
func NewEphemeralPeersMetrics(ctx context.Context, meter metric.Meter) (*EphemeralPeersMetrics, error) {
|
||||
pending, err := meter.Int64UpDownCounter("management.ephemeral.accounts.tracked",
|
||||
pending, err := meter.Int64UpDownCounter("management.ephemeral.peers.pending",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of accounts currently tracked for ephemeral peer cleanup"))
|
||||
metric.WithDescription("Number of ephemeral peers currently waiting to be cleaned up"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cleanupRuns, err := meter.Int64Counter("management.ephemeral.cleanup.runs.counter",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Number of ephemeral cleanup sweeps that deleted at least one peer"))
|
||||
metric.WithDescription("Number of ephemeral cleanup ticks that processed at least one peer"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -61,8 +61,7 @@ func NewEphemeralPeersMetrics(ctx context.Context, meter metric.Meter) (*Ephemer
|
||||
// All methods are nil-receiver safe so callers that haven't wired metrics
|
||||
// (tests, self-hosted with metrics off) can invoke them unconditionally.
|
||||
|
||||
// IncPending bumps the tracked-accounts gauge when a new account
|
||||
// becomes eligible for ephemeral cleanup tracking.
|
||||
// IncPending bumps the pending gauge when a peer is added to the cleanup list.
|
||||
func (m *EphemeralPeersMetrics) IncPending() {
|
||||
if m == nil {
|
||||
return
|
||||
@@ -70,8 +69,8 @@ func (m *EphemeralPeersMetrics) IncPending() {
|
||||
m.pending.Add(m.ctx, 1)
|
||||
}
|
||||
|
||||
// AddPending bumps the tracked-accounts gauge by n — used at startup
|
||||
// when the catch-up query seeds the tracker.
|
||||
// AddPending bumps the pending gauge by n — used at startup when the
|
||||
// initial set of ephemeral peers is loaded from the store.
|
||||
func (m *EphemeralPeersMetrics) AddPending(n int64) {
|
||||
if m == nil || n <= 0 {
|
||||
return
|
||||
@@ -79,8 +78,9 @@ func (m *EphemeralPeersMetrics) AddPending(n int64) {
|
||||
m.pending.Add(m.ctx, n)
|
||||
}
|
||||
|
||||
// DecPending decreases the tracked-accounts gauge when an account is
|
||||
// dropped from the tracker (no more disconnects to chase).
|
||||
// DecPending decreases the pending gauge — used both when a peer reconnects
|
||||
// before its deadline (removed from the list) and when a cleanup tick
|
||||
// actually deletes it.
|
||||
func (m *EphemeralPeersMetrics) DecPending(n int64) {
|
||||
if m == nil || n <= 0 {
|
||||
return
|
||||
|
||||
@@ -103,7 +103,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(ctx, peersManger, settingsManagerMock, eventStore, cacheStore)
|
||||
ia, _ := integrations.NewIntegratedValidator(ctx, nil, nil, peersManger, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
||||
require.NoError(t, err)
|
||||
@@ -126,7 +126,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManger), config)
|
||||
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManger), config, nil)
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -134,11 +134,11 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
|
||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user