diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 5477653a2..b9ff35945 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -15,6 +15,8 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" clientProto "github.com/netbirdio/netbird/client/proto" @@ -24,8 +26,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -116,15 +116,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), config) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config) accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) } - secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + t.Fatal(err) + } + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 9252ce13e..5ab21e3e1 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -30,11 +30,12 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" @@ -54,7 +55,6 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -1628,14 +1628,17 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) - networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) + networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config) accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, "", err } - secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + return nil, "", err + } + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index f8592bc7a..5f28a2664 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -17,11 +17,12 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -35,7 +36,6 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -316,14 +316,17 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) peersUpdateManager := update_channel.NewPeersUpdateManager(metrics) - networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) + networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config) accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { return nil, "", err } - secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + return nil, "", err + } + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController) if err != nil { return nil, "", err } diff --git a/go.mod b/go.mod index a72d09dc3..91e587c32 100644 --- a/go.mod +++ b/go.mod @@ -64,7 +64,7 @@ require ( github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 + github.com/netbirdio/management-integrations/integrations v0.0.0-20251114143509-4eff2374da63 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 011a5f199..98d395ad1 100644 --- a/go.sum +++ b/go.sum @@ -368,8 +368,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251114143509-4eff2374da63 h1:ecs4GMANgObopiy29zMmz2dIdOTJMwezUbrFy+zfSwE= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251114143509-4eff2374da63/go.mod h1:JIWpjbCgDvZIt45C9vYpikU2gRXeDWrN7SiyGYd3Qrc= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 49bb9cef3..022ea774c 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -19,6 +19,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/account" @@ -42,6 +43,7 @@ type Controller struct { accountManagerMetrics *telemetry.AccountManagerMetrics peersUpdateManager network_map.PeersUpdateManager settingsManager settings.Manager + EphemeralPeersManager ephemeral.Manager accountUpdateLocks sync.Map sendAccountUpdateLocks sync.Map @@ -70,7 +72,7 @@ type bufferUpdate struct { var _ network_map.Controller = (*Controller)(nil) -func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, config *config.Config) *Controller { +func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller { nMetrics, err := newMetrics(metrics.UpdateChannelMetrics()) if err != nil { log.Fatal(fmt.Errorf("error creating metrics: %w", err)) @@ -99,7 +101,8 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App dnsDomain: dnsDomain, config: config, - proxyController: proxyController, + proxyController: proxyController, + EphemeralPeersManager: ephemeralPeersManager, holder: types.NewHolder(), expNewNetworkMap: newNetworkMapBuilder, @@ -107,6 +110,31 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App } } +func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) { + peer, err := c.repo.GetPeerByID(ctx, accountID, peerID) + if err != nil { + return nil, fmt.Errorf("failed to get peer %s: %v", peerID, err) + } + + c.EphemeralPeersManager.OnPeerConnected(ctx, peer) + + return c.peersUpdateManager.CreateChannel(ctx, peerID), nil +} + +func (c *Controller) OnPeerDisconnected(ctx context.Context, accountID string, peerID string) { + c.peersUpdateManager.CloseChannel(ctx, peerID) + peer, err := c.repo.GetPeerByID(ctx, accountID, peerID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get peer %s: %v", peerID, err) + return + } + c.EphemeralPeersManager.OnPeerDisconnected(ctx, peer) +} + +func (c *Controller) CountStreams() int { + return c.peersUpdateManager.CountStreams() +} + func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error { log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) var ( @@ -366,38 +394,6 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str return nil } -func (c *Controller) DeletePeer(ctx context.Context, accountId string, peerId string) error { - network, err := c.repo.GetAccountNetwork(ctx, accountId) - if err != nil { - return err - } - - peers, err := c.repo.GetAccountPeers(ctx, accountId) - if err != nil { - return err - } - - dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion) - c.peersUpdateManager.SendUpdate(ctx, peerId, &network_map.UpdateMessage{ - Update: &proto.SyncResponse{ - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - NetworkMap: &proto.NetworkMap{ - Serial: network.CurrentSerial(), - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - FirewallRules: []*proto.FirewallRule{}, - FirewallRulesIsEmpty: true, - DNSConfig: &proto.DNSConfig{ - ForwarderPort: dnsFwdPort, - }, - }, - }, - }) - c.peersUpdateManager.CloseChannel(ctx, peerId) - return nil -} - func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { if isRequiresApproval { network, err := c.repo.GetAccountNetwork(ctx, accountID) @@ -698,35 +694,83 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t return false, nil } -func (c *Controller) OnPeerUpdated(accountId string, peer *nbpeer.Peer) { - c.UpdatePeerInNetworkMapCache(accountId, peer) - _ = c.bufferSendUpdateAccountPeers(context.Background(), accountId) +func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error { + peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs) + if err != nil { + return fmt.Errorf("failed to get peers by ids: %w", err) + } + + for _, peer := range peers { + c.UpdatePeerInNetworkMapCache(accountID, peer) + } + + err = c.bufferSendUpdateAccountPeers(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err) + } + + return nil } -func (c *Controller) OnPeerAdded(ctx context.Context, accountID string, peerID string) error { - if c.experimentalNetworkMap(accountID) { - account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return err - } +func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error { + for _, peerID := range peerIDs { + if c.experimentalNetworkMap(accountID) { + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return err + } - err = c.onPeerAddedUpdNetworkMapCache(account, peerID) - if err != nil { - return err + err = c.onPeerAddedUpdNetworkMapCache(account, peerID) + if err != nil { + return err + } } } return c.bufferSendUpdateAccountPeers(ctx, accountID) } -func (c *Controller) OnPeerDeleted(ctx context.Context, accountID string, peerID string) error { - if c.experimentalNetworkMap(accountID) { - account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return err - } - err = c.onPeerDeletedUpdNetworkMapCache(account, peerID) - if err != nil { - return err +func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error { + network, err := c.repo.GetAccountNetwork(ctx, accountID) + if err != nil { + return err + } + + peers, err := c.repo.GetAccountPeers(ctx, accountID) + if err != nil { + return err + } + + dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion) + for _, peerID := range peerIDs { + c.peersUpdateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + NetworkMap: &proto.NetworkMap{ + Serial: network.CurrentSerial(), + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + FirewallRules: []*proto.FirewallRule{}, + FirewallRulesIsEmpty: true, + DNSConfig: &proto.DNSConfig{ + ForwarderPort: dnsFwdPort, + }, + }, + }, + }) + c.peersUpdateManager.CloseChannel(ctx, peerID) + + if c.experimentalNetworkMap(accountID) { + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err) + continue + } + err = c.onPeerDeletedUpdNetworkMapCache(account, peerID) + if err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err) + continue + } } } @@ -778,10 +822,6 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N return networkMap, nil } -func (c *Controller) DisconnectPeers(ctx context.Context, peerIDs []string) { +func (c *Controller) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) { c.peersUpdateManager.CloseChannels(ctx, peerIDs) } - -func (c *Controller) IsConnected(peerID string) bool { - return c.peersUpdateManager.HasChannel(peerID) -} diff --git a/management/internals/controllers/network_map/controller/repository.go b/management/internals/controllers/network_map/controller/repository.go index 44144263b..3ed51a5c3 100644 --- a/management/internals/controllers/network_map/controller/repository.go +++ b/management/internals/controllers/network_map/controller/repository.go @@ -12,6 +12,8 @@ type Repository interface { GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error) GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) + GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error) + GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) } type repository struct { @@ -37,3 +39,11 @@ func (r *repository) GetAccountPeers(ctx context.Context, accountID string) ([]* func (r *repository) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) { return r.store.GetAccountByPeerID(ctx, peerID) } + +func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error) { + return r.store.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, peerIDs) +} + +func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) { + return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) +} diff --git a/management/internals/controllers/network_map/interface.go b/management/internals/controllers/network_map/interface.go index 6f893ce79..b1de7d017 100644 --- a/management/internals/controllers/network_map/interface.go +++ b/management/internals/controllers/network_map/interface.go @@ -28,12 +28,12 @@ type Controller interface { GetDNSDomain(settings *types.Settings) string StartWarmup(context.Context) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) + CountStreams() int - DeletePeer(ctx context.Context, accountId string, peerId string) error - - OnPeerUpdated(accountId string, peer *nbpeer.Peer) - OnPeerAdded(ctx context.Context, accountID string, peerID string) error - OnPeerDeleted(ctx context.Context, accountID string, peerID string) error - DisconnectPeers(ctx context.Context, peerIDs []string) - IsConnected(peerID string) bool + OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error + OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error + OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error + DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) + OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error) + OnPeerDisconnected(ctx context.Context, accountID string, peerID string) } diff --git a/management/internals/controllers/network_map/interface_mock.go b/management/internals/controllers/network_map/interface_mock.go index aaa093e47..5a98eefa8 100644 --- a/management/internals/controllers/network_map/interface_mock.go +++ b/management/internals/controllers/network_map/interface_mock.go @@ -57,30 +57,30 @@ func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID an return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID) } -// DeletePeer mocks base method. -func (m *MockController) DeletePeer(ctx context.Context, accountId, peerId string) error { +// CountStreams mocks base method. +func (m *MockController) CountStreams() int { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeletePeer", ctx, accountId, peerId) - ret0, _ := ret[0].(error) + ret := m.ctrl.Call(m, "CountStreams") + ret0, _ := ret[0].(int) return ret0 } -// DeletePeer indicates an expected call of DeletePeer. -func (mr *MockControllerMockRecorder) DeletePeer(ctx, accountId, peerId any) *gomock.Call { +// CountStreams indicates an expected call of CountStreams. +func (mr *MockControllerMockRecorder) CountStreams() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeer", reflect.TypeOf((*MockController)(nil).DeletePeer), ctx, accountId, peerId) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountStreams", reflect.TypeOf((*MockController)(nil).CountStreams)) } // DisconnectPeers mocks base method. -func (m *MockController) DisconnectPeers(ctx context.Context, peerIDs []string) { +func (m *MockController) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) { m.ctrl.T.Helper() - m.ctrl.Call(m, "DisconnectPeers", ctx, peerIDs) + m.ctrl.Call(m, "DisconnectPeers", ctx, accountId, peerIDs) } // DisconnectPeers indicates an expected call of DisconnectPeers. -func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, peerIDs any) *gomock.Call { +func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, accountId, peerIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, peerIDs) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, accountId, peerIDs) } // GetDNSDomain mocks base method. @@ -130,58 +130,73 @@ func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApp return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p) } -// IsConnected mocks base method. -func (m *MockController) IsConnected(peerID string) bool { +// OnPeerConnected mocks base method. +func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IsConnected", peerID) - ret0, _ := ret[0].(bool) - return ret0 + ret := m.ctrl.Call(m, "OnPeerConnected", ctx, accountID, peerID) + ret0, _ := ret[0].(chan *UpdateMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// IsConnected indicates an expected call of IsConnected. -func (mr *MockControllerMockRecorder) IsConnected(peerID any) *gomock.Call { +// OnPeerConnected indicates an expected call of OnPeerConnected. +func (mr *MockControllerMockRecorder) OnPeerConnected(ctx, accountID, peerID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockController)(nil).IsConnected), peerID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerConnected", reflect.TypeOf((*MockController)(nil).OnPeerConnected), ctx, accountID, peerID) } -// OnPeerAdded mocks base method. -func (m *MockController) OnPeerAdded(ctx context.Context, accountID, peerID string) error { +// OnPeerDisconnected mocks base method. +func (m *MockController) OnPeerDisconnected(ctx context.Context, accountID, peerID string) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OnPeerAdded", ctx, accountID, peerID) + m.ctrl.Call(m, "OnPeerDisconnected", ctx, accountID, peerID) +} + +// OnPeerDisconnected indicates an expected call of OnPeerDisconnected. +func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDisconnected", reflect.TypeOf((*MockController)(nil).OnPeerDisconnected), ctx, accountID, peerID) +} + +// OnPeersAdded mocks base method. +func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs) ret0, _ := ret[0].(error) return ret0 } -// OnPeerAdded indicates an expected call of OnPeerAdded. -func (mr *MockControllerMockRecorder) OnPeerAdded(ctx, accountID, peerID any) *gomock.Call { +// OnPeersAdded indicates an expected call of OnPeersAdded. +func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerAdded", reflect.TypeOf((*MockController)(nil).OnPeerAdded), ctx, accountID, peerID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs) } -// OnPeerDeleted mocks base method. -func (m *MockController) OnPeerDeleted(ctx context.Context, accountID, peerID string) error { +// OnPeersDeleted mocks base method. +func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OnPeerDeleted", ctx, accountID, peerID) + ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs) ret0, _ := ret[0].(error) return ret0 } -// OnPeerDeleted indicates an expected call of OnPeerDeleted. -func (mr *MockControllerMockRecorder) OnPeerDeleted(ctx, accountID, peerID any) *gomock.Call { +// OnPeersDeleted indicates an expected call of OnPeersDeleted. +func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDeleted", reflect.TypeOf((*MockController)(nil).OnPeerDeleted), ctx, accountID, peerID) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs) } -// OnPeerUpdated mocks base method. -func (m *MockController) OnPeerUpdated(accountId string, peer *peer.Peer) { +// OnPeersUpdated mocks base method. +func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "OnPeerUpdated", accountId, peer) + ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs) + ret0, _ := ret[0].(error) + return ret0 } -// OnPeerUpdated indicates an expected call of OnPeerUpdated. -func (mr *MockControllerMockRecorder) OnPeerUpdated(accountId, peer any) *gomock.Call { +// OnPeersUpdated indicates an expected call of OnPeersUpdated. +func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerUpdated", reflect.TypeOf((*MockController)(nil).OnPeerUpdated), accountId, peer) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs) } // StartWarmup mocks base method. diff --git a/management/server/peers/ephemeral/interface.go b/management/internals/modules/peers/ephemeral/interface.go similarity index 83% rename from management/server/peers/ephemeral/interface.go rename to management/internals/modules/peers/ephemeral/interface.go index a1605b3b9..8fe25435c 100644 --- a/management/server/peers/ephemeral/interface.go +++ b/management/internals/modules/peers/ephemeral/interface.go @@ -2,10 +2,15 @@ package ephemeral import ( "context" + "time" nbpeer "github.com/netbirdio/netbird/management/server/peer" ) +const ( + EphemeralLifeTime = 10 * time.Minute +) + type Manager interface { LoadInitialPeers(ctx context.Context) Stop() diff --git a/management/server/peers/ephemeral/manager/ephemeral.go b/management/internals/modules/peers/ephemeral/manager/ephemeral.go similarity index 85% rename from management/server/peers/ephemeral/manager/ephemeral.go rename to management/internals/modules/peers/ephemeral/manager/ephemeral.go index 062ba69d2..15119045b 100644 --- a/management/server/peers/ephemeral/manager/ephemeral.go +++ b/management/internals/modules/peers/ephemeral/manager/ephemeral.go @@ -7,14 +7,15 @@ import ( log "github.com/sirupsen/logrus" - nbAccount "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" ) const ( - ephemeralLifeTime = 10 * time.Minute // cleanupWindow is the time window to wait after nearest peer deadline to start the cleanup procedure. cleanupWindow = 1 * time.Minute ) @@ -33,11 +34,11 @@ type ephemeralPeer struct { // todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it // in worst case we will get invalid error message in this manager. -// EphemeralManager keep a list of ephemeral peers. After ephemeralLifeTime inactivity the peer will be deleted +// EphemeralManager keep a list of ephemeral peers. After EphemeralLifeTime inactivity the peer will be deleted // automatically. Inactivity means the peer disconnected from the Management server. type EphemeralManager struct { - store store.Store - accountManager nbAccount.Manager + store store.Store + peersManager peers.Manager headPeer *ephemeralPeer tailPeer *ephemeralPeer @@ -49,12 +50,12 @@ type EphemeralManager struct { } // NewEphemeralManager instantiate new EphemeralManager -func NewEphemeralManager(store store.Store, accountManager nbAccount.Manager) *EphemeralManager { +func NewEphemeralManager(store store.Store, peersManager peers.Manager) *EphemeralManager { return &EphemeralManager{ - store: store, - accountManager: accountManager, + store: store, + peersManager: peersManager, - lifeTime: ephemeralLifeTime, + lifeTime: ephemeral.EphemeralLifeTime, cleanupWindow: cleanupWindow, } } @@ -106,7 +107,7 @@ func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Pee } // OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer -// is inactive it will be deleted after the ephemeralLifeTime period. +// is inactive it will be deleted after the EphemeralLifeTime period. func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) { if !peer.Ephemeral { return @@ -180,20 +181,18 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { e.peersLock.Unlock() - bufferAccountCall := make(map[string]struct{}) - + peerIDsPerAccount := make(map[string][]string) for id, p := range deletePeers { - log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id) - err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator) + peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id) + } + + for accountID, peerIDs := range peerIDsPerAccount { + log.WithContext(ctx).Debugf("delete ephemeral peers for account: %s", accountID) + err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true) if err != nil { log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err) - } else { - bufferAccountCall[p.accountID] = struct{}{} } } - for accountID := range bufferAccountCall { - e.accountManager.BufferUpdateAccountPeers(ctx, accountID) - } } func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) { diff --git a/management/server/peers/ephemeral/manager/ephemeral_test.go b/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go similarity index 69% rename from management/server/peers/ephemeral/manager/ephemeral_test.go rename to management/internals/modules/peers/ephemeral/manager/ephemeral_test.go index fc7525c29..9d3ed246a 100644 --- a/management/server/peers/ephemeral/manager/ephemeral_test.go +++ b/management/internals/modules/peers/ephemeral/manager/ephemeral_test.go @@ -7,10 +7,13 @@ import ( "testing" "time" + "github.com/golang/mock/gomock" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" nbAccount "github.com/netbirdio/netbird/management/server/account" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" @@ -91,17 +94,27 @@ func TestNewManager(t *testing.T) { } store := &MockStore{} - am := MockAccountManager{ - store: store, - } + ctrl := gomock.NewController(t) + peersManager := peers.NewMockManager(ctrl) numberOfPeers := 5 numberOfEphemeralPeers := 3 seedPeers(store, numberOfPeers, numberOfEphemeralPeers) - mgr := NewEphemeralManager(store, &am) + // Expect DeletePeers to be called for ephemeral peers + peersManager.EXPECT(). + DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true). + DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + for _, peerID := range peerIDs { + delete(store.account.Peers, peerID) + } + return nil + }). + AnyTimes() + + mgr := NewEphemeralManager(store, peersManager) mgr.loadEphemeralPeers(context.Background()) - startTime = startTime.Add(ephemeralLifeTime + 1) + startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1) mgr.cleanup(context.Background()) if len(store.account.Peers) != numberOfPeers { @@ -119,19 +132,29 @@ func TestNewManagerPeerConnected(t *testing.T) { } store := &MockStore{} - am := MockAccountManager{ - store: store, - } + ctrl := gomock.NewController(t) + peersManager := peers.NewMockManager(ctrl) numberOfPeers := 5 numberOfEphemeralPeers := 3 seedPeers(store, numberOfPeers, numberOfEphemeralPeers) - mgr := NewEphemeralManager(store, &am) + // Expect DeletePeers to be called for ephemeral peers (except the connected one) + peersManager.EXPECT(). + DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true). + DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + for _, peerID := range peerIDs { + delete(store.account.Peers, peerID) + } + return nil + }). + AnyTimes() + + mgr := NewEphemeralManager(store, peersManager) mgr.loadEphemeralPeers(context.Background()) mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) - startTime = startTime.Add(ephemeralLifeTime + 1) + startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1) mgr.cleanup(context.Background()) expected := numberOfPeers + 1 @@ -150,15 +173,25 @@ func TestNewManagerPeerDisconnected(t *testing.T) { } store := &MockStore{} - am := MockAccountManager{ - store: store, - } + ctrl := gomock.NewController(t) + peersManager := peers.NewMockManager(ctrl) numberOfPeers := 5 numberOfEphemeralPeers := 3 seedPeers(store, numberOfPeers, numberOfEphemeralPeers) - mgr := NewEphemeralManager(store, &am) + // Expect DeletePeers to be called for the one disconnected peer + peersManager.EXPECT(). + DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true). + DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + for _, peerID := range peerIDs { + delete(store.account.Peers, peerID) + } + return nil + }). + AnyTimes() + + mgr := NewEphemeralManager(store, peersManager) mgr.loadEphemeralPeers(context.Background()) for _, v := range store.account.Peers { mgr.OnPeerConnected(context.Background(), v) @@ -166,7 +199,7 @@ func TestNewManagerPeerDisconnected(t *testing.T) { } mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) - startTime = startTime.Add(ephemeralLifeTime + 1) + startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1) mgr.cleanup(context.Background()) expected := numberOfPeers + numberOfEphemeralPeers - 1 @@ -181,25 +214,63 @@ func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) { testLifeTime = 1 * time.Second testCleanupWindow = 100 * time.Millisecond ) + + t.Cleanup(func() { + timeNow = time.Now + }) + startTime := time.Now() + timeNow = func() time.Time { + return startTime + } + mockStore := &MockStore{} + account := newAccountWithId(context.Background(), "account", "", "", false) + mockStore.account = account + + wg := &sync.WaitGroup{} + wg.Add(ephemeralPeers) mockAM := &MockAccountManager{ store: mockStore, + wg: wg, } - mockAM.wg = &sync.WaitGroup{} - mockAM.wg.Add(ephemeralPeers) - mgr := NewEphemeralManager(mockStore, mockAM) + + ctrl := gomock.NewController(t) + peersManager := peers.NewMockManager(ctrl) + + // Set up expectation that DeletePeers will be called once with all peer IDs + peersManager.EXPECT(). + DeletePeers(gomock.Any(), account.Id, gomock.Any(), gomock.Any(), true). + DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + // Simulate the actual deletion behavior + for _, peerID := range peerIDs { + err := mockAM.DeletePeer(ctx, accountID, peerID, userID) + if err != nil { + return err + } + } + mockAM.BufferUpdateAccountPeers(ctx, accountID) + return nil + }). + Times(1) + + mgr := NewEphemeralManager(mockStore, peersManager) mgr.lifeTime = testLifeTime mgr.cleanupWindow = testCleanupWindow - account := newAccountWithId(context.Background(), "account", "", "", false) - mockStore.account = account + // Add peers and disconnect them at slightly different times (within cleanup window) for i := range ephemeralPeers { p := &nbpeer.Peer{ID: fmt.Sprintf("peer-%d", i), AccountID: account.Id, Ephemeral: true} mockStore.account.Peers[p.ID] = p - time.Sleep(testCleanupWindow / ephemeralPeers) mgr.OnPeerDisconnected(context.Background(), p) + startTime = startTime.Add(testCleanupWindow / (ephemeralPeers * 2)) } - mockAM.wg.Wait() + + // Advance time past the lifetime to trigger cleanup + startTime = startTime.Add(testLifeTime + testCleanupWindow) + + // Wait for all deletions to complete + wg.Wait() + assert.Len(t, mockStore.account.Peers, 0, "all ephemeral peers should be cleaned up after the lifetime") assert.Equal(t, 1, mockAM.GetBufferUpdateCalls(account.Id), "buffer update should be called once") assert.Equal(t, ephemeralPeers, mockAM.GetDeletePeerCalls(), "should have deleted all peers") diff --git a/management/internals/modules/peers/manager.go b/management/internals/modules/peers/manager.go new file mode 100644 index 000000000..e82f19e63 --- /dev/null +++ b/management/internals/modules/peers/manager.go @@ -0,0 +1,162 @@ +package peers + +//go:generate go run github.com/golang/mock/mockgen -package peers -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod + +import ( + "context" + "fmt" + "time" + + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" + "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" +) + +type Manager interface { + GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) + GetPeerAccountID(ctx context.Context, peerID string) (string, error) + GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) + GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) + DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error + SetNetworkMapController(networkMapController network_map.Controller) + SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) + SetAccountManager(accountManager account.Manager) +} + +type managerImpl struct { + store store.Store + permissionsManager permissions.Manager + integratedPeerValidator integrated_validator.IntegratedValidator + accountManager account.Manager + + networkMapController network_map.Controller +} + +func NewManager(store store.Store, permissionsManager permissions.Manager) Manager { + return &managerImpl{ + store: store, + permissionsManager: permissionsManager, + } +} + +func (m *managerImpl) SetNetworkMapController(networkMapController network_map.Controller) { + m.networkMapController = networkMapController +} + +func (m *managerImpl) SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) { + m.integratedPeerValidator = integratedPeerValidator +} + +func (m *managerImpl) SetAccountManager(accountManager account.Manager) { + m.accountManager = accountManager +} + +func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) { + allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read) + if err != nil { + return nil, fmt.Errorf("failed to validate user permissions: %w", err) + } + + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) +} + +func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) { + allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read) + if err != nil { + return nil, fmt.Errorf("failed to validate user permissions: %w", err) + } + + if !allowed { + return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID) + } + + return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") +} + +func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) { + return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID) +} + +func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) { + return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs) +} + +func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return err + } + dnsDomain := m.networkMapController.GetDNSDomain(settings) + + for _, peerID := range peerIDs { + var eventsToStore []func() + err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + return err + } + + if checkConnected && (peer.Status.Connected || peer.Status.LastSeen.After(time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)))) { + return nil + } + + if err := transaction.RemovePeerFromAllGroups(ctx, peerID); err != nil { + return fmt.Errorf("failed to remove peer %s from groups", peerID) + } + + if err := m.integratedPeerValidator.PeerDeleted(ctx, accountID, peerID, settings.Extra); err != nil { + return err + } + + peerPolicyRules, err := transaction.GetPolicyRulesByResourceID(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + return err + } + for _, rule := range peerPolicyRules { + policy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, rule.PolicyID) + if err != nil { + return err + } + + err = transaction.DeletePolicy(ctx, accountID, rule.PolicyID) + if err != nil { + return err + } + + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) + }) + } + + if err = transaction.DeletePeer(ctx, accountID, peerID); err != nil { + return err + } + + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) + }) + + return nil + }) + if err != nil { + return err + } + for _, event := range eventsToStore { + event() + } + } + + return nil +} diff --git a/management/server/peers/manager_mock.go b/management/internals/modules/peers/manager_mock.go similarity index 55% rename from management/server/peers/manager_mock.go rename to management/internals/modules/peers/manager_mock.go index 994f8346b..2e3651e88 100644 --- a/management/server/peers/manager_mock.go +++ b/management/internals/modules/peers/manager_mock.go @@ -9,6 +9,9 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + network_map "github.com/netbirdio/netbird/management/internals/controllers/network_map" + account "github.com/netbirdio/netbird/management/server/account" + integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" peer "github.com/netbirdio/netbird/management/server/peer" ) @@ -35,6 +38,20 @@ func (m *MockManager) EXPECT() *MockManagerMockRecorder { return m.recorder } +// DeletePeers mocks base method. +func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeletePeers", ctx, accountID, peerIDs, userID, checkConnected) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeletePeers indicates an expected call of DeletePeers. +func (mr *MockManagerMockRecorder) DeletePeers(ctx, accountID, peerIDs, userID, checkConnected interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeers", reflect.TypeOf((*MockManager)(nil).DeletePeers), ctx, accountID, peerIDs, userID, checkConnected) +} + // GetAllPeers mocks base method. func (m *MockManager) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) { m.ctrl.T.Helper() @@ -94,3 +111,39 @@ func (mr *MockManagerMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupsIDs mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockManager)(nil).GetPeersByGroupIDs), ctx, accountID, groupsIDs) } + +// SetAccountManager mocks base method. +func (m *MockManager) SetAccountManager(accountManager account.Manager) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAccountManager", accountManager) +} + +// SetAccountManager indicates an expected call of SetAccountManager. +func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager) +} + +// SetIntegratedPeerValidator mocks base method. +func (m *MockManager) SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetIntegratedPeerValidator", integratedPeerValidator) +} + +// SetIntegratedPeerValidator indicates an expected call of SetIntegratedPeerValidator. +func (mr *MockManagerMockRecorder) SetIntegratedPeerValidator(integratedPeerValidator interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetIntegratedPeerValidator", reflect.TypeOf((*MockManager)(nil).SetIntegratedPeerValidator), integratedPeerValidator) +} + +// SetNetworkMapController mocks base method. +func (m *MockManager) SetNetworkMapController(networkMapController network_map.Controller) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetNetworkMapController", networkMapController) +} + +// SetNetworkMapController indicates an expected call of SetNetworkMapController. +func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController) +} diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index eadd16c2d..37788e80e 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -57,7 +57,7 @@ func (s *BaseServer) Metrics() telemetry.AppMetrics { func (s *BaseServer) Store() store.Store { return Create(s, func() store.Store { - store, err := store.NewStore(context.Background(), s.config.StoreConfig.Engine, s.config.Datadir, s.Metrics(), false) + store, err := store.NewStore(context.Background(), s.Config.StoreConfig.Engine, s.Config.Datadir, s.Metrics(), false) if err != nil { log.Fatalf("failed to create store: %v", err) } @@ -73,17 +73,17 @@ func (s *BaseServer) EventStore() activity.Store { log.Fatalf("failed to initialize integration metrics: %v", err) } - eventStore, key, err := integrations.InitEventStore(context.Background(), s.config.Datadir, s.config.DataStoreEncryptionKey, integrationMetrics) + eventStore, key, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics) if err != nil { log.Fatalf("failed to initialize event store: %v", err) } - if s.config.DataStoreEncryptionKey != key { - log.WithContext(context.Background()).Infof("update config with activity store key") - s.config.DataStoreEncryptionKey = key - err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.config) + if s.Config.DataStoreEncryptionKey != key { + log.WithContext(context.Background()).Infof("update Config with activity store key") + s.Config.DataStoreEncryptionKey = key + err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.Config) if err != nil { - log.Fatalf("failed to update config with activity store: %v", err) + log.Fatalf("failed to update Config with activity store: %v", err) } } @@ -103,14 +103,14 @@ func (s *BaseServer) APIHandler() http.Handler { func (s *BaseServer) GRPCServer() *grpc.Server { return Create(s, func() *grpc.Server { - trustedPeers := s.config.ReverseProxy.TrustedPeers + trustedPeers := s.Config.ReverseProxy.TrustedPeers defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")} if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) { log.WithContext(context.Background()).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.") trustedPeers = defaultTrustedPeers } - trustedHTTPProxies := s.config.ReverseProxy.TrustedHTTPProxies - trustedProxiesCount := s.config.ReverseProxy.TrustedHTTPProxiesCount + trustedHTTPProxies := s.Config.ReverseProxy.TrustedHTTPProxies + trustedProxiesCount := s.Config.ReverseProxy.TrustedHTTPProxiesCount if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 { log.WithContext(context.Background()).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " + "This is not recommended way to extract X-Forwarded-For. Consider using one of these options.") @@ -128,15 +128,15 @@ func (s *BaseServer) GRPCServer() *grpc.Server { grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor), } - if s.config.HttpConfig.LetsEncryptDomain != "" { - certManager, err := encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain) + if s.Config.HttpConfig.LetsEncryptDomain != "" { + certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain) if err != nil { log.Fatalf("failed to create certificate manager: %v", err) } transportCredentials := credentials.NewTLS(certManager.TLSConfig()) gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials)) - } else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" { - tlsConfig, err := loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey) + } else if s.Config.HttpConfig.CertFile != "" && s.Config.HttpConfig.CertKey != "" { + tlsConfig, err := loadTLSConfig(s.Config.HttpConfig.CertFile, s.Config.HttpConfig.CertKey) if err != nil { log.Fatalf("cannot load TLS credentials: %v", err) } @@ -145,7 +145,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server { } gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := nbgrpc.NewServer(s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController()) + srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController()) if err != nil { log.Fatalf("failed to create management server: %v", err) } diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index 38ec6fde6..3442c7646 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -9,17 +9,17 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map" nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers/ephemeral" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" ) func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager { - return Create(s, func() *update_channel.PeersUpdateManager { + return Create(s, func() network_map.PeersUpdateManager { return update_channel.NewPeersUpdateManager(s.Metrics()) }) } @@ -44,33 +44,37 @@ func (s *BaseServer) ProxyController() port_forwarding.Controller { }) } -func (s *BaseServer) SecretsManager() *grpc.TimeBasedAuthSecretsManager { - return Create(s, func() *grpc.TimeBasedAuthSecretsManager { - return grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager()) +func (s *BaseServer) SecretsManager() grpc.SecretsManager { + return Create(s, func() grpc.SecretsManager { + secretsManager, err := grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.Config.TURNConfig, s.Config.Relay, s.SettingsManager(), s.GroupsManager()) + if err != nil { + log.Fatalf("failed to create secrets manager: %v", err) + } + return secretsManager }) } func (s *BaseServer) AuthManager() auth.Manager { return Create(s, func() auth.Manager { return auth.NewManager(s.Store(), - s.config.HttpConfig.AuthIssuer, - s.config.HttpConfig.AuthAudience, - s.config.HttpConfig.AuthKeysLocation, - s.config.HttpConfig.AuthUserIDClaim, - s.config.GetAuthAudiences(), - s.config.HttpConfig.IdpSignKeyRefreshEnabled) + s.Config.HttpConfig.AuthIssuer, + s.Config.HttpConfig.AuthAudience, + s.Config.HttpConfig.AuthKeysLocation, + s.Config.HttpConfig.AuthUserIDClaim, + s.Config.GetAuthAudiences(), + s.Config.HttpConfig.IdpSignKeyRefreshEnabled) }) } func (s *BaseServer) EphemeralManager() ephemeral.Manager { return Create(s, func() ephemeral.Manager { - return manager.NewEphemeralManager(s.Store(), s.AccountManager()) + return manager.NewEphemeralManager(s.Store(), s.PeersManager()) }) } func (s *BaseServer) NetworkMapController() network_map.Controller { - return Create(s, func() *nmapcontroller.Controller { - return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.dnsDomain, s.ProxyController(), s.config) + return Create(s, func() network_map.Controller { + return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.DNSDomain(), s.ProxyController(), s.EphemeralManager(), s.Config) }) } @@ -79,3 +83,7 @@ func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer { return server.NewAccountRequestBuffer(context.Background(), s.Store()) }) } + +func (s *BaseServer) DNSDomain() string { + return s.dnsDomain +} diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 18a8427be..91ce50a79 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -6,6 +6,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/geolocation" @@ -14,7 +15,7 @@ import ( "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" - "github.com/netbirdio/netbird/management/server/peers" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/users" @@ -22,12 +23,12 @@ import ( func (s *BaseServer) GeoLocationManager() geolocation.Geolocation { return Create(s, func() geolocation.Geolocation { - geo, err := geolocation.NewGeolocation(context.Background(), s.config.Datadir, !s.disableGeoliteUpdate) + geo, err := geolocation.NewGeolocation(context.Background(), s.Config.Datadir, !s.disableGeoliteUpdate) if err != nil { log.Fatalf("could not initialize geolocation service: %v", err) } - log.Infof("geolocation service has been initialized from %s", s.config.Datadir) + log.Infof("geolocation service has been initialized from %s", s.Config.Datadir) return geo }) @@ -60,20 +61,22 @@ func (s *BaseServer) SettingsManager() settings.Manager { func (s *BaseServer) PeersManager() peers.Manager { return Create(s, func() peers.Manager { - return peers.NewManager(s.Store(), s.PermissionsManager()) + manager := peers.NewManager(s.Store(), s.PermissionsManager()) + s.AfterInit(func(s *BaseServer) { + manager.SetNetworkMapController(s.NetworkMapController()) + manager.SetIntegratedPeerValidator(s.IntegratedValidator()) + manager.SetAccountManager(s.AccountManager()) + }) + return manager }) } func (s *BaseServer) AccountManager() account.Manager { return Create(s, func() account.Manager { - accountManager, err := server.BuildManager(context.Background(), s.config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy) + accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy) if err != nil { log.Fatalf("failed to create account manager: %v", err) } - - s.AfterInit(func(s *BaseServer) { - accountManager.SetEphemeralManager(s.EphemeralManager()) - }) return accountManager }) } @@ -82,8 +85,8 @@ func (s *BaseServer) IdpManager() idp.Manager { return Create(s, func() idp.Manager { var idpManager idp.Manager var err error - if s.config.IdpManagerConfig != nil { - idpManager, err = idp.NewManager(context.Background(), *s.config.IdpManagerConfig, s.Metrics()) + if s.Config.IdpManagerConfig != nil { + idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics()) if err != nil { log.Fatalf("failed to create IDP manager: %v", err) } diff --git a/management/internals/server/server.go b/management/internals/server/server.go index ab1c2ebe7..a1b144dac 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -41,10 +41,10 @@ type Server interface { } // Server holds the HTTP BaseServer instance. -// Add any additional fields you need, such as database connections, config, etc. +// Add any additional fields you need, such as database connections, Config, etc. type BaseServer struct { - // config holds the server configuration - config *nbconfig.Config + // Config holds the server configuration + Config *nbconfig.Config // container of dependencies, each dependency is identified by a unique string. container map[string]any // AfterInit is a function that will be called after the server is initialized @@ -70,7 +70,7 @@ type BaseServer struct { // NewServer initializes and configures a new Server instance func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) *BaseServer { return &BaseServer{ - config: config, + Config: config, container: make(map[string]any), dnsDomain: dnsDomain, mgmtSingleAccModeDomain: mgmtSingleAccModeDomain, @@ -103,14 +103,14 @@ func (s *BaseServer) Start(ctx context.Context) error { var tlsConfig *tls.Config tlsEnabled := false - if s.config.HttpConfig.LetsEncryptDomain != "" { - s.certManager, err = encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain) + if s.Config.HttpConfig.LetsEncryptDomain != "" { + s.certManager, err = encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain) if err != nil { return fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err) } tlsEnabled = true - } else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" { - tlsConfig, err = loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey) + } else if s.Config.HttpConfig.CertFile != "" && s.Config.HttpConfig.CertKey != "" { + tlsConfig, err = loadTLSConfig(s.Config.HttpConfig.CertFile, s.Config.HttpConfig.CertKey) if err != nil { log.WithContext(srvCtx).Errorf("cannot load TLS credentials: %v", err) return err @@ -126,8 +126,8 @@ func (s *BaseServer) Start(ctx context.Context) error { if !s.disableMetrics { idpManager := "disabled" - if s.config.IdpManagerConfig != nil && s.config.IdpManagerConfig.ManagerType != "" { - idpManager = s.config.IdpManagerConfig.ManagerType + if s.Config.IdpManagerConfig != nil && s.Config.IdpManagerConfig.ManagerType != "" { + idpManager = s.Config.IdpManagerConfig.ManagerType } metricsWorker := metrics.NewWorker(srvCtx, installationID, s.Store(), s.PeersUpdateManager(), idpManager) go metricsWorker.Run(srvCtx) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 0aadadf84..62dc215d8 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -24,7 +24,6 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" - "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/store" @@ -55,15 +54,12 @@ const ( type Server struct { accountManager account.Manager settingsManager settings.Manager - wgKey wgtypes.Key proto.UnimplementedManagementServiceServer - peersUpdateManager network_map.PeersUpdateManager - config *nbconfig.Config - secretsManager SecretsManager - appMetrics telemetry.AppMetrics - ephemeralManager ephemeral.Manager - peerLocks sync.Map - authManager auth.Manager + config *nbconfig.Config + secretsManager SecretsManager + appMetrics telemetry.AppMetrics + peerLocks sync.Map + authManager auth.Manager logBlockedPeers bool blockPeersWithSameConfig bool @@ -82,23 +78,16 @@ func NewServer( config *nbconfig.Config, accountManager account.Manager, settingsManager settings.Manager, - peersUpdateManager network_map.PeersUpdateManager, secretsManager SecretsManager, appMetrics telemetry.AppMetrics, - ephemeralManager ephemeral.Manager, authManager auth.Manager, integratedPeerValidator integrated_validator.IntegratedValidator, networkMapController network_map.Controller, ) (*Server, error) { - key, err := wgtypes.GeneratePrivateKey() - if err != nil { - return nil, err - } - if appMetrics != nil { // update gauge based on number of connected peers which is equal to open gRPC streams - err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 { - return int64(peersUpdateManager.CountStreams()) + err := appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 { + return int64(networkMapController.CountStreams()) }) if err != nil { return nil, err @@ -120,16 +109,12 @@ func NewServer( } return &Server{ - wgKey: key, - // peerKey -> event channel - peersUpdateManager: peersUpdateManager, accountManager: accountManager, settingsManager: settingsManager, config: config, secretsManager: secretsManager, authManager: authManager, appMetrics: appMetrics, - ephemeralManager: ephemeralManager, logBlockedPeers: logBlockedPeers, blockPeersWithSameConfig: blockPeersWithSameConfig, integratedPeerValidator: integratedPeerValidator, @@ -163,8 +148,14 @@ func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.Ser nanos := int32(now.Nanosecond()) expiresAt := ×tamp.Timestamp{Seconds: secs, Nanos: nanos} + key, err := s.secretsManager.GetWGKey() + if err != nil { + log.WithContext(ctx).Errorf("failed to get wireguard key: %v", err) + return nil, errors.New("failed to get wireguard key") + } + return &proto.ServerKeyResponse{ - Key: s.wgKey.PublicKey().String(), + Key: key.PublicKey().String(), ExpiresAt: expiresAt, }, nil } @@ -269,9 +260,13 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S return err } - updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID) - - s.ephemeralManager.OnPeerConnected(ctx, peer) + updates, err := s.networkMapController.OnPeerConnected(ctx, accountID, peer.ID) + if err != nil { + log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err) + s.syncSem.Add(-1) + s.cancelPeerRoutines(ctx, accountID, peer) + return err + } s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) @@ -323,13 +318,19 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg // sendUpdate encrypts the update message using the peer key and the server's wireguard key, // then sends the encrypted message to the connected peer via the sync server. func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error { - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) + key, err := s.secretsManager.GetWGKey() + if err != nil { + s.cancelPeerRoutines(ctx, accountID, peer) + return status.Errorf(codes.Internal, "failed processing update message") + } + + encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update) if err != nil { s.cancelPeerRoutines(ctx, accountID, peer) return status.Errorf(codes.Internal, "failed processing update message") } err = srv.SendMsg(&proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), + WgPubKey: key.PublicKey().String(), Body: encryptedResp, }) if err != nil { @@ -348,9 +349,8 @@ func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer if err != nil { log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err) } - s.peersUpdateManager.CloseChannel(ctx, peer.ID) + s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID) s.secretsManager.CancelRefresh(peer.ID) - s.ephemeralManager.OnPeerDisconnected(ctx, peer) log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key) } @@ -504,7 +504,12 @@ func (s *Server) parseRequest(ctx context.Context, req *proto.EncryptedMessage, return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey) } - err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, parsed) + key, err := s.secretsManager.GetWGKey() + if err != nil { + return wgtypes.Key{}, status.Errorf(codes.Internal, "failed processing request") + } + + err = encryption.DecryptMessage(peerKey, key, req.Body, parsed) if err != nil { return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message") } @@ -601,12 +606,6 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart)) - // if the login request contains setup key then it is a registration request - if loginReq.GetSetupKey() != "" { - s.ephemeralManager.OnPeerDisconnected(ctx, peer) - log.WithContext(ctx).Debugf("Login: OnPeerDisconnected since start %v", time.Since(reqStart)) - } - loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks) if err != nil { log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err) @@ -615,14 +614,20 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart)) - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) + key, err := s.secretsManager.GetWGKey() + if err != nil { + log.WithContext(ctx).Warnf("failed getting server's WireGuard private key: %s", err) + return nil, status.Errorf(codes.Internal, "failed logging in peer") + } + + encryptedResp, err := encryption.EncryptMessage(peerKey, key, loginResp) if err != nil { log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID) return nil, status.Errorf(codes.Internal, "failed logging in peer") } return &proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), + WgPubKey: key.PublicKey().String(), Body: encryptedResp, }, nil } @@ -715,14 +720,19 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) + key, err := s.secretsManager.GetWGKey() + if err != nil { + return status.Errorf(codes.Internal, "failed getting server key") + } + + encryptedResp, err := encryption.EncryptMessage(peerKey, key, plainResp) if err != nil { return status.Errorf(codes.Internal, "error handling request") } sendStart := time.Now() err = srv.Send(&proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), + WgPubKey: key.PublicKey().String(), Body: encryptedResp, }) log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart)) @@ -752,7 +762,12 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr return nil, status.Error(codes.InvalidArgument, errMSG) } - err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{}) + key, err := s.secretsManager.GetWGKey() + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get server key") + } + + err = encryption.DecryptMessage(peerKey, key, req.Body, &proto.DeviceAuthorizationFlowRequest{}) if err != nil { errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey) log.WithContext(ctx).Warn(errMSG) @@ -782,13 +797,13 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr }, } - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp) + encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp) if err != nil { return nil, status.Error(codes.Internal, "failed to encrypt no device authorization flow information") } return &proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), + WgPubKey: key.PublicKey().String(), Body: encryptedResp, }, nil } @@ -810,7 +825,12 @@ func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.Encryp return nil, status.Error(codes.InvalidArgument, errMSG) } - err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{}) + key, err := s.secretsManager.GetWGKey() + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get server key") + } + + err = encryption.DecryptMessage(peerKey, key, req.Body, &proto.PKCEAuthorizationFlowRequest{}) if err != nil { errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey) log.WithContext(ctx).Warn(errMSG) @@ -838,13 +858,13 @@ func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.Encryp flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow) - encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp) + encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp) if err != nil { return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information") } return &proto.EncryptedMessage{ - WgPubKey: s.wgKey.PublicKey().String(), + WgPubKey: key.PublicKey().String(), Body: encryptedResp, }, nil } diff --git a/management/internals/shared/grpc/server_test.go b/management/internals/shared/grpc/server_test.go index 9867b38e3..d3a12e986 100644 --- a/management/internals/shared/grpc/server_test.go +++ b/management/internals/shared/grpc/server_test.go @@ -73,15 +73,17 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { mgmtServer := &Server{ - wgKey: testingServerKey, + secretsManager: &TimeBasedAuthSecretsManager{wgKey: testingServerKey}, config: &config.Config{ DeviceAuthorizationFlow: testCase.inputFlow, }, } message := &mgmtProto.DeviceAuthorizationFlowRequest{} + key, err := mgmtServer.secretsManager.GetWGKey() + require.NoError(t, err, "should be able to get server key") - encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message) + encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), key, message) require.NoError(t, err, "should be able to encrypt message") resp, err := mgmtServer.GetDeviceAuthorizationFlow( @@ -95,7 +97,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { if testCase.expectedComparisonFunc != nil { flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{} - err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp) + err = encryption.DecryptMessage(key.PublicKey(), testingClientKey, resp.Body, flowInfoResp) require.NoError(t, err, "should be able to decrypt") testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG) diff --git a/management/internals/shared/grpc/token_mgr.go b/management/internals/shared/grpc/token_mgr.go index e9770db41..0f893ae3a 100644 --- a/management/internals/shared/grpc/token_mgr.go +++ b/management/internals/shared/grpc/token_mgr.go @@ -10,6 +10,7 @@ import ( "time" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" "github.com/netbirdio/netbird/management/internals/controllers/network_map" @@ -29,6 +30,7 @@ type SecretsManager interface { GenerateRelayToken() (*Token, error) SetupRefresh(ctx context.Context, accountID, peerKey string) CancelRefresh(peerKey string) + GetWGKey() (wgtypes.Key, error) } // TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server @@ -43,11 +45,17 @@ type TimeBasedAuthSecretsManager struct { groupsManager groups.Manager turnCancelMap map[string]chan struct{} relayCancelMap map[string]chan struct{} + wgKey wgtypes.Key } type Token auth.Token -func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager { +func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) (*TimeBasedAuthSecretsManager, error) { + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, err + } + mgr := &TimeBasedAuthSecretsManager{ updateManager: updateManager, turnCfg: turnCfg, @@ -56,6 +64,7 @@ func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager relayCancelMap: make(map[string]chan struct{}), settingsManager: settingsManager, groupsManager: groupsManager, + wgKey: key, } if turnCfg != nil { @@ -81,7 +90,12 @@ func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager } } - return mgr + return mgr, nil +} + +// GetWGKey returns WireGuard private key used to generate peer keys +func (m *TimeBasedAuthSecretsManager) GetWGKey() (wgtypes.Key, error) { + return m.wgKey, nil } // GenerateTurnToken generates new time-based secret credentials for TURN diff --git a/management/internals/shared/grpc/token_mgr_test.go b/management/internals/shared/grpc/token_mgr_test.go index 06d28d05b..98eb66fb5 100644 --- a/management/internals/shared/grpc/token_mgr_test.go +++ b/management/internals/shared/grpc/token_mgr_test.go @@ -46,12 +46,13 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) { settingsMockManager := settings.NewMockManager(ctrl) groupsManager := groups.NewManagerMock() - tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ + tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ CredentialsTTL: ttl, Secret: secret, Turns: []*config.Host{TurnTestHost}, TimeBasedCredentials: true, }, rc, settingsMockManager, groupsManager) + require.NoError(t, err) turnCredentials, err := tested.GenerateTurnToken() require.NoError(t, err) @@ -98,12 +99,13 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes() groupsManager := groups.NewManagerMock() - tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ + tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ CredentialsTTL: ttl, Secret: secret, Turns: []*config.Host{TurnTestHost}, TimeBasedCredentials: true, }, rc, settingsMockManager, groupsManager) + require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -201,12 +203,13 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) { settingsMockManager := settings.NewMockManager(ctrl) groupsManager := groups.NewManagerMock() - tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ + tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{ CredentialsTTL: ttl, Secret: secret, Turns: []*config.Host{TurnTestHost}, TimeBasedCredentials: true, }, rc, settingsMockManager, groupsManager) + require.NoError(t, err) tested.SetupRefresh(context.Background(), "someAccountID", peer) if _, ok := tested.turnCancelMap[peer]; !ok { diff --git a/management/server/account.go b/management/server/account.go index 716d5ab5d..dac040db0 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -37,7 +37,6 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -77,7 +76,6 @@ type DefaultAccountManager struct { ctx context.Context eventStore activity.Store geo geolocation.Geolocation - ephemeralManager ephemeral.Manager requestBuffer *AccountRequestBuffer @@ -238,7 +236,7 @@ func BuildManager( log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", accountsCounter) } - cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval) + cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn) if err != nil { return nil, fmt.Errorf("getting cache store: %s", err) } @@ -263,10 +261,6 @@ func BuildManager( return am, nil } -func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) { - am.ephemeralManager = em -} - func (am *DefaultAccountManager) GetExternalCacheManager() account.ExternalCacheManager { return am.externalCacheManager } @@ -2076,7 +2070,10 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us if err != nil { return err } - am.networkMapController.OnPeerUpdated(peer.AccountID, peer) + err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, []string{peerID}) + if err != nil { + return fmt.Errorf("notify network map controller of peer update: %w", err) + } } return nil } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 9b3902d87..b5921ec7a 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -13,7 +13,6 @@ import ( nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -124,5 +123,4 @@ type Manager interface { UpdateToPrimaryAccount(ctx context.Context, accountId string) error GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetCurrentUserInfo(ctx context.Context, userAuth auth.UserAuth) (*users.UserInfoWithPermissions, error) - SetEphemeralManager(em ephemeral.Manager) } diff --git a/management/server/account_test.go b/management/server/account_test.go index 340e8db18..8569f1b2f 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -25,6 +25,8 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" @@ -2959,8 +2961,8 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) - manager, err := BuildManager(ctx, nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) + manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, nil, err } @@ -3371,7 +3373,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) { t.Run("memory cache", func(t *testing.T) { t.Run("should always return true", func(t *testing.T) { - cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) require.NoError(t, err) cold, err := manager.isCacheCold(context.Background(), cacheStore) @@ -3386,7 +3388,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) { t.Cleanup(cleanup) t.Setenv(cache.RedisStoreEnvVar, redisURL) - cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + cacheStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) require.NoError(t, err) t.Run("should return true when no account exists", func(t *testing.T) { diff --git a/management/server/cache/idp.go b/management/server/cache/idp.go index 1b31ff82a..19dfc0f38 100644 --- a/management/server/cache/idp.go +++ b/management/server/cache/idp.go @@ -18,6 +18,7 @@ const ( DefaultIDPCacheExpirationMax = 7 * 24 * time.Hour // 7 days DefaultIDPCacheExpirationMin = 3 * 24 * time.Hour // 3 days DefaultIDPCacheCleanupInterval = 30 * time.Minute + DefaultIDPCacheOpenConn = 100 ) // UserDataCache is an interface that wraps the basic Get, Set and Delete methods for idp.UserData objects. diff --git a/management/server/cache/idp_test.go b/management/server/cache/idp_test.go index 3fcfbb11a..0e8061e94 100644 --- a/management/server/cache/idp_test.go +++ b/management/server/cache/idp_test.go @@ -33,7 +33,7 @@ func TestNewIDPCacheManagers(t *testing.T) { t.Cleanup(cleanup) t.Setenv(cache.RedisStoreEnvVar, redisURL) } - cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval) + cacheStore, err := cache.NewStore(context.Background(), cache.DefaultIDPCacheExpirationMax, cache.DefaultIDPCacheCleanupInterval, cache.DefaultIDPCacheOpenConn) if err != nil { t.Fatalf("couldn't create cache store: %s", err) } diff --git a/management/server/cache/store.go b/management/server/cache/store.go index 1c141a180..54b0242de 100644 --- a/management/server/cache/store.go +++ b/management/server/cache/store.go @@ -3,6 +3,7 @@ package cache import ( "context" "fmt" + "math" "os" "time" @@ -20,24 +21,27 @@ const RedisStoreEnvVar = "NB_IDP_CACHE_REDIS_ADDRESS" // NewStore creates a new cache store with the given max timeout and cleanup interval. It checks for the environment Variable RedisStoreEnvVar // to determine if a redis store should be used. If the environment variable is set, it will attempt to connect to the redis store. -func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration) (store.StoreInterface, error) { +func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (store.StoreInterface, error) { redisAddr := os.Getenv(RedisStoreEnvVar) if redisAddr != "" { - return getRedisStore(ctx, redisAddr) + return getRedisStore(ctx, redisAddr, maxConn) } goc := gocache.New(maxTimeout, cleanupInterval) return gocache_store.NewGoCache(goc), nil } -func getRedisStore(ctx context.Context, redisEnvAddr string) (store.StoreInterface, error) { +func getRedisStore(ctx context.Context, redisEnvAddr string, maxConn int) (store.StoreInterface, error) { options, err := redis.ParseURL(redisEnvAddr) if err != nil { return nil, fmt.Errorf("parsing redis cache url: %s", err) } - options.MaxIdleConns = 6 - options.MinIdleConns = 3 - options.MaxActiveConns = 100 + options.MaxIdleConns = int(math.Ceil(float64(maxConn) * 0.5)) // 50% of max conns + options.MinIdleConns = int(math.Ceil(float64(maxConn) * 0.1)) // 10% of max conns + options.MaxActiveConns = maxConn + options.ConnMaxIdleTime = 30 * time.Minute + options.ConnMaxLifetime = 0 + options.PoolTimeout = 10 * time.Second redisClient := redis.NewClient(options) subCtx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() diff --git a/management/server/cache/store_test.go b/management/server/cache/store_test.go index f49dd6bbd..1b64fd70d 100644 --- a/management/server/cache/store_test.go +++ b/management/server/cache/store_test.go @@ -15,7 +15,7 @@ import ( ) func TestMemoryStore(t *testing.T) { - memStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + memStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) if err != nil { t.Fatalf("couldn't create memory store: %s", err) } @@ -42,7 +42,7 @@ func TestMemoryStore(t *testing.T) { func TestRedisStoreConnectionFailure(t *testing.T) { t.Setenv(cache.RedisStoreEnvVar, "redis://127.0.0.1:6379") - _, err := cache.NewStore(context.Background(), 10*time.Millisecond, 30*time.Millisecond) + _, err := cache.NewStore(context.Background(), 10*time.Millisecond, 30*time.Millisecond, 100) if err == nil { t.Fatal("getting redis cache store should return error") } @@ -65,7 +65,7 @@ func TestRedisStoreConnectionSuccess(t *testing.T) { } t.Setenv(cache.RedisStoreEnvVar, redisURL) - redisStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond) + redisStore, err := cache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) if err != nil { t.Fatalf("couldn't create redis store: %s", err) } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 99b09566a..b5e3f2b99 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -12,6 +12,8 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" @@ -223,7 +225,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } diff --git a/management/server/http/handler.go b/management/server/http/handler.go index c1a8c5885..7cf0b5765 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -21,6 +21,7 @@ import ( "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" + nbpeers "github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/geolocation" nbgroups "github.com/netbirdio/netbird/management/server/groups" @@ -39,7 +40,6 @@ import ( nbnetworks "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" - nbpeers "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/telemetry" ) diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index c4c5ae165..f531f0cdb 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -45,19 +45,6 @@ func NewHandler(accountManager account.Manager, networkMapController network_map } } -func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { - peerToReturn := peer.Copy() - if peer.Status.Connected { - // Although we have online status in store we do not yet have an updated channel so have to show it as disconnected - // This may happen after server restart when not all peers are yet connected - if !h.networkMapController.IsConnected(peer.ID) { - peerToReturn.Status.Connected = false - } - } - - return peerToReturn, nil -} - func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) { peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID) if err != nil { @@ -65,11 +52,6 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, return } - peerToReturn, err := h.checkPeerStatus(peer) - if err != nil { - util.WriteError(ctx, err, w) - return - } settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator) if err != nil { util.WriteError(ctx, err, w) @@ -91,7 +73,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, _, valid := validPeers[peer.ID] reason := invalidPeers[peer.ID] - util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid, reason)) + util.WriteJSONObject(ctx, w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { @@ -237,13 +219,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers)) respBody := make([]*api.PeerBatch, 0, len(peers)) for _, peer := range peers { - peerToReturn, err := h.checkPeerStatus(peer) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0)) + respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0)) } validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index ddf2e2a70..55e779ff0 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -109,14 +109,6 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { GetDNSDomain(gomock.Any()). Return("domain"). AnyTimes() - networkMapController.EXPECT(). - IsConnected(noUpdateChannelTestPeerID). - Return(false). - AnyTimes() - networkMapController.EXPECT(). - IsConnected(gomock.Any()). - Return(true). - AnyTimes() return &Handler{ accountManager: &mock_server.MockAccountManager{ @@ -269,14 +261,6 @@ func TestGetPeers(t *testing.T) { expectedArray: false, expectedPeer: peer, }, - { - name: "GetPeer with no update channel", - requestType: http.MethodGet, - requestPath: "/api/peers/" + peer1.ID, - expectedStatus: http.StatusOK, - expectedArray: false, - expectedPeer: expectedPeer1, - }, { name: "PutPeer", requestType: http.MethodPut, @@ -336,8 +320,6 @@ func TestGetPeers(t *testing.T) { for _, peer := range respBody { if peer.Id == testPeerID { got = peer - } else { - assert.Equal(t, peer.Connected, false) } } @@ -351,14 +333,14 @@ func TestGetPeers(t *testing.T) { t.Log(got) - assert.Equal(t, got.Name, tc.expectedPeer.Name) - assert.Equal(t, got.Version, tc.expectedPeer.Meta.WtVersion) - assert.Equal(t, got.Ip, tc.expectedPeer.IP.String()) - assert.Equal(t, got.Os, "OS core") - assert.Equal(t, got.LoginExpirationEnabled, tc.expectedPeer.LoginExpirationEnabled) - assert.Equal(t, got.SshEnabled, tc.expectedPeer.SSHEnabled) - assert.Equal(t, got.Connected, tc.expectedPeer.Status.Connected) - assert.Equal(t, got.SerialNumber, tc.expectedPeer.Meta.SystemSerialNumber) + assert.Equal(t, tc.expectedPeer.Name, got.Name) + assert.Equal(t, tc.expectedPeer.Meta.WtVersion, got.Version) + assert.Equal(t, tc.expectedPeer.IP.String(), got.Ip) + assert.Equal(t, "OS core", got.Os) + assert.Equal(t, tc.expectedPeer.LoginExpirationEnabled, got.LoginExpirationEnabled) + assert.Equal(t, tc.expectedPeer.SSHEnabled, got.SshEnabled) + assert.Equal(t, tc.expectedPeer.Status.Connected, got.Connected) + assert.Equal(t, tc.expectedPeer.Meta.SystemSerialNumber, got.SerialNumber) }) } } diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index e292a7d6c..e8513feb5 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -15,6 +15,8 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server" @@ -28,7 +30,6 @@ import ( "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" - "github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -72,7 +73,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee ctx := context.Background() requestBuffer := server.NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) am, err := server.BuildManager(ctx, nil, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) if err != nil { t.Fatalf("Failed to create manager: %v", err) diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 42311d944..42f192c0a 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -24,13 +24,14 @@ import ( "github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -363,7 +364,9 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) + ephemeralMgr := manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)) + + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeralMgr, config) accountManager, err := BuildManager(ctx, nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) @@ -372,10 +375,13 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config return nil, nil, "", cleanup, err } - secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + cleanup() + return nil, nil, "", cleanup, err + } - ephemeralMgr := manager.NewEphemeralManager(store, accountManager) - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}, networkMapController) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, MockIntegratedValidator{}, networkMapController) if err != nil { return nil, nil, "", cleanup, err } diff --git a/management/server/management_test.go b/management/server/management_test.go index 2350b225b..648201d4e 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -22,13 +22,14 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -205,7 +206,7 @@ func startServer( ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(ctx, str) - networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) + networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(str, peers.NewManager(str, permissionsManager)), config) accountManager, err := server.BuildManager( context.Background(), @@ -228,15 +229,16 @@ func startServer( } groupsManager := groups.NewManager(str, permissionsManager, accountManager) - secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + t.Fatalf("failed creating secrets manager: %v", err) + } mgmtServer, err := nbgrpc.NewServer( config, accountManager, settingsMockManager, - updateManager, secretsManager, nil, - &manager.EphemeralManager{}, nil, server.MockIntegratedValidator{}, networkMapController, diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 0178e51f5..928098dbe 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -15,7 +15,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -976,11 +975,6 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth a return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented") } -// SetEphemeralManager mocks SetEphemeralManager of the AccountManager interface -func (am *MockAccountManager) SetEphemeralManager(em ephemeral.Manager) { - // Mock implementation - does nothing -} - func (am *MockAccountManager) AllowSync(key string, hash uint64) bool { if am.AllowSyncFunc != nil { return am.AllowSyncFunc(key, hash) diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index a4574c978..e3dd8b0b8 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -13,6 +13,8 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" @@ -792,7 +794,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) return BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } diff --git a/management/server/peer.go b/management/server/peer.go index cd9fbe4c8..f2de05f15 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -136,7 +136,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK } if expired { - am.networkMapController.OnPeerUpdated(accountID, peer) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) + if err != nil { + return fmt.Errorf("notify network map controller of peer update: %w", err) + } } return nil @@ -309,7 +312,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } } - am.networkMapController.OnPeerUpdated(accountID, peer) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) + if err != nil { + return nil, fmt.Errorf("notify network map controller of peer update: %w", err) + } return peer, nil } @@ -365,13 +371,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer storeEvent() } - err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID) - if err != nil { - log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err) - } - - if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peerID); err != nil { - log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err) + if err := am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}); err != nil { + log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peerID, err) } return nil @@ -583,11 +584,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe return fmt.Errorf("failed adding peer to All group: %w", err) } - if temporary { - // we are running the on disconnect handler so that it is considered not connected as we are adding the peer manually - am.ephemeralManager.OnPeerDisconnected(ctx, newPeer) - } - if addedByUser { err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) if err != nil { @@ -645,7 +641,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) - if err := am.networkMapController.OnPeerAdded(ctx, accountID, newPeer.ID); err != nil { + if err := am.networkMapController.OnPeersAdded(ctx, accountID, []string{newPeer.ID}); err != nil { log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err) } @@ -729,7 +725,10 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy } if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) { - am.networkMapController.OnPeerUpdated(accountID, peer) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) + if err != nil { + return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err) + } } return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer) @@ -857,7 +856,10 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction)) if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { - am.networkMapController.OnPeerUpdated(accountID, peer) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) + if err != nil { + return nil, nil, nil, fmt.Errorf("notify network map controller of peer update: %w", err) + } } p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 2d09f5200..752563299 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -28,6 +28,8 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" @@ -1058,6 +1060,7 @@ func testUpdateAccountPeers(t *testing.T) { for _, channel := range peerChannels { update := <-channel + assert.Nil(t, update.Update.NetbirdConfig) assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers)) assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules)) } @@ -1290,7 +1293,7 @@ func Test_RegisterPeerByUser(t *testing.T) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) @@ -1375,7 +1378,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) @@ -1528,7 +1531,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) @@ -1608,7 +1611,7 @@ func Test_LoginPeer(t *testing.T) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) - networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) am, err := BuildManager(context.Background(), nil, s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) diff --git a/management/server/peers/manager.go b/management/server/peers/manager.go deleted file mode 100644 index cb135f4ac..000000000 --- a/management/server/peers/manager.go +++ /dev/null @@ -1,68 +0,0 @@ -package peers - -//go:generate go run github.com/golang/mock/mockgen -package peers -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod - -import ( - "context" - "fmt" - - "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/management/server/permissions/modules" - "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/shared/management/status" -) - -type Manager interface { - GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) - GetPeerAccountID(ctx context.Context, peerID string) (string, error) - GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) - GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) -} - -type managerImpl struct { - store store.Store - permissionsManager permissions.Manager -} - -func NewManager(store store.Store, permissionsManager permissions.Manager) Manager { - return &managerImpl{ - store: store, - permissionsManager: permissionsManager, - } -} - -func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) { - allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read) - if err != nil { - return nil, fmt.Errorf("failed to validate user permissions: %w", err) - } - - if !allowed { - return nil, status.NewPermissionDeniedError() - } - - return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) -} - -func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) { - allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read) - if err != nil { - return nil, fmt.Errorf("failed to validate user permissions: %w", err) - } - - if !allowed { - return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID) - } - - return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") -} - -func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) { - return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID) -} - -func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) { - return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs) -} diff --git a/management/server/route_test.go b/management/server/route_test.go index 5c8b636bc..a413d545b 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -16,6 +16,8 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" @@ -1291,7 +1293,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel. ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), &config.Config{}) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) am, err := BuildManager(context.Background(), nil, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { diff --git a/management/server/user.go b/management/server/user.go index cefc4d1a5..ca02f91e6 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -263,15 +263,11 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init return err } - updateAccountPeers, err := am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo) + _, err = am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo) if err != nil { return err } - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) - } - return nil } @@ -998,14 +994,17 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou peer.UserID, peer.ID, accountID, activity.PeerLoginExpired, peer.EventMeta(dnsDomain), ) + } - am.networkMapController.OnPeerUpdated(accountID, peer) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, peerIDs) + if err != nil { + return fmt.Errorf("notify network map controller of peer update: %w", err) } if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID) - am.networkMapController.DisconnectPeers(ctx, peerIDs) + am.networkMapController.DisconnectPeers(ctx, accountID, peerIDs) } return nil } @@ -1051,7 +1050,6 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account } var allErrors error - var updateAccountPeers bool for _, targetUserID := range targetUserIDs { if initiatorUserID == targetUserID { @@ -1082,19 +1080,11 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account continue } - userHadPeers, err := am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo) + _, err = am.deleteRegularUser(ctx, accountID, initiatorUserID, userInfo) if err != nil { allErrors = errors.Join(allErrors, err) continue } - - if userHadPeers { - updateAccountPeers = true - } - } - - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) } return allErrors @@ -1152,15 +1142,12 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI return false, err } + var peerIDs []string for _, peer := range userPeers { - err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID) - if err != nil { - log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err) - } - - if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peer.ID); err != nil { - log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peer.ID, err) - } + peerIDs = append(peerIDs, peer.ID) + } + if err := am.networkMapController.OnPeersDeleted(ctx, accountID, peerIDs); err != nil { + log.WithContext(ctx).Errorf("failed to delete peers %s from network map: %v", peerIDs, err) } for _, addPeerRemovedEvent := range addPeerRemovedEvents { diff --git a/management/server/user_test.go b/management/server/user_test.go index 5ce15621e..0d778cfa2 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -8,8 +8,10 @@ import ( "time" "github.com/google/go-cmp/cmp" + "go.uber.org/mock/gomock" "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" @@ -547,7 +549,7 @@ func TestUser_InviteNewUser(t *testing.T) { permissionsManager: permissionsManager, } - cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval) + cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn) require.NoError(t, err) am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cs) @@ -739,11 +741,18 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + ctrl := gomock.NewController(t) + networkMapControllerMock := network_map.NewMockController(ctrl) + networkMapControllerMock.EXPECT(). + OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil) + permissionsManager := permissions.NewManager(store) am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, - permissionsManager: permissionsManager, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + permissionsManager: permissionsManager, + networkMapController: networkMapControllerMock, } testCases := []struct { @@ -848,12 +857,20 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { t.Fatalf("Error when saving account: %s", err) } + ctrl := gomock.NewController(t) + networkMapControllerMock := network_map.NewMockController(ctrl) + networkMapControllerMock.EXPECT(). + OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil). + AnyTimes() + permissionsManager := permissions.NewManager(store) am := DefaultAccountManager{ Store: store, eventStore: &activity.InMemoryEventStore{}, integratedPeerValidator: MockIntegratedValidator{}, permissionsManager: permissionsManager, + networkMapController: networkMapControllerMock, } testCases := []struct { @@ -1056,7 +1073,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { permissionsManager: permissionsManager, } - cacheStore, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval) + cacheStore, err := nbcache.NewStore(context.Background(), nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn) assert.NoError(t, err) am.externalCacheManager = nbcache.NewUserDataCache(cacheStore) am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cacheStore) @@ -1412,7 +1429,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { t.Run("deleting user with no linked peers", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg) close(done) }() diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index 9e08317f6..9fbe70948 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -21,6 +21,8 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/client/system" @@ -31,8 +33,6 @@ import ( "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/mock_server" - "github.com/netbirdio/netbird/management/server/peers" - "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" @@ -117,7 +117,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) - networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), config) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManger), config) accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) @@ -125,8 +125,11 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { groupsManager := groups.NewManagerMock() - secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, mgmt.MockIntegratedValidator{}, networkMapController) + secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + if err != nil { + t.Fatal(err) + } + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController) if err != nil { t.Fatal(err) }