diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index bd3209605..78bb0476b 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -12,6 +12,9 @@ import ( "google.golang.org/grpc" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" clientProto "github.com/netbirdio/netbird/client/proto" client "github.com/netbirdio/netbird/client/server" @@ -84,7 +87,6 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp } t.Cleanup(cleanUp) - peersUpdateManager := mgmt.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} if err != nil { return nil, nil @@ -110,13 +112,18 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp Return(&types.Settings{}, nil). AnyTimes() - accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + + accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) } - secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}) + secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &mgmt.MockIntegratedValidator{}, networkMapController) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 2f1098100..15ac0a947 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -26,6 +26,9 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" @@ -1556,7 +1559,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } t.Cleanup(cleanUp) - peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} if err != nil { return nil, "", err @@ -1584,13 +1586,16 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri groupsManager := groups.NewManagerMock() - accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) + networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { return nil, "", err } - secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}) + secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController) if err != nil { return nil, "", err } diff --git a/client/server/server_test.go b/client/server/server_test.go index e0a4805f6..ae5f759ee 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -14,6 +14,9 @@ import ( "go.opentelemetry.io/otel" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" @@ -290,7 +293,6 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve } t.Cleanup(cleanUp) - peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} if err != nil { return nil, "", err @@ -311,13 +313,16 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve settingsMockManager := settings.NewMockManager(ctrl) groupsManager := groups.NewManagerMock() - accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) + peersUpdateManager := update_channel.NewPeersUpdateManager(metrics) + networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { return nil, "", err } - secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}) + secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, &server.MockIntegratedValidator{}, networkMapController) if err != nil { return nil, "", err } diff --git a/go.mod b/go.mod index 7b9bae321..1e3177d7e 100644 --- a/go.mod +++ b/go.mod @@ -99,6 +99,7 @@ require ( go.opentelemetry.io/otel/exporters/prometheus v0.48.0 go.opentelemetry.io/otel/metric v1.35.0 go.opentelemetry.io/otel/sdk/metric v1.35.0 + go.uber.org/mock v0.5.0 go.uber.org/zap v1.27.0 goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 @@ -242,7 +243,6 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect go.opentelemetry.io/otel/sdk v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect - go.uber.org/mock v0.5.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/text v0.27.0 // indirect diff --git a/management/internals/controllers/network_map/controller/cache/dns_config_cache.go b/management/internals/controllers/network_map/controller/cache/dns_config_cache.go new file mode 100644 index 000000000..8cc634ef4 --- /dev/null +++ b/management/internals/controllers/network_map/controller/cache/dns_config_cache.go @@ -0,0 +1,31 @@ +package cache + +import ( + "sync" + + "github.com/netbirdio/netbird/shared/management/proto" +) + +// DNSConfigCache is a thread-safe cache for DNS configuration components +type DNSConfigCache struct { + NameServerGroups sync.Map +} + +// GetNameServerGroup retrieves a cached name server group +func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) { + if c == nil { + return nil, false + } + if value, ok := c.NameServerGroups.Load(key); ok { + return value.(*proto.NameServerGroup), true + } + return nil, false +} + +// SetNameServerGroup stores a name server group in the cache +func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) { + if c == nil { + return + } + c.NameServerGroups.Store(key, value) +} diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go new file mode 100644 index 000000000..ad25494c7 --- /dev/null +++ b/management/internals/controllers/network_map/controller/controller.go @@ -0,0 +1,784 @@ +package controller + +import ( + "context" + "errors" + "fmt" + "os" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + "golang.org/x/mod/semver" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" + "github.com/netbirdio/netbird/management/internals/shared/grpc" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" + "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/management/status" + "github.com/netbirdio/netbird/util" +) + +type Controller struct { + repo Repository + metrics *metrics + // This should not be here, but we need to maintain it for the time being + accountManagerMetrics *telemetry.AccountManagerMetrics + peersUpdateManager network_map.PeersUpdateManager + settingsManager settings.Manager + + accountUpdateLocks sync.Map + sendAccountUpdateLocks sync.Map + updateAccountPeersBufferInterval atomic.Int64 + // dnsDomain is used for peer resolution. This is appended to the peer's name + dnsDomain string + + requestBuffer account.RequestBuffer + + proxyController port_forwarding.Controller + + integratedPeerValidator integrated_validator.IntegratedValidator + + holder *types.Holder + + expNewNetworkMap bool + expNewNetworkMapAIDs map[string]struct{} +} + +type bufferUpdate struct { + mu sync.Mutex + next *time.Timer + update atomic.Bool +} + +var _ network_map.Controller = (*Controller)(nil) + +func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller) *Controller { + nMetrics, err := newMetrics(metrics.UpdateChannelMetrics()) + if err != nil { + log.Fatal(fmt.Errorf("error creating metrics: %w", err)) + } + + newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder)) + if err != nil { + log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err) + newNetworkMapBuilder = false + } + + ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",") + expIDs := make(map[string]struct{}, len(ids)) + for _, id := range ids { + expIDs[id] = struct{}{} + } + + return &Controller{ + repo: newRepository(store), + metrics: nMetrics, + accountManagerMetrics: metrics.AccountManagerMetrics(), + peersUpdateManager: peersUpdateManager, + requestBuffer: requestBuffer, + integratedPeerValidator: integratedPeerValidator, + settingsManager: settingsManager, + dnsDomain: dnsDomain, + + proxyController: proxyController, + + holder: types.NewHolder(), + expNewNetworkMap: newNetworkMapBuilder, + expNewNetworkMapAIDs: expIDs, + } +} + +func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error { + log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) + var ( + account *types.Account + err error + ) + if c.experimentalNetworkMap(accountID) { + account = c.getAccountFromHolderOrInit(accountID) + } else { + account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to get account: %v", err) + } + } + + globalStart := time.Now() + + hasPeersConnected := false + for _, peer := range account.Peers { + if c.peersUpdateManager.HasChannel(peer.ID) { + hasPeersConnected = true + break + } + + } + + if !hasPeersConnected { + return nil + } + + approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + return fmt.Errorf("failed to get validate peers: %v", err) + } + + var wg sync.WaitGroup + semaphore := make(chan struct{}, 10) + + dnsCache := &cache.DNSConfigCache{} + dnsDomain := c.GetDNSDomain(account.Settings) + customZone := account.GetPeersCustomZone(ctx, dnsDomain) + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + if c.experimentalNetworkMap(accountID) { + c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap) + } + + proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) + if err != nil { + log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) + return fmt.Errorf("failed to get proxy network maps: %v", err) + } + + extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to get flow enabled status: %v", err) + } + + dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) + + for _, peer := range account.Peers { + if !c.peersUpdateManager.HasChannel(peer.ID) { + log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) + continue + } + + wg.Add(1) + semaphore <- struct{}{} + go func(p *nbpeer.Peer) { + defer wg.Done() + defer func() { <-semaphore }() + + start := time.Now() + + postureChecks, err := c.getPeerPostureChecks(account, p.ID) + if err != nil { + log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err) + return + } + + c.metrics.CountCalcPostureChecksDuration(time.Since(start)) + start = time.Now() + + var remotePeerNetworkMap *types.NetworkMap + + if c.experimentalNetworkMap(accountID) { + remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics) + } else { + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics) + } + + c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start)) + + proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] + if ok { + remotePeerNetworkMap.Merge(proxyNetworkMap) + } + + peerGroups := account.GetPeerGroups(p.ID) + start = time.Now() + update := grpc.ToSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) + c.metrics.CountToSyncResponseDuration(time.Since(start)) + + c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update}) + }(peer) + } + + wg.Wait() + if c.accountManagerMetrics != nil { + c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart)) + } + + return nil +} + +func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string) error { + log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName()) + + bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) + b := bufUpd.(*bufferUpdate) + + if !b.mu.TryLock() { + b.update.Store(true) + return nil + } + + if b.next != nil { + b.next.Stop() + } + + go func() { + defer b.mu.Unlock() + _ = c.sendUpdateAccountPeers(ctx, accountID) + if !b.update.Load() { + return + } + b.update.Store(false) + if b.next == nil { + b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() { + _ = c.sendUpdateAccountPeers(ctx, accountID) + }) + return + } + b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load())) + }() + + return nil +} + +// UpdatePeers updates all peers that belong to an account. +// Should be called when changes have to be synced to peers. +func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error { + if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return fmt.Errorf("recalculate network map cache: %v", err) + } + + return c.sendUpdateAccountPeers(ctx, accountID) +} + +func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error { + if !c.peersUpdateManager.HasChannel(peerId) { + return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId) + } + + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId) + if err != nil { + return fmt.Errorf("failed to send out updates to peer %s: %v", peerId, err) + } + + peer := account.GetPeer(peerId) + if peer == nil { + return fmt.Errorf("peer %s doesn't exists in account %s", peerId, accountId) + } + + approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + return fmt.Errorf("failed to get validated peers: %v", err) + } + + dnsCache := &cache.DNSConfigCache{} + dnsDomain := c.GetDNSDomain(account.Settings) + customZone := account.GetPeersCustomZone(ctx, dnsDomain) + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + postureChecks, err := c.getPeerPostureChecks(account, peerId) + if err != nil { + log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err) + return fmt.Errorf("failed to get posture checks for peer %s: %v", peerId, err) + } + + proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers) + if err != nil { + log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) + return err + } + + var remotePeerNetworkMap *types.NetworkMap + + if c.experimentalNetworkMap(accountId) { + remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics) + } else { + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics) + } + + proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] + if ok { + remotePeerNetworkMap.Merge(proxyNetworkMap) + } + + extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID) + if err != nil { + return fmt.Errorf("failed to get extra settings: %v", err) + } + + peerGroups := account.GetPeerGroups(peerId) + dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) + + update := grpc.ToSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) + c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update}) + + return nil +} + +func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string) error { + log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName()) + + bufUpd, _ := c.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) + b := bufUpd.(*bufferUpdate) + + if !b.mu.TryLock() { + b.update.Store(true) + return nil + } + + if b.next != nil { + b.next.Stop() + } + + go func() { + defer b.mu.Unlock() + _ = c.UpdateAccountPeers(ctx, accountID) + if !b.update.Load() { + return + } + b.update.Store(false) + if b.next == nil { + b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() { + _ = c.UpdateAccountPeers(ctx, accountID) + }) + return + } + b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load())) + }() + + return nil +} + +func (c *Controller) DeletePeer(ctx context.Context, accountId string, peerId string) error { + network, err := c.repo.GetAccountNetwork(ctx, accountId) + if err != nil { + return err + } + + peers, err := c.repo.GetAccountPeers(ctx, accountId) + if err != nil { + return err + } + + dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion) + c.peersUpdateManager.SendUpdate(ctx, peerId, &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + NetworkMap: &proto.NetworkMap{ + Serial: network.CurrentSerial(), + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + FirewallRules: []*proto.FirewallRule{}, + FirewallRulesIsEmpty: true, + DNSConfig: &proto.DNSConfig{ + ForwarderPort: dnsFwdPort, + }, + }, + }, + }) + c.peersUpdateManager.CloseChannel(ctx, peerId) + return nil +} + +func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { + if isRequiresApproval { + network, err := c.repo.GetAccountNetwork(ctx, accountID) + if err != nil { + return nil, nil, nil, 0, err + } + + emptyMap := &types.NetworkMap{ + Network: network.Copy(), + } + return peer, emptyMap, nil, 0, nil + } + + var ( + account *types.Account + err error + ) + if c.experimentalNetworkMap(accountID) { + account = c.getAccountFromHolderOrInit(accountID) + } else { + account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, nil, nil, 0, err + } + } + + approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + return nil, nil, nil, 0, err + } + + startPosture := time.Now() + postureChecks, err := c.getPeerPostureChecks(account, peer.ID) + if err != nil { + return nil, nil, nil, 0, err + } + log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture)) + + customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings)) + + proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers) + if err != nil { + log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) + return nil, nil, nil, 0, err + } + + var networkMap *types.NetworkMap + + if c.experimentalNetworkMap(accountID) { + networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics) + } else { + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics) + } + + proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] + if ok { + networkMap.Merge(proxyNetworkMap) + } + + dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) + + return peer, networkMap, postureChecks, dnsFwdPort, nil +} + +func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) { + c.enrichAccountFromHolder(account) + account.InitNetworkMapBuilderIfNeeded(validatedPeers) +} + +func (c *Controller) getPeerNetworkMapExp( + ctx context.Context, + accountId string, + peerId string, + validatedPeers map[string]struct{}, + customZone nbdns.CustomZone, + metrics *telemetry.AccountManagerMetrics, +) *types.NetworkMap { + account := c.getAccountFromHolderOrInit(accountId) + if account == nil { + log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId) + return &types.NetworkMap{ + Network: &types.Network{}, + } + } + return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics) +} + +func (c *Controller) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error { + c.enrichAccountFromHolder(account) + return account.OnPeerAddedUpdNetworkMapCache(peerId) +} + +func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error { + c.enrichAccountFromHolder(account) + return account.OnPeerDeletedUpdNetworkMapCache(peerId) +} + +func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) { + account := c.getAccountFromHolder(accountId) + if account == nil { + return + } + account.UpdatePeerInNetworkMapCache(peer) +} + +func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) { + account.RecalculateNetworkMapCache(validatedPeers) + c.updateAccountInHolder(account) +} + +func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error { + if c.experimentalNetworkMap(accountId) { + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId) + if err != nil { + return err + } + validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + log.WithContext(ctx).Errorf("failed to get validate peers: %v", err) + return err + } + c.recalculateNetworkMapCache(account, validatedPeers) + } + return nil +} + +func (c *Controller) experimentalNetworkMap(accountId string) bool { + _, ok := c.expNewNetworkMapAIDs[accountId] + return c.expNewNetworkMap || ok +} + +func (c *Controller) enrichAccountFromHolder(account *types.Account) { + a := c.holder.GetAccount(account.Id) + if a == nil { + c.holder.AddAccount(account) + return + } + account.NetworkMapCache = a.NetworkMapCache + if account.NetworkMapCache == nil { + return + } + account.NetworkMapCache.UpdateAccountPointer(account) + c.holder.AddAccount(account) +} + +func (c *Controller) getAccountFromHolder(accountID string) *types.Account { + return c.holder.GetAccount(accountID) +} + +func (c *Controller) getAccountFromHolderOrInit(accountID string) *types.Account { + a := c.holder.GetAccount(accountID) + if a != nil { + return a + } + account, err := c.holder.LoadOrStoreFunc(accountID, c.requestBuffer.GetAccountWithBackpressure) + if err != nil { + return nil + } + return account +} + +func (c *Controller) updateAccountInHolder(account *types.Account) { + c.holder.AddAccount(account) +} + +// GetDNSDomain returns the configured dnsDomain +func (c *Controller) GetDNSDomain(settings *types.Settings) string { + if settings == nil { + return c.dnsDomain + } + if settings.DNSDomain == "" { + return c.dnsDomain + } + + return settings.DNSDomain +} + +// getPeerPostureChecks returns the posture checks applied for a given peer. +func (c *Controller) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) { + peerPostureChecks := make(map[string]*posture.Checks) + + if len(account.PostureChecks) == 0 { + return nil, nil + } + + for _, policy := range account.Policies { + if !policy.Enabled || len(policy.SourcePostureChecks) == 0 { + continue + } + + if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil { + return nil, err + } + } + + return maps.Values(peerPostureChecks), nil +} + +func (c *Controller) StartWarmup(ctx context.Context) { + var initialInterval int64 + intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS") + interval, err := strconv.Atoi(intervalStr) + if err != nil { + initialInterval = 1 + log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err) + } else { + initialInterval = int64(interval) * 10 + go func() { + startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S") + startupPeriod, err := strconv.Atoi(startupPeriodStr) + if err != nil { + startupPeriod = 1 + log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err) + } + time.Sleep(time.Duration(startupPeriod) * time.Second) + c.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond)) + log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval) + }() + } + c.updateAccountPeersBufferInterval.Store(int64(time.Duration(initialInterval) * time.Millisecond)) + log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval) + +} + +// computeForwarderPort checks if all peers in the account have updated to a specific version or newer. +// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0. +func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { + if len(peers) == 0 { + return int64(network_map.OldForwarderPort) + } + + reqVer := semver.Canonical(requiredVersion) + + // Check if all peers have the required version or newer + for _, peer := range peers { + + // Development version is always supported + if peer.Meta.WtVersion == "development" { + continue + } + peerVersion := semver.Canonical("v" + peer.Meta.WtVersion) + if peerVersion == "" { + // If any peer doesn't have version info, return 0 + return int64(network_map.OldForwarderPort) + } + + // Compare versions + if semver.Compare(peerVersion, reqVer) < 0 { + return int64(network_map.OldForwarderPort) + } + } + + // All peers have the required version or newer + return int64(network_map.DnsForwarderPort) +} + +// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups. +func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error { + isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy) + if err != nil { + return err + } + + if !isInGroup { + return nil + } + + for _, sourcePostureCheckID := range policy.SourcePostureChecks { + postureCheck := account.GetPostureChecks(sourcePostureCheckID) + if postureCheck == nil { + return errors.New("failed to add policy posture checks: posture checks not found") + } + peerPostureChecks[sourcePostureCheckID] = postureCheck + } + + return nil +} + +// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. +func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) { + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + for _, sourceGroup := range rule.Sources { + group := account.GetGroup(sourceGroup) + if group == nil { + return false, fmt.Errorf("failed to check peer in policy source group: group not found") + } + + if slices.Contains(group.Peers, peerID) { + return true, nil + } + } + } + + return false, nil +} + +func (c *Controller) OnPeerUpdated(accountId string, peer *nbpeer.Peer) { + c.UpdatePeerInNetworkMapCache(accountId, peer) + _ = c.bufferSendUpdateAccountPeers(context.Background(), accountId) +} + +func (c *Controller) OnPeerAdded(ctx context.Context, accountID string, peerID string) error { + if c.experimentalNetworkMap(accountID) { + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return err + } + + err = c.onPeerAddedUpdNetworkMapCache(account, peerID) + if err != nil { + return err + } + } + return c.bufferSendUpdateAccountPeers(ctx, accountID) +} + +func (c *Controller) OnPeerDeleted(ctx context.Context, accountID string, peerID string) error { + if c.experimentalNetworkMap(accountID) { + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return err + } + err = c.onPeerDeletedUpdNetworkMapCache(account, peerID) + if err != nil { + return err + } + } + + return c.bufferSendUpdateAccountPeers(ctx, accountID) +} + +// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) +func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) { + account, err := c.repo.GetAccountByPeerID(ctx, peerID) + if err != nil { + return nil, err + } + + peer := account.GetPeer(peerID) + if peer == nil { + return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID) + } + + groups := make(map[string][]string) + for groupID, group := range account.Groups { + groups[groupID] = group.Peers + } + + validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + return nil, err + } + customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings)) + + proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers) + if err != nil { + log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) + return nil, err + } + + var networkMap *types.NetworkMap + + if c.experimentalNetworkMap(peer.AccountID) { + networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil) + } else { + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + } + + proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] + if ok { + networkMap.Merge(proxyNetworkMap) + } + + return networkMap, nil +} + +func (c *Controller) DisconnectPeers(ctx context.Context, peerIDs []string) { + c.peersUpdateManager.CloseChannels(ctx, peerIDs) +} + +func (c *Controller) IsConnected(peerID string) bool { + return c.peersUpdateManager.HasChannel(peerID) +} diff --git a/management/internals/controllers/network_map/controller/controller_test.go b/management/internals/controllers/network_map/controller/controller_test.go new file mode 100644 index 000000000..baaffe677 --- /dev/null +++ b/management/internals/controllers/network_map/controller/controller_test.go @@ -0,0 +1,244 @@ +package controller + +import ( + "context" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/server/mock_server" + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +func TestComputeForwarderPort(t *testing.T) { + // Test with empty peers list + peers := []*nbpeer.Peer{} + result := computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for empty peers list, got %d", network_map.OldForwarderPort, result) + } + + // Test with peers that have old versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.57.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.26.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for peers with old versions, got %d", network_map.OldForwarderPort, result) + } + + // Test with peers that have new versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.DnsForwarderPort) { + t.Errorf("Expected %d for peers with new versions, got %d", network_map.DnsForwarderPort, result) + } + + // Test with peers that have mixed versions + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.59.0", + }, + }, + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "0.57.0", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for peers with mixed versions, got %d", network_map.OldForwarderPort, result) + } + + // Test with peers that have empty version + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for peers with empty version, got %d", network_map.OldForwarderPort, result) + } + + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "development", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result == int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for peers with dev version, got %d", network_map.DnsForwarderPort, result) + } + + // Test with peers that have unknown version string + peers = []*nbpeer.Peer{ + { + Meta: nbpeer.PeerSystemMeta{ + WtVersion: "unknown", + }, + }, + } + result = computeForwarderPort(peers, "v0.59.0") + if result != int64(network_map.OldForwarderPort) { + t.Errorf("Expected %d for peers with unknown version, got %d", network_map.OldForwarderPort, result) + } +} + +func TestBufferUpdateAccountPeers(t *testing.T) { + const ( + peersCount = 1000 + updateAccountInterval = 50 * time.Millisecond + ) + + var ( + deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32 + uapLastRun, dpLastRun atomic.Int64 + + totalNewRuns, totalOldRuns int + ) + + uap := func(ctx context.Context, accountID string) { + updatePeersDeleted.Store(deletedPeers.Load()) + updatePeersRuns.Add(1) + uapLastRun.Store(time.Now().UnixMilli()) + time.Sleep(100 * time.Millisecond) + } + + t.Run("new approach", func(t *testing.T) { + updatePeersRuns.Store(0) + updatePeersDeleted.Store(0) + deletedPeers.Store(0) + + var mustore sync.Map + bufupd := func(ctx context.Context, accountID string) { + mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{}) + b := mu.(*bufferUpdate) + + if !b.mu.TryLock() { + b.update.Store(true) + return + } + + if b.next != nil { + b.next.Stop() + } + + go func() { + defer b.mu.Unlock() + uap(ctx, accountID) + if !b.update.Load() { + return + } + b.update.Store(false) + b.next = time.AfterFunc(updateAccountInterval, func() { + uap(ctx, accountID) + }) + }() + } + dp := func(ctx context.Context, accountID, peerID, userID string) error { + deletedPeers.Add(1) + dpLastRun.Store(time.Now().UnixMilli()) + time.Sleep(10 * time.Millisecond) + bufupd(ctx, accountID) + return nil + } + + am := mock_server.MockAccountManager{ + UpdateAccountPeersFunc: uap, + BufferUpdateAccountPeersFunc: bufupd, + DeletePeerFunc: dp, + } + empty := "" + for range peersCount { + //nolint + am.DeletePeer(context.Background(), empty, empty, empty) + } + time.Sleep(100 * time.Millisecond) + + assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") + assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") + assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") + + totalNewRuns = int(updatePeersRuns.Load()) + }) + + t.Run("old approach", func(t *testing.T) { + updatePeersRuns.Store(0) + updatePeersDeleted.Store(0) + deletedPeers.Store(0) + + var mustore sync.Map + bufupd := func(ctx context.Context, accountID string) { + mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{}) + b := mu.(*sync.Mutex) + + if !b.TryLock() { + return + } + + go func() { + time.Sleep(updateAccountInterval) + b.Unlock() + uap(ctx, accountID) + }() + } + dp := func(ctx context.Context, accountID, peerID, userID string) error { + deletedPeers.Add(1) + dpLastRun.Store(time.Now().UnixMilli()) + time.Sleep(10 * time.Millisecond) + bufupd(ctx, accountID) + return nil + } + + am := mock_server.MockAccountManager{ + UpdateAccountPeersFunc: uap, + BufferUpdateAccountPeersFunc: bufupd, + DeletePeerFunc: dp, + } + empty := "" + for range peersCount { + //nolint + am.DeletePeer(context.Background(), empty, empty, empty) + } + time.Sleep(100 * time.Millisecond) + + assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") + assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") + assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") + + totalOldRuns = int(updatePeersRuns.Load()) + }) + assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) + t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) +} diff --git a/management/internals/controllers/network_map/controller/metrics.go b/management/internals/controllers/network_map/controller/metrics.go new file mode 100644 index 000000000..5832d2130 --- /dev/null +++ b/management/internals/controllers/network_map/controller/metrics.go @@ -0,0 +1,15 @@ +package controller + +import ( + "github.com/netbirdio/netbird/management/server/telemetry" +) + +type metrics struct { + *telemetry.UpdateChannelMetrics +} + +func newMetrics(updateChannelMetrics *telemetry.UpdateChannelMetrics) (*metrics, error) { + return &metrics{ + updateChannelMetrics, + }, nil +} diff --git a/management/internals/controllers/network_map/controller/repository.go b/management/internals/controllers/network_map/controller/repository.go new file mode 100644 index 000000000..44144263b --- /dev/null +++ b/management/internals/controllers/network_map/controller/repository.go @@ -0,0 +1,39 @@ +package controller + +import ( + "context" + + "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +type Repository interface { + GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error) + GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error) + GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) +} + +type repository struct { + store store.Store +} + +var _ Repository = (*repository)(nil) + +func newRepository(s store.Store) Repository { + return &repository{ + store: s, + } +} + +func (r *repository) GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error) { + return r.store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) +} + +func (r *repository) GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error) { + return r.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") +} + +func (r *repository) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) { + return r.store.GetAccountByPeerID(ctx, peerID) +} diff --git a/management/internals/controllers/network_map/interface.go b/management/internals/controllers/network_map/interface.go new file mode 100644 index 000000000..6f893ce79 --- /dev/null +++ b/management/internals/controllers/network_map/interface.go @@ -0,0 +1,39 @@ +package network_map + +//go:generate go run go.uber.org/mock/mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod + +import ( + "context" + + nbdns "github.com/netbirdio/netbird/dns" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" +) + +const ( + EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP" + EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS" + + DnsForwarderPort = nbdns.ForwarderServerPort + OldForwarderPort = nbdns.ForwarderClientPort + DnsForwarderPortMinVersion = "v0.59.0" +) + +type Controller interface { + UpdateAccountPeers(ctx context.Context, accountID string) error + UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error + BufferUpdateAccountPeers(ctx context.Context, accountID string) error + GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) + GetDNSDomain(settings *types.Settings) string + StartWarmup(context.Context) + GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) + + DeletePeer(ctx context.Context, accountId string, peerId string) error + + OnPeerUpdated(accountId string, peer *nbpeer.Peer) + OnPeerAdded(ctx context.Context, accountID string, peerID string) error + OnPeerDeleted(ctx context.Context, accountID string, peerID string) error + DisconnectPeers(ctx context.Context, peerIDs []string) + IsConnected(peerID string) bool +} diff --git a/management/internals/controllers/network_map/interface_mock.go b/management/internals/controllers/network_map/interface_mock.go new file mode 100644 index 000000000..aaa093e47 --- /dev/null +++ b/management/internals/controllers/network_map/interface_mock.go @@ -0,0 +1,225 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./interface.go +// +// Generated by this command: +// +// mockgen -package network_map -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod +// + +// Package network_map is a generated GoMock package. +package network_map + +import ( + context "context" + reflect "reflect" + + peer "github.com/netbirdio/netbird/management/server/peer" + posture "github.com/netbirdio/netbird/management/server/posture" + types "github.com/netbirdio/netbird/management/server/types" + gomock "go.uber.org/mock/gomock" +) + +// MockController is a mock of Controller interface. +type MockController struct { + ctrl *gomock.Controller + recorder *MockControllerMockRecorder + isgomock struct{} +} + +// MockControllerMockRecorder is the mock recorder for MockController. +type MockControllerMockRecorder struct { + mock *MockController +} + +// NewMockController creates a new mock instance. +func NewMockController(ctrl *gomock.Controller) *MockController { + mock := &MockController{ctrl: ctrl} + mock.recorder = &MockControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockController) EXPECT() *MockControllerMockRecorder { + return m.recorder +} + +// BufferUpdateAccountPeers mocks base method. +func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID) + ret0, _ := ret[0].(error) + return ret0 +} + +// BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers. +func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID) +} + +// DeletePeer mocks base method. +func (m *MockController) DeletePeer(ctx context.Context, accountId, peerId string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeletePeer", ctx, accountId, peerId) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeletePeer indicates an expected call of DeletePeer. +func (mr *MockControllerMockRecorder) DeletePeer(ctx, accountId, peerId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeer", reflect.TypeOf((*MockController)(nil).DeletePeer), ctx, accountId, peerId) +} + +// DisconnectPeers mocks base method. +func (m *MockController) DisconnectPeers(ctx context.Context, peerIDs []string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "DisconnectPeers", ctx, peerIDs) +} + +// DisconnectPeers indicates an expected call of DisconnectPeers. +func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, peerIDs any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, peerIDs) +} + +// GetDNSDomain mocks base method. +func (m *MockController) GetDNSDomain(settings *types.Settings) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDNSDomain", settings) + ret0, _ := ret[0].(string) + return ret0 +} + +// GetDNSDomain indicates an expected call of GetDNSDomain. +func (mr *MockControllerMockRecorder) GetDNSDomain(settings any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSDomain", reflect.TypeOf((*MockController)(nil).GetDNSDomain), settings) +} + +// GetNetworkMap mocks base method. +func (m *MockController) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNetworkMap", ctx, peerID) + ret0, _ := ret[0].(*types.NetworkMap) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetNetworkMap indicates an expected call of GetNetworkMap. +func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNetworkMap", reflect.TypeOf((*MockController)(nil).GetNetworkMap), ctx, peerID) +} + +// GetValidatedPeerWithMap mocks base method. +func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p) + ret0, _ := ret[0].(*peer.Peer) + ret1, _ := ret[1].(*types.NetworkMap) + ret2, _ := ret[2].([]*posture.Checks) + ret3, _ := ret[3].(int64) + ret4, _ := ret[4].(error) + return ret0, ret1, ret2, ret3, ret4 +} + +// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap. +func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p) +} + +// IsConnected mocks base method. +func (m *MockController) IsConnected(peerID string) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsConnected", peerID) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsConnected indicates an expected call of IsConnected. +func (mr *MockControllerMockRecorder) IsConnected(peerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockController)(nil).IsConnected), peerID) +} + +// OnPeerAdded mocks base method. +func (m *MockController) OnPeerAdded(ctx context.Context, accountID, peerID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnPeerAdded", ctx, accountID, peerID) + ret0, _ := ret[0].(error) + return ret0 +} + +// OnPeerAdded indicates an expected call of OnPeerAdded. +func (mr *MockControllerMockRecorder) OnPeerAdded(ctx, accountID, peerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerAdded", reflect.TypeOf((*MockController)(nil).OnPeerAdded), ctx, accountID, peerID) +} + +// OnPeerDeleted mocks base method. +func (m *MockController) OnPeerDeleted(ctx context.Context, accountID, peerID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnPeerDeleted", ctx, accountID, peerID) + ret0, _ := ret[0].(error) + return ret0 +} + +// OnPeerDeleted indicates an expected call of OnPeerDeleted. +func (mr *MockControllerMockRecorder) OnPeerDeleted(ctx, accountID, peerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDeleted", reflect.TypeOf((*MockController)(nil).OnPeerDeleted), ctx, accountID, peerID) +} + +// OnPeerUpdated mocks base method. +func (m *MockController) OnPeerUpdated(accountId string, peer *peer.Peer) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnPeerUpdated", accountId, peer) +} + +// OnPeerUpdated indicates an expected call of OnPeerUpdated. +func (mr *MockControllerMockRecorder) OnPeerUpdated(accountId, peer any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerUpdated", reflect.TypeOf((*MockController)(nil).OnPeerUpdated), accountId, peer) +} + +// StartWarmup mocks base method. +func (m *MockController) StartWarmup(arg0 context.Context) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "StartWarmup", arg0) +} + +// StartWarmup indicates an expected call of StartWarmup. +func (mr *MockControllerMockRecorder) StartWarmup(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartWarmup", reflect.TypeOf((*MockController)(nil).StartWarmup), arg0) +} + +// UpdateAccountPeer mocks base method. +func (m *MockController) UpdateAccountPeer(ctx context.Context, accountId, peerId string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateAccountPeer", ctx, accountId, peerId) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateAccountPeer indicates an expected call of UpdateAccountPeer. +func (mr *MockControllerMockRecorder) UpdateAccountPeer(ctx, accountId, peerId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeer", reflect.TypeOf((*MockController)(nil).UpdateAccountPeer), ctx, accountId, peerId) +} + +// UpdateAccountPeers mocks base method. +func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateAccountPeers indicates an expected call of UpdateAccountPeers. +func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID) +} diff --git a/management/internals/controllers/network_map/network_map.go b/management/internals/controllers/network_map/network_map.go new file mode 100644 index 000000000..e915c2193 --- /dev/null +++ b/management/internals/controllers/network_map/network_map.go @@ -0,0 +1 @@ +package network_map diff --git a/management/internals/controllers/network_map/update_channel.go b/management/internals/controllers/network_map/update_channel.go new file mode 100644 index 000000000..0b085b85f --- /dev/null +++ b/management/internals/controllers/network_map/update_channel.go @@ -0,0 +1,13 @@ +package network_map + +import "context" + +type PeersUpdateManager interface { + SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) + CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage + CloseChannel(ctx context.Context, peerID string) + CountStreams() int + HasChannel(peerID string) bool + CloseChannels(ctx context.Context, peerIDs []string) + GetAllConnectedPeers() map[string]struct{} +} diff --git a/management/server/updatechannel.go b/management/internals/controllers/network_map/update_channel/updatechannel.go similarity index 87% rename from management/server/updatechannel.go rename to management/internals/controllers/network_map/update_channel/updatechannel.go index adf64592a..5f7db5300 100644 --- a/management/server/updatechannel.go +++ b/management/internals/controllers/network_map/update_channel/updatechannel.go @@ -1,4 +1,4 @@ -package server +package update_channel import ( "context" @@ -7,36 +7,34 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/shared/management/proto" ) const channelBufferSize = 100 -type UpdateMessage struct { - Update *proto.SyncResponse -} - type PeersUpdateManager struct { // peerChannels is an update channel indexed by Peer.ID - peerChannels map[string]chan *UpdateMessage + peerChannels map[string]chan *network_map.UpdateMessage // channelsMux keeps the mutex to access peerChannels channelsMux *sync.RWMutex // metrics provides method to collect application metrics metrics telemetry.AppMetrics } +var _ network_map.PeersUpdateManager = (*PeersUpdateManager)(nil) + // NewPeersUpdateManager returns a new instance of PeersUpdateManager func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager { return &PeersUpdateManager{ - peerChannels: make(map[string]chan *UpdateMessage), + peerChannels: make(map[string]chan *network_map.UpdateMessage), channelsMux: &sync.RWMutex{}, metrics: metrics, } } // SendUpdate sends update message to the peer's channel -func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *UpdateMessage) { +func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, update *network_map.UpdateMessage) { start := time.Now() var found, dropped bool @@ -64,7 +62,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda } // CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer. -func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage { +func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *network_map.UpdateMessage { start := time.Now() closed := false @@ -83,7 +81,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c close(channel) } // mbragin: todo shouldn't it be more? or configurable? - channel := make(chan *UpdateMessage, channelBufferSize) + channel := make(chan *network_map.UpdateMessage, channelBufferSize) p.peerChannels[peerID] = channel log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID) @@ -174,3 +172,9 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool { return ok } + +func (p *PeersUpdateManager) CountStreams() int { + p.channelsMux.RLock() + defer p.channelsMux.RUnlock() + return len(p.peerChannels) +} diff --git a/management/server/updatechannel_test.go b/management/internals/controllers/network_map/update_channel/updatechannel_test.go similarity index 89% rename from management/server/updatechannel_test.go rename to management/internals/controllers/network_map/update_channel/updatechannel_test.go index 0dc86563d..afc1e2c32 100644 --- a/management/server/updatechannel_test.go +++ b/management/internals/controllers/network_map/update_channel/updatechannel_test.go @@ -1,10 +1,11 @@ -package server +package update_channel import ( "context" "testing" "time" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -24,7 +25,7 @@ func TestCreateChannel(t *testing.T) { func TestSendUpdate(t *testing.T) { peer := "test-sendupdate" peersUpdater := NewPeersUpdateManager(nil) - update1 := &UpdateMessage{Update: &proto.SyncResponse{ + update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{ NetworkMap: &proto.NetworkMap{ Serial: 0, }, @@ -44,7 +45,7 @@ func TestSendUpdate(t *testing.T) { peersUpdater.SendUpdate(context.Background(), peer, update1) } - update2 := &UpdateMessage{Update: &proto.SyncResponse{ + update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{ NetworkMap: &proto.NetworkMap{ Serial: 10, }, diff --git a/management/internals/controllers/network_map/update_message.go b/management/internals/controllers/network_map/update_message.go new file mode 100644 index 000000000..33643bcbd --- /dev/null +++ b/management/internals/controllers/network_map/update_message.go @@ -0,0 +1,9 @@ +package network_map + +import ( + "github.com/netbirdio/netbird/shared/management/proto" +) + +type UpdateMessage struct { + Update *proto.SyncResponse +} diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 16e93a549..eadd16c2d 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -22,7 +22,7 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/formatter/hook" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" - "github.com/netbirdio/netbird/management/server" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" nbContext "github.com/netbirdio/netbird/management/server/context" nbhttp "github.com/netbirdio/netbird/management/server/http" @@ -93,7 +93,7 @@ func (s *BaseServer) EventStore() activity.Store { func (s *BaseServer) APIHandler() http.Handler { return Create(s, func() http.Handler { - httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager()) + httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController()) if err != nil { log.Fatalf("failed to create API handler: %v", err) } @@ -145,7 +145,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server { } gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := server.NewServer(context.Background(), s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator()) + srv, err := nbgrpc.NewServer(s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController()) if err != nil { log.Fatalf("failed to create management server: %v", err) } diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index ddd81daa2..b61e33688 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -6,6 +6,10 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/integrations/integrated_validator" @@ -14,9 +18,9 @@ import ( "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" ) -func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager { - return Create(s, func() *server.PeersUpdateManager { - return server.NewPeersUpdateManager(s.Metrics()) +func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager { + return Create(s, func() *update_channel.PeersUpdateManager { + return update_channel.NewPeersUpdateManager(s.Metrics()) }) } @@ -40,9 +44,9 @@ func (s *BaseServer) ProxyController() port_forwarding.Controller { }) } -func (s *BaseServer) SecretsManager() *server.TimeBasedAuthSecretsManager { - return Create(s, func() *server.TimeBasedAuthSecretsManager { - return server.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager()) +func (s *BaseServer) SecretsManager() *grpc.TimeBasedAuthSecretsManager { + return Create(s, func() *grpc.TimeBasedAuthSecretsManager { + return grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager()) }) } @@ -63,3 +67,15 @@ func (s *BaseServer) EphemeralManager() ephemeral.Manager { return manager.NewEphemeralManager(s.Store(), s.AccountManager()) }) } + +func (s *BaseServer) NetworkMapController() network_map.Controller { + return Create(s, func() *nmapcontroller.Controller { + return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.dnsDomain, s.ProxyController()) + }) +} + +func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer { + return Create(s, func() *server.AccountRequestBuffer { + return server.NewAccountRequestBuffer(context.Background(), s.Store()) + }) +} diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 209a20065..409bdaaba 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -66,8 +66,7 @@ func (s *BaseServer) PeersManager() peers.Manager { func (s *BaseServer) AccountManager() account.Manager { return Create(s, func() account.Manager { - accountManager, err := server.BuildManager(context.Background(), s.Store(), s.PeersUpdateManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, - s.dnsDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy) + accountManager, err := server.BuildManager(context.Background(), s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy) if err != nil { log.Fatalf("failed to create account manager: %v", err) } diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go new file mode 100644 index 000000000..9a4681eae --- /dev/null +++ b/management/internals/shared/grpc/conversion.go @@ -0,0 +1,352 @@ +package grpc + +import ( + "context" + "fmt" + + integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" + nbconfig "github.com/netbirdio/netbird/management/internals/server/config" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/proto" +) + +func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig { + if config == nil { + return nil + } + + var stuns []*proto.HostConfig + for _, stun := range config.Stuns { + stuns = append(stuns, &proto.HostConfig{ + Uri: stun.URI, + Protocol: ToResponseProto(stun.Proto), + }) + } + + var turns []*proto.ProtectedHostConfig + if config.TURNConfig != nil { + for _, turn := range config.TURNConfig.Turns { + var username string + var password string + if turnCredentials != nil { + username = turnCredentials.Payload + password = turnCredentials.Signature + } else { + username = turn.Username + password = turn.Password + } + turns = append(turns, &proto.ProtectedHostConfig{ + HostConfig: &proto.HostConfig{ + Uri: turn.URI, + Protocol: ToResponseProto(turn.Proto), + }, + User: username, + Password: password, + }) + } + } + + var relayCfg *proto.RelayConfig + if config.Relay != nil && len(config.Relay.Addresses) > 0 { + relayCfg = &proto.RelayConfig{ + Urls: config.Relay.Addresses, + } + + if relayToken != nil { + relayCfg.TokenPayload = relayToken.Payload + relayCfg.TokenSignature = relayToken.Signature + } + } + + var signalCfg *proto.HostConfig + if config.Signal != nil { + signalCfg = &proto.HostConfig{ + Uri: config.Signal.URI, + Protocol: ToResponseProto(config.Signal.Proto), + } + } + + nbConfig := &proto.NetbirdConfig{ + Stuns: stuns, + Turns: turns, + Signal: signalCfg, + Relay: relayCfg, + } + + return nbConfig +} + +func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig { + netmask, _ := network.Net.Mask.Size() + fqdn := peer.FQDN(dnsName) + return &proto.PeerConfig{ + Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network + SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled}, + Fqdn: fqdn, + RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled, + LazyConnectionEnabled: settings.LazyConnectionEnabled, + } +} + +func ToSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { + response := &proto.SyncResponse{ + PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), + NetworkMap: &proto.NetworkMap{ + Serial: networkMap.Network.CurrentSerial(), + Routes: toProtocolRoutes(networkMap.Routes), + DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort), + }, + Checks: toProtocolChecks(ctx, checks), + } + + nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings) + extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings) + response.NetbirdConfig = extendedConfig + + response.NetworkMap.PeerConfig = response.PeerConfig + + remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) + remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName) + response.RemotePeers = remotePeers + response.NetworkMap.RemotePeers = remotePeers + response.RemotePeersIsEmpty = len(remotePeers) == 0 + response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty + + response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName) + + firewallRules := toProtocolFirewallRules(networkMap.FirewallRules) + response.NetworkMap.FirewallRules = firewallRules + response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0 + + routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules) + response.NetworkMap.RoutesFirewallRules = routesFirewallRules + response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0 + + if networkMap.ForwardingRules != nil { + forwardingRules := make([]*proto.ForwardingRule, 0, len(networkMap.ForwardingRules)) + for _, rule := range networkMap.ForwardingRules { + forwardingRules = append(forwardingRules, rule.ToProto()) + } + response.NetworkMap.ForwardingRules = forwardingRules + } + + return response +} + +func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { + for _, rPeer := range peers { + dst = append(dst, &proto.RemotePeerConfig{ + WgPubKey: rPeer.Key, + AllowedIps: []string{rPeer.IP.String() + "/32"}, + SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)}, + Fqdn: rPeer.FQDN(dnsName), + AgentVersion: rPeer.Meta.WtVersion, + }) + } + return dst +} + +// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache +func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig { + protoUpdate := &proto.DNSConfig{ + ServiceEnable: update.ServiceEnable, + CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)), + NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)), + ForwarderPort: forwardPort, + } + + for _, zone := range update.CustomZones { + protoZone := convertToProtoCustomZone(zone) + protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) + } + + for _, nsGroup := range update.NameServerGroups { + cacheKey := nsGroup.ID + if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists { + protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup) + } else { + protoGroup := convertToProtoNameServerGroup(nsGroup) + cache.SetNameServerGroup(cacheKey, protoGroup) + protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup) + } + } + + return protoUpdate +} + +func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol { + switch configProto { + case nbconfig.UDP: + return proto.HostConfig_UDP + case nbconfig.DTLS: + return proto.HostConfig_DTLS + case nbconfig.HTTP: + return proto.HostConfig_HTTP + case nbconfig.HTTPS: + return proto.HostConfig_HTTPS + case nbconfig.TCP: + return proto.HostConfig_TCP + default: + panic(fmt.Errorf("unexpected config protocol type %v", configProto)) + } +} + +func toProtocolRoutes(routes []*route.Route) []*proto.Route { + protoRoutes := make([]*proto.Route, 0, len(routes)) + for _, r := range routes { + protoRoutes = append(protoRoutes, toProtocolRoute(r)) + } + return protoRoutes +} + +func toProtocolRoute(route *route.Route) *proto.Route { + return &proto.Route{ + ID: string(route.ID), + NetID: string(route.NetID), + Network: route.Network.String(), + Domains: route.Domains.ToPunycodeList(), + NetworkType: int64(route.NetworkType), + Peer: route.Peer, + Metric: int64(route.Metric), + Masquerade: route.Masquerade, + KeepRoute: route.KeepRoute, + SkipAutoApply: route.SkipAutoApply, + } +} + +// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. +func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule { + result := make([]*proto.FirewallRule, len(rules)) + for i := range rules { + rule := rules[i] + + fwRule := &proto.FirewallRule{ + PolicyID: []byte(rule.PolicyID), + PeerIP: rule.PeerIP, + Direction: getProtoDirection(rule.Direction), + Action: getProtoAction(rule.Action), + Protocol: getProtoProtocol(rule.Protocol), + Port: rule.Port, + } + + if shouldUsePortRange(fwRule) { + fwRule.PortInfo = rule.PortRange.ToProto() + } + + result[i] = fwRule + } + return result +} + +// getProtoDirection converts the direction to proto.RuleDirection. +func getProtoDirection(direction int) proto.RuleDirection { + if direction == types.FirewallRuleDirectionOUT { + return proto.RuleDirection_OUT + } + return proto.RuleDirection_IN +} + +func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule { + result := make([]*proto.RouteFirewallRule, len(rules)) + for i := range rules { + rule := rules[i] + result[i] = &proto.RouteFirewallRule{ + SourceRanges: rule.SourceRanges, + Action: getProtoAction(rule.Action), + Destination: rule.Destination, + Protocol: getProtoProtocol(rule.Protocol), + PortInfo: getProtoPortInfo(rule), + IsDynamic: rule.IsDynamic, + Domains: rule.Domains.ToPunycodeList(), + PolicyID: []byte(rule.PolicyID), + RouteID: string(rule.RouteID), + } + } + + return result +} + +// getProtoAction converts the action to proto.RuleAction. +func getProtoAction(action string) proto.RuleAction { + if action == string(types.PolicyTrafficActionDrop) { + return proto.RuleAction_DROP + } + return proto.RuleAction_ACCEPT +} + +// getProtoProtocol converts the protocol to proto.RuleProtocol. +func getProtoProtocol(protocol string) proto.RuleProtocol { + switch types.PolicyRuleProtocolType(protocol) { + case types.PolicyRuleProtocolALL: + return proto.RuleProtocol_ALL + case types.PolicyRuleProtocolTCP: + return proto.RuleProtocol_TCP + case types.PolicyRuleProtocolUDP: + return proto.RuleProtocol_UDP + case types.PolicyRuleProtocolICMP: + return proto.RuleProtocol_ICMP + default: + return proto.RuleProtocol_UNKNOWN + } +} + +// getProtoPortInfo converts the port info to proto.PortInfo. +func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo { + var portInfo proto.PortInfo + if rule.Port != 0 { + portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)} + } else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 { + portInfo.PortSelection = &proto.PortInfo_Range_{ + Range: &proto.PortInfo_Range{ + Start: uint32(portRange.Start), + End: uint32(portRange.End), + }, + } + } + return &portInfo +} + +func shouldUsePortRange(rule *proto.FirewallRule) bool { + return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP) +} + +// Helper function to convert nbdns.CustomZone to proto.CustomZone +func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone { + protoZone := &proto.CustomZone{ + Domain: zone.Domain, + Records: make([]*proto.SimpleRecord, 0, len(zone.Records)), + } + for _, record := range zone.Records { + protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{ + Name: record.Name, + Type: int64(record.Type), + Class: record.Class, + TTL: int64(record.TTL), + RData: record.RData, + }) + } + return protoZone +} + +// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup +func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup { + protoGroup := &proto.NameServerGroup{ + Primary: nsGroup.Primary, + Domains: nsGroup.Domains, + SearchDomainsEnabled: nsGroup.SearchDomainsEnabled, + NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)), + } + for _, ns := range nsGroup.NameServers { + protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{ + IP: ns.IP.String(), + Port: int64(ns.Port), + NSType: int64(ns.NSType), + }) + } + return protoGroup +} diff --git a/management/internals/shared/grpc/conversion_test.go b/management/internals/shared/grpc/conversion_test.go new file mode 100644 index 000000000..701271345 --- /dev/null +++ b/management/internals/shared/grpc/conversion_test.go @@ -0,0 +1,150 @@ +package grpc + +import ( + "fmt" + "net/netip" + "reflect" + "testing" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" +) + +func TestToProtocolDNSConfigWithCache(t *testing.T) { + var cache cache.DNSConfigCache + + // Create two different configs + config1 := nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + { + Domain: "example.com", + Records: []nbdns.SimpleRecord{ + {Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"}, + }, + }, + }, + NameServerGroups: []*nbdns.NameServerGroup{ + { + ID: "group1", + Name: "Group 1", + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("8.8.8.8"), Port: 53}, + }, + }, + }, + } + + config2 := nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + { + Domain: "example.org", + Records: []nbdns.SimpleRecord{ + {Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"}, + }, + }, + }, + NameServerGroups: []*nbdns.NameServerGroup{ + { + ID: "group2", + Name: "Group 2", + NameServers: []nbdns.NameServer{ + {IP: netip.MustParseAddr("8.8.4.4"), Port: 53}, + }, + }, + }, + } + + // First run with config1 + result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort)) + + // Second run with config2 + result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort)) + + // Third run with config1 again + result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort)) + + // Verify that result1 and result3 are identical + if !reflect.DeepEqual(result1, result3) { + t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3) + } + + // Verify that result2 is different from result1 and result3 + if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) { + t.Errorf("Results should be different for different inputs") + } + + if _, exists := cache.GetNameServerGroup("group1"); !exists { + t.Errorf("Cache should contain name server group 'group1'") + } + + if _, exists := cache.GetNameServerGroup("group2"); !exists { + t.Errorf("Cache should contain name server group 'group2'") + } +} + +func BenchmarkToProtocolDNSConfig(b *testing.B) { + sizes := []int{10, 100, 1000} + + for _, size := range sizes { + testData := generateTestData(size) + + b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) { + cache := &cache.DNSConfigCache{} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort)) + } + }) + + b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + cache := &cache.DNSConfigCache{} + toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort)) + } + }) + } +} + +func generateTestData(size int) nbdns.Config { + config := nbdns.Config{ + ServiceEnable: true, + CustomZones: make([]nbdns.CustomZone, size), + NameServerGroups: make([]*nbdns.NameServerGroup, size), + } + + for i := 0; i < size; i++ { + config.CustomZones[i] = nbdns.CustomZone{ + Domain: fmt.Sprintf("domain%d.com", i), + Records: []nbdns.SimpleRecord{ + { + Name: fmt.Sprintf("record%d", i), + Type: 1, + Class: "IN", + TTL: 3600, + RData: "192.168.1.1", + }, + }, + } + + config.NameServerGroups[i] = &nbdns.NameServerGroup{ + ID: fmt.Sprintf("group%d", i), + Primary: i == 0, + Domains: []string{fmt.Sprintf("domain%d.com", i)}, + SearchDomainsEnabled: true, + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + Port: 53, + NSType: 1, + }, + }, + } + } + + return config +} diff --git a/management/server/loginfilter.go b/management/internals/shared/grpc/loginfilter.go similarity index 99% rename from management/server/loginfilter.go rename to management/internals/shared/grpc/loginfilter.go index 8604af6e2..59f69dd90 100644 --- a/management/server/loginfilter.go +++ b/management/internals/shared/grpc/loginfilter.go @@ -1,4 +1,4 @@ -package server +package grpc import ( "hash/fnv" diff --git a/management/server/loginfilter_test.go b/management/internals/shared/grpc/loginfilter_test.go similarity index 99% rename from management/server/loginfilter_test.go rename to management/internals/shared/grpc/loginfilter_test.go index 65782dd9d..8b26e14ab 100644 --- a/management/server/loginfilter_test.go +++ b/management/internals/shared/grpc/loginfilter_test.go @@ -1,4 +1,4 @@ -package server +package grpc import ( "hash/fnv" diff --git a/management/server/grpcserver.go b/management/internals/shared/grpc/server.go similarity index 76% rename from management/server/grpcserver.go rename to management/internals/shared/grpc/server.go index d3d94443a..08a840316 100644 --- a/management/server/grpcserver.go +++ b/management/internals/shared/grpc/server.go @@ -1,4 +1,4 @@ -package server +package grpc import ( "context" @@ -22,7 +22,7 @@ import ( "google.golang.org/grpc/peer" "google.golang.org/grpc/status" - integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/peers/ephemeral" @@ -51,13 +51,13 @@ const ( defaultSyncLim = 1000 ) -// GRPCServer an instance of a Management gRPC API server -type GRPCServer struct { +// Server an instance of a Management gRPC API server +type Server struct { accountManager account.Manager settingsManager settings.Manager wgKey wgtypes.Key proto.UnimplementedManagementServiceServer - peersUpdateManager *PeersUpdateManager + peersUpdateManager network_map.PeersUpdateManager config *nbconfig.Config secretsManager SecretsManager appMetrics telemetry.AppMetrics @@ -69,23 +69,27 @@ type GRPCServer struct { blockPeersWithSameConfig bool integratedPeerValidator integrated_validator.IntegratedValidator + loginFilter *loginFilter + + networkMapController network_map.Controller + syncSem atomic.Int32 syncLim int32 } // NewServer creates a new Management server func NewServer( - ctx context.Context, config *nbconfig.Config, accountManager account.Manager, settingsManager settings.Manager, - peersUpdateManager *PeersUpdateManager, + peersUpdateManager network_map.PeersUpdateManager, secretsManager SecretsManager, appMetrics telemetry.AppMetrics, ephemeralManager ephemeral.Manager, authManager auth.Manager, integratedPeerValidator integrated_validator.IntegratedValidator, -) (*GRPCServer, error) { + networkMapController network_map.Controller, +) (*Server, error) { key, err := wgtypes.GeneratePrivateKey() if err != nil { return nil, err @@ -94,7 +98,7 @@ func NewServer( if appMetrics != nil { // update gauge based on number of connected peers which is equal to open gRPC streams err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 { - return int64(len(peersUpdateManager.peerChannels)) + return int64(peersUpdateManager.CountStreams()) }) if err != nil { return nil, err @@ -115,7 +119,7 @@ func NewServer( } } - return &GRPCServer{ + return &Server{ wgKey: key, // peerKey -> event channel peersUpdateManager: peersUpdateManager, @@ -129,12 +133,15 @@ func NewServer( logBlockedPeers: logBlockedPeers, blockPeersWithSameConfig: blockPeersWithSameConfig, integratedPeerValidator: integratedPeerValidator, + networkMapController: networkMapController, + + loginFilter: newLoginFilter(), syncLim: syncLim, }, nil } -func (s *GRPCServer) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) { +func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.ServerKeyResponse, error) { ip := "" p, ok := peer.FromContext(ctx) if ok { @@ -171,7 +178,7 @@ func getRealIP(ctx context.Context) net.IP { // Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and // notifies the connected peer of any updates (e.g. new peers under the same account) -func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { +func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { if s.syncSem.Load() >= s.syncLim { return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later") } @@ -191,7 +198,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi sRealIP := realIP.String() peerMeta := extractPeerMeta(ctx, syncReq.GetMeta()) metahashed := metaHash(peerMeta, sRealIP) - if !s.accountManager.AllowSync(peerKey.String(), metahashed) { + if !s.loginFilter.allowLogin(peerKey.String(), metahashed) { if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountSyncRequestBlocked() } @@ -245,35 +252,29 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) } - peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) + metahash := metaHash(peerMeta, realIP.String()) + s.loginFilter.addLogin(peerKey.String(), metahash) + + peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) if err != nil { log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) s.syncSem.Add(-1) return mapError(ctx, err) } - log.WithContext(ctx).Debugf("Sync: SyncAndMarkPeer since start %v", time.Since(reqStart)) - - err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv) + err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv, dnsFwdPort) if err != nil { log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) return err } - log.WithContext(ctx).Debugf("Sync: sendInitialSync since start %v", time.Since(reqStart)) updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID) - log.WithContext(ctx).Debugf("Sync: CreateChannel since start %v", time.Since(reqStart)) - s.ephemeralManager.OnPeerConnected(ctx, peer) - log.WithContext(ctx).Debugf("Sync: OnPeerConnected since start %v", time.Since(reqStart)) - s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) - log.WithContext(ctx).Debugf("Sync: SetupRefresh since start %v", time.Since(reqStart)) - if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) } @@ -281,15 +282,13 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi unlock() unlock = nil - log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart)) - s.syncSem.Add(-1) return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) } // handleUpdates sends updates to the connected peer until the updates channel is closed. -func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error { log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String()) for { select { @@ -323,7 +322,7 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe // sendUpdate encrypts the update message using the peer key and the server's wireguard key, // then sends the encrypted message to the connected peer via the sync server. -func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error { encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) if err != nil { s.cancelPeerRoutines(ctx, accountID, peer) @@ -341,7 +340,7 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w return nil } -func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { +func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { unlock := s.acquirePeerLockByUID(ctx, peer.Key) defer unlock() @@ -356,7 +355,7 @@ func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, p log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key) } -func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) { +func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) { if s.authManager == nil { return "", status.Errorf(codes.Internal, "missing auth manager") } @@ -390,7 +389,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string return userAuth.UserId, nil } -func (s *GRPCServer) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) { +func (s *Server) acquirePeerLockByUID(ctx context.Context, uniqueID string) (unlock func()) { log.WithContext(ctx).Tracef("acquiring peer lock for ID %s", uniqueID) start := time.Now() @@ -498,7 +497,7 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee } } -func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) { +func (s *Server) parseRequest(ctx context.Context, req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) { peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { log.WithContext(ctx).Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey) @@ -517,7 +516,7 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa // In case it is, the login is successful // In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer. // In case of the successful registration login is also successful -func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { +func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { reqStart := time.Now() realIP := getRealIP(ctx) sRealIP := realIP.String() @@ -531,7 +530,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p peerMeta := extractPeerMeta(ctx, loginReq.GetMeta()) metahashed := metaHash(peerMeta, sRealIP) - if !s.accountManager.AllowSync(peerKey.String(), metahashed) { + if !s.loginFilter.allowLogin(peerKey.String(), metahashed) { if s.logBlockedPeers { log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed) } @@ -628,7 +627,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p }, nil } -func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) { +func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) { var relayToken *Token var err error if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 { @@ -647,7 +646,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer // if peer has reached this point then it has logged in loginResp := &proto.LoginResponse{ NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil), - PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(settings), settings), + PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings), Checks: toProtocolChecks(ctx, postureChecks), } @@ -659,7 +658,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer // // The user ID can be empty if the token is not provided, which is acceptable if the peer is already // registered or if it uses a setup key to register. -func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) { +func (s *Server) processJwtToken(ctx context.Context, loginReq *proto.LoginRequest, peerKey wgtypes.Key) (string, error) { userID := "" if loginReq.GetJwtToken() != "" { var err error @@ -679,166 +678,13 @@ func (s *GRPCServer) processJwtToken(ctx context.Context, loginReq *proto.LoginR return userID, nil } -func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol { - switch configProto { - case nbconfig.UDP: - return proto.HostConfig_UDP - case nbconfig.DTLS: - return proto.HostConfig_DTLS - case nbconfig.HTTP: - return proto.HostConfig_HTTP - case nbconfig.HTTPS: - return proto.HostConfig_HTTPS - case nbconfig.TCP: - return proto.HostConfig_TCP - default: - panic(fmt.Errorf("unexpected config protocol type %v", configProto)) - } -} - -func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig { - if config == nil { - return nil - } - - var stuns []*proto.HostConfig - for _, stun := range config.Stuns { - stuns = append(stuns, &proto.HostConfig{ - Uri: stun.URI, - Protocol: ToResponseProto(stun.Proto), - }) - } - - var turns []*proto.ProtectedHostConfig - if config.TURNConfig != nil { - for _, turn := range config.TURNConfig.Turns { - var username string - var password string - if turnCredentials != nil { - username = turnCredentials.Payload - password = turnCredentials.Signature - } else { - username = turn.Username - password = turn.Password - } - turns = append(turns, &proto.ProtectedHostConfig{ - HostConfig: &proto.HostConfig{ - Uri: turn.URI, - Protocol: ToResponseProto(turn.Proto), - }, - User: username, - Password: password, - }) - } - } - - var relayCfg *proto.RelayConfig - if config.Relay != nil && len(config.Relay.Addresses) > 0 { - relayCfg = &proto.RelayConfig{ - Urls: config.Relay.Addresses, - } - - if relayToken != nil { - relayCfg.TokenPayload = relayToken.Payload - relayCfg.TokenSignature = relayToken.Signature - } - } - - var signalCfg *proto.HostConfig - if config.Signal != nil { - signalCfg = &proto.HostConfig{ - Uri: config.Signal.URI, - Protocol: ToResponseProto(config.Signal.Proto), - } - } - - nbConfig := &proto.NetbirdConfig{ - Stuns: stuns, - Turns: turns, - Signal: signalCfg, - Relay: relayCfg, - } - - return nbConfig -} - -func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig { - netmask, _ := network.Net.Mask.Size() - fqdn := peer.FQDN(dnsName) - return &proto.PeerConfig{ - Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network - SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled}, - Fqdn: fqdn, - RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled, - LazyConnectionEnabled: settings.LazyConnectionEnabled, - } -} - -func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { - response := &proto.SyncResponse{ - PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), - NetworkMap: &proto.NetworkMap{ - Serial: networkMap.Network.CurrentSerial(), - Routes: toProtocolRoutes(networkMap.Routes), - DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort), - }, - Checks: toProtocolChecks(ctx, checks), - } - - nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings) - extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings) - response.NetbirdConfig = extendedConfig - - response.NetworkMap.PeerConfig = response.PeerConfig - - remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers)) - remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName) - response.RemotePeers = remotePeers - response.NetworkMap.RemotePeers = remotePeers - response.RemotePeersIsEmpty = len(remotePeers) == 0 - response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty - - response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName) - - firewallRules := toProtocolFirewallRules(networkMap.FirewallRules) - response.NetworkMap.FirewallRules = firewallRules - response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0 - - routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules) - response.NetworkMap.RoutesFirewallRules = routesFirewallRules - response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0 - - if networkMap.ForwardingRules != nil { - forwardingRules := make([]*proto.ForwardingRule, 0, len(networkMap.ForwardingRules)) - for _, rule := range networkMap.ForwardingRules { - forwardingRules = append(forwardingRules, rule.ToProto()) - } - response.NetworkMap.ForwardingRules = forwardingRules - } - - return response -} - -func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { - for _, rPeer := range peers { - dst = append(dst, &proto.RemotePeerConfig{ - WgPubKey: rPeer.Key, - AllowedIps: []string{rPeer.IP.String() + "/32"}, - SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)}, - Fqdn: rPeer.FQDN(dnsName), - AgentVersion: rPeer.Meta.WtVersion, - }) - } - return dst -} - // IsHealthy indicates whether the service is healthy -func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) { +func (s *Server) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Empty, error) { return &proto.Empty{}, nil } // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization -func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { +func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer, dnsFwdPort int64) error { var err error var turnToken *Token @@ -862,19 +708,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p return status.Errorf(codes.Internal, "error handling request") } - peerGroups, err := getPeerGroupIDs(ctx, s.accountManager.GetStore(), peer.AccountID, peer.ID) + peerGroups, err := s.accountManager.GetStore().GetPeerGroupIDs(ctx, store.LockingStrengthNone, peer.AccountID, peer.ID) if err != nil { return status.Errorf(codes.Internal, "failed to get peer groups %s", err) } - // Get all peers in the account for forwarder port computation - allPeers, err := s.accountManager.GetStore().GetAccountPeers(ctx, store.LockingStrengthNone, peer.AccountID, "", "") - if err != nil { - return fmt.Errorf("get account peers: %w", err) - } - dnsFwdPort := computeForwarderPort(allPeers, dnsForwarderPortMinVersion) - - plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) + plainResp := ToSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { @@ -899,7 +738,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p // GetDeviceAuthorizationFlow returns a device authorization flow information // This is used for initiating an Oauth 2 device authorization grant flow // which will be used by our clients to Login -func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { +func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { log.WithContext(ctx).Tracef("GetDeviceAuthorizationFlow request for pubKey: %s", req.WgPubKey) start := time.Now() defer func() { @@ -957,7 +796,7 @@ func (s *GRPCServer) GetDeviceAuthorizationFlow(ctx context.Context, req *proto. // GetPKCEAuthorizationFlow returns a pkce authorization flow information // This is used for initiating an Oauth 2 pkce authorization grant flow // which will be used by our clients to Login -func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { +func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { log.WithContext(ctx).Tracef("GetPKCEAuthorizationFlow request for pubKey: %s", req.WgPubKey) start := time.Now() defer func() { @@ -1012,7 +851,7 @@ func (s *GRPCServer) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.En // SyncMeta endpoint is used to synchronize peer's system metadata and notifies the connected, // peer's under the same account of any updates. -func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { +func (s *Server) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { realIP := getRealIP(ctx) log.WithContext(ctx).Debugf("Sync meta request from peer [%s] [%s]", req.WgPubKey, realIP.String()) @@ -1037,7 +876,7 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) return &proto.Empty{}, nil } -func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { +func (s *Server) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey) start := time.Now() diff --git a/management/internals/shared/grpc/server_test.go b/management/internals/shared/grpc/server_test.go new file mode 100644 index 000000000..9867b38e3 --- /dev/null +++ b/management/internals/shared/grpc/server_test.go @@ -0,0 +1,106 @@ +package grpc + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/internals/server/config" + mgmtProto "github.com/netbirdio/netbird/shared/management/proto" +) + +func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { + testingServerKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err) + } + + testingClientKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err) + } + + testCases := []struct { + name string + inputFlow *config.DeviceAuthorizationFlow + expectedFlow *mgmtProto.DeviceAuthorizationFlow + expectedErrFunc require.ErrorAssertionFunc + expectedErrMSG string + expectedComparisonFunc require.ComparisonAssertionFunc + expectedComparisonMSG string + }{ + { + name: "Testing No Device Flow Config", + inputFlow: nil, + expectedErrFunc: require.Error, + expectedErrMSG: "should return error", + }, + { + name: "Testing Invalid Device Flow Provider Config", + inputFlow: &config.DeviceAuthorizationFlow{ + Provider: "NoNe", + ProviderConfig: config.ProviderConfig{ + ClientID: "test", + }, + }, + expectedErrFunc: require.Error, + expectedErrMSG: "should return error", + }, + { + name: "Testing Full Device Flow Config", + inputFlow: &config.DeviceAuthorizationFlow{ + Provider: "hosted", + ProviderConfig: config.ProviderConfig{ + ClientID: "test", + }, + }, + expectedFlow: &mgmtProto.DeviceAuthorizationFlow{ + Provider: 0, + ProviderConfig: &mgmtProto.ProviderConfig{ + ClientID: "test", + }, + }, + expectedErrFunc: require.NoError, + expectedErrMSG: "should not return error", + expectedComparisonFunc: require.Equal, + expectedComparisonMSG: "should match", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + mgmtServer := &Server{ + wgKey: testingServerKey, + config: &config.Config{ + DeviceAuthorizationFlow: testCase.inputFlow, + }, + } + + message := &mgmtProto.DeviceAuthorizationFlowRequest{} + + encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message) + require.NoError(t, err, "should be able to encrypt message") + + resp, err := mgmtServer.GetDeviceAuthorizationFlow( + context.TODO(), + &mgmtProto.EncryptedMessage{ + WgPubKey: testingClientKey.PublicKey().String(), + Body: encryptedMSG, + }, + ) + testCase.expectedErrFunc(t, err, testCase.expectedErrMSG) + if testCase.expectedComparisonFunc != nil { + flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{} + + err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp) + require.NoError(t, err, "should be able to decrypt") + + testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG) + testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG) + } + }) + } +} diff --git a/management/server/token_mgr.go b/management/internals/shared/grpc/token_mgr.go similarity index 93% rename from management/server/token_mgr.go rename to management/internals/shared/grpc/token_mgr.go index f9293e7a8..e9770db41 100644 --- a/management/server/token_mgr.go +++ b/management/internals/shared/grpc/token_mgr.go @@ -1,4 +1,4 @@ -package server +package grpc import ( "context" @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/settings" @@ -37,7 +38,7 @@ type TimeBasedAuthSecretsManager struct { relayCfg *nbconfig.Relay turnHmacToken *auth.TimedHMAC relayHmacToken *authv2.Generator - updateManager *PeersUpdateManager + updateManager network_map.PeersUpdateManager settingsManager settings.Manager groupsManager groups.Manager turnCancelMap map[string]chan struct{} @@ -46,7 +47,7 @@ type TimeBasedAuthSecretsManager struct { type Token auth.Token -func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager { +func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager { mgr := &TimeBasedAuthSecretsManager{ updateManager: updateManager, turnCfg: turnCfg, @@ -227,7 +228,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont m.extendNetbirdConfig(ctx, peerID, accountID, update) log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID) - m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) + m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update}) } func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) { @@ -251,7 +252,7 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac m.extendNetbirdConfig(ctx, peerID, accountID, update) log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID) - m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) + m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update}) } func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) { diff --git a/management/server/token_mgr_test.go b/management/internals/shared/grpc/token_mgr_test.go similarity index 94% rename from management/server/token_mgr_test.go rename to management/internals/shared/grpc/token_mgr_test.go index 5c956dc31..06d28d05b 100644 --- a/management/server/token_mgr_test.go +++ b/management/internals/shared/grpc/token_mgr_test.go @@ -1,4 +1,4 @@ -package server +package grpc import ( "context" @@ -13,6 +13,8 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/settings" @@ -31,7 +33,7 @@ var TurnTestHost = &config.Host{ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) { ttl := util.Duration{Duration: time.Hour} secret := "some_secret" - peersManager := NewPeersUpdateManager(nil) + peersManager := update_channel.NewPeersUpdateManager(nil) rc := &config.Relay{ Addresses: []string{"localhost:0"}, @@ -80,7 +82,7 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) { func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { ttl := util.Duration{Duration: 2 * time.Second} secret := "some_secret" - peersManager := NewPeersUpdateManager(nil) + peersManager := update_channel.NewPeersUpdateManager(nil) peer := "some_peer" updateChannel := peersManager.CreateChannel(context.Background(), peer) @@ -116,7 +118,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { t.Errorf("expecting peer to be present in the relay cancel map, got not present") } - var updates []*UpdateMessage + var updates []*network_map.UpdateMessage loop: for timeout := time.After(5 * time.Second); ; { @@ -185,7 +187,7 @@ loop: func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) { ttl := util.Duration{Duration: time.Hour} secret := "some_secret" - peersManager := NewPeersUpdateManager(nil) + peersManager := update_channel.NewPeersUpdateManager(nil) peer := "some_peer" rc := &config.Relay{ diff --git a/management/server/account.go b/management/server/account.go index f5a5c7b7a..a4b2a752b 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -11,10 +11,8 @@ import ( "reflect" "regexp" "slices" - "strconv" "strings" "sync" - "sync/atomic" "time" cacheStore "github.com/eko/gocache/lib/v4/store" @@ -26,6 +24,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/formatter/hook" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" nbcache "github.com/netbirdio/netbird/management/server/cache" @@ -53,9 +52,6 @@ const ( peerSchedulerRetryInterval = 3 * time.Second emptyUserID = "empty user ID in claims" errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" - - envNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP" - envNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS" ) type userLoggedInOnce bool @@ -71,7 +67,7 @@ type DefaultAccountManager struct { cacheMux sync.Mutex // cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded cacheLoading map[string]chan struct{} - peersUpdateManager *PeersUpdateManager + networkMapController network_map.Controller idpManager idp.Manager cacheManager *nbcache.AccountUserDataCache externalCacheManager nbcache.UserDataCache @@ -91,8 +87,7 @@ type DefaultAccountManager struct { singleAccountMode bool // singleAccountModeDomain is a domain to use in singleAccountMode setup singleAccountModeDomain string - // dnsDomain is used for peer resolution. This is appended to the peer's name - dnsDomain string + peerLoginExpiry Scheduler peerInactivityExpiry Scheduler @@ -106,19 +101,11 @@ type DefaultAccountManager struct { permissionsManager permissions.Manager - accountUpdateLocks sync.Map - updateAccountPeersBufferInterval atomic.Int64 - - loginFilter *loginFilter - disableDefaultPolicy bool - - holder *types.Holder - - expNewNetworkMap bool - expNewNetworkMapAIDs map[string]struct{} } +var _ account.Manager = (*DefaultAccountManager)(nil) + func isUniqueConstraintError(err error) bool { switch { case strings.Contains(err.Error(), "(SQLSTATE 23505)"), @@ -185,10 +172,9 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups [] func BuildManager( ctx context.Context, store store.Store, - peersUpdateManager *PeersUpdateManager, + networkMapController network_map.Controller, idpManager idp.Manager, singleAccountModeDomain string, - dnsDomain string, eventStore activity.Store, geo geolocation.Geolocation, userDeleteFromIDPEnabled bool, @@ -204,27 +190,14 @@ func BuildManager( log.WithContext(ctx).Debugf("took %v to instantiate account manager", time.Since(start)) }() - newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(envNewNetworkMapBuilder)) - if err != nil { - log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", envNewNetworkMapBuilder, err) - newNetworkMapBuilder = false - } - - ids := strings.Split(os.Getenv(envNewNetworkMapAccounts), ",") - expIDs := make(map[string]struct{}, len(ids)) - for _, id := range ids { - expIDs[id] = struct{}{} - } - am := &DefaultAccountManager{ Store: store, geo: geo, - peersUpdateManager: peersUpdateManager, + networkMapController: networkMapController, idpManager: idpManager, ctx: context.Background(), cacheMux: sync.Mutex{}, cacheLoading: map[string]chan struct{}{}, - dnsDomain: dnsDomain, eventStore: eventStore, peerLoginExpiry: NewDefaultScheduler(), peerInactivityExpiry: NewDefaultScheduler(), @@ -235,15 +208,10 @@ func BuildManager( proxyController: proxyController, settingsManager: settingsManager, permissionsManager: permissionsManager, - loginFilter: newLoginFilter(), disableDefaultPolicy: disableDefaultPolicy, - holder: types.NewHolder(), - - expNewNetworkMap: newNetworkMapBuilder, - expNewNetworkMapAIDs: expIDs, } - am.startWarmup(ctx) + am.networkMapController.StartWarmup(ctx) accountsCounter, err := store.GetAccountsCounter(ctx) if err != nil { @@ -291,32 +259,6 @@ func (am *DefaultAccountManager) SetEphemeralManager(em ephemeral.Manager) { am.ephemeralManager = em } -func (am *DefaultAccountManager) startWarmup(ctx context.Context) { - var initialInterval int64 - intervalStr := os.Getenv("NB_PEER_UPDATE_INTERVAL_MS") - interval, err := strconv.Atoi(intervalStr) - if err != nil { - initialInterval = 1 - log.WithContext(ctx).Warnf("failed to parse peer update interval, using default value %dms: %v", initialInterval, err) - } else { - initialInterval = int64(interval) * 10 - go func() { - startupPeriodStr := os.Getenv("NB_PEER_UPDATE_STARTUP_PERIOD_S") - startupPeriod, err := strconv.Atoi(startupPeriodStr) - if err != nil { - startupPeriod = 1 - log.WithContext(ctx).Warnf("failed to parse peer update startup period, using default value %ds: %v", startupPeriod, err) - } - time.Sleep(time.Duration(startupPeriod) * time.Second) - am.updateAccountPeersBufferInterval.Store(int64(time.Duration(interval) * time.Millisecond)) - log.WithContext(ctx).Infof("set peer update buffer interval to %dms", interval) - }() - } - am.updateAccountPeersBufferInterval.Store(initialInterval) - log.WithContext(ctx).Infof("set peer update buffer interval to %dms", initialInterval) - -} - func (am *DefaultAccountManager) GetExternalCacheManager() account.ExternalCacheManager { return am.externalCacheManager } @@ -419,9 +361,6 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return nil, err - } go am.UpdateAccountPeers(ctx, accountID) } @@ -1504,10 +1443,6 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth } if removedGroupAffectsPeers || newGroupsAffectsPeers { - if err := am.RecalculateNetworkMapCache(ctx, userAuth.AccountId); err != nil { - return err - } - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId) am.BufferUpdateAccountPeers(ctx, userAuth.AccountId) } @@ -1667,14 +1602,10 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.U return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } -func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) bool { - return am.loginFilter.allowLogin(wgPubKey, metahash) -} - -func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) +func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { + peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { - return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err) + return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err) } err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID) @@ -1682,10 +1613,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } - metahash := metaHash(meta, realIP.String()) - am.loginFilter.addLogin(peerPubKey, metahash) - - return peer, netMap, postureChecks, nil + return peer, netMap, postureChecks, dnsfwdPort, nil } func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error { @@ -1702,41 +1630,19 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st return err } - _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID) + _, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID) if err != nil { - return mapError(ctx, err) + return err } return nil } -// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers() -func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) { - return am.peersUpdateManager.GetAllConnectedPeers(), nil -} - -// HasConnectedChannel returns true if peers has channel in update manager, otherwise false -func (am *DefaultAccountManager) HasConnectedChannel(peerID string) bool { - return am.peersUpdateManager.HasChannel(peerID) -} - var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`) func isDomainValid(domain string) bool { return invalidDomainRegexp.MatchString(domain) } -// GetDNSDomain returns the configured dnsDomain -func (am *DefaultAccountManager) GetDNSDomain(settings *types.Settings) string { - if settings == nil { - return am.dnsDomain - } - if settings.DNSDomain == "" { - return am.dnsDomain - } - - return settings.DNSDomain -} - func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string, peerIDs []string) { peers := []*nbpeer.Peer{} log.WithContext(ctx).Debugf("invalidating peers %v for account %s", peerIDs, accountID) @@ -2159,8 +2065,7 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us if err != nil { return err } - am.updatePeerInNetworkMapCache(peer.AccountID, peer) - am.BufferUpdateAccountPeers(ctx, accountID) + am.networkMapController.OnPeerUpdated(peer.AccountID, peer) } return nil } @@ -2208,7 +2113,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti if err != nil { return fmt.Errorf("get account settings: %w", err) } - dnsDomain := am.GetDNSDomain(settings) + dnsDomain := am.networkMapController.GetDNSDomain(settings) eventMeta := peer.EventMeta(dnsDomain) oldIP := peer.IP.String() diff --git a/management/server/account/manager.go b/management/server/account/manager.go index db377865a..7c174a481 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -89,7 +89,6 @@ type Manager interface { SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) - GetDNSDomain(settings *types.Settings) string StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) @@ -97,10 +96,8 @@ type Manager interface { GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) - LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API - SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API - GetAllConnectedPeers() (map[string]struct{}, error) - HasConnectedChannel(peerID string) bool + LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API + SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API GetExternalCacheManager() ExternalCacheManager GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) @@ -110,7 +107,7 @@ type Manager interface { UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) - SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) @@ -127,6 +124,4 @@ type Manager interface { GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) SetEphemeralManager(em ephemeral.Manager) - AllowSync(string, uint64) bool - RecalculateNetworkMapCache(ctx context.Context, accountId string) error } diff --git a/management/server/account/request_buffer.go b/management/server/account/request_buffer.go new file mode 100644 index 000000000..eced1929f --- /dev/null +++ b/management/server/account/request_buffer.go @@ -0,0 +1,11 @@ +package account + +import ( + "context" + + "github.com/netbirdio/netbird/management/server/types" +) + +type RequestBuffer interface { + GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error) +} diff --git a/management/server/account_test.go b/management/server/account_test.go index 200ba6b98..ee9950796 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -22,6 +22,9 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/cache" @@ -406,7 +409,7 @@ func TestNewAccount(t *testing.T) { } func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -603,7 +606,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.Domain) @@ -644,7 +647,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) { userId := "user-id" domain := "test.domain" _ = newAccountWithId(context.Background(), "", userId, domain, false) - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain) require.NoError(t, err, "create init user failed") @@ -705,7 +708,7 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) { } func TestAccountManager_PrivateAccount(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -731,7 +734,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) { } func TestAccountManager_SetOrUpdateDomain(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -768,7 +771,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { } func TestAccountManager_GetAccountByUserID(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -805,7 +808,7 @@ func createAccount(am *DefaultAccountManager, accountID, userID, domain string) } func TestAccountManager_GetAccount(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -843,7 +846,7 @@ func TestAccountManager_GetAccount(t *testing.T) { } func TestAccountManager_DeleteAccount(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -924,7 +927,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { DomainCategory: types.PublicCategory, } - am, err := createManager(b) + am, _, err := createManager(b) if err != nil { b.Fatal(err) return @@ -1016,7 +1019,7 @@ func genUsers(p string, n int) map[string]*types.User { } func TestAccountManager_AddPeer(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -1086,7 +1089,7 @@ func TestAccountManager_AddPeer(t *testing.T) { } func TestAccountManager_AddPeerWithUserID(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -1155,7 +1158,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { } func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testAccountManager_NetworkUpdates_SaveGroup(t) } @@ -1164,7 +1167,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { } func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) group := types.Group{ ID: "groupA", @@ -1190,8 +1193,8 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { }, true) require.NoError(t, err) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + defer updateManager.CloseChannel(context.Background(), peer1.ID) wg := sync.WaitGroup{} wg.Add(1) @@ -1215,7 +1218,7 @@ func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { } func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testAccountManager_NetworkUpdates_DeletePolicy(t) } @@ -1224,10 +1227,10 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { } func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { - manager, account, peer1, _, _ := setupNetworkMapTest(t) + manager, updateManager, account, peer1, _, _ := setupNetworkMapTest(t) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + defer updateManager.CloseChannel(context.Background(), peer1.ID) // Ensure that we do not receive an update message before the policy is deleted time.Sleep(time.Second) @@ -1258,7 +1261,7 @@ func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { } func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testAccountManager_NetworkUpdates_SavePolicy(t) } @@ -1267,7 +1270,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { } func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { - manager, account, peer1, peer2, _ := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, _ := setupNetworkMapTest(t) group := types.Group{ AccountID: account.Id, @@ -1280,8 +1283,8 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { return } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + defer updateManager.CloseChannel(context.Background(), peer1.ID) wg := sync.WaitGroup{} wg.Add(1) @@ -1316,7 +1319,7 @@ func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { } func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testAccountManager_NetworkUpdates_DeletePeer(t) } @@ -1325,7 +1328,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { } func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { - manager, account, peer1, _, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, _, peer3 := setupNetworkMapTest(t) group := types.Group{ ID: "groupA", @@ -1354,8 +1357,11 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { return } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + // We need to sleep to wait for the buffer peer update + time.Sleep(300 * time.Millisecond) + + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + defer updateManager.CloseChannel(context.Background(), peer1.ID) wg := sync.WaitGroup{} wg.Add(1) @@ -1378,7 +1384,7 @@ func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { } func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testAccountManager_NetworkUpdates_DeleteGroup(t) } @@ -1387,10 +1393,10 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { } func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) + defer updateManager.CloseChannel(context.Background(), peer1.ID) err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", @@ -1457,7 +1463,7 @@ func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { } func TestAccountManager_DeletePeer(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -1538,7 +1544,7 @@ func getEvent(t *testing.T, accountID string, manager nbAccount.Manager, eventTy } func TestGetUsersFromAccount(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -1837,7 +1843,7 @@ func hasNilField(x interface{}) error { } func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -1852,7 +1858,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { } func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -1908,7 +1914,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { } func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -1951,7 +1957,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. } func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") _, err = manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -2013,7 +2019,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test } func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -2677,7 +2683,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { func TestAccount_SetJWTGroups(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", "postgres") - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") // create a new account @@ -2919,18 +2925,18 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) { // Fatalf(format string, args ...interface{}) // } -func createManager(t testing.TB) (*DefaultAccountManager, error) { +func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) { t.Helper() store, err := createStore(t) if err != nil { - return nil, err + return nil, nil, err } eventStore := &activity.InMemoryEventStore{} metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) if err != nil { - return nil, err + return nil, nil, err } ctrl := gomock.NewController(t) @@ -2948,12 +2954,17 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) { permissionsManager := permissions.NewManager(store) - manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + manager, err := BuildManager(ctx, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { - return nil, err + return nil, nil, err } - return manager, nil + return manager, updateManager, nil } func createStore(t testing.TB) (store.Store, error) { @@ -2982,10 +2993,10 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { } } -func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { +func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.PeersUpdateManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { t.Helper() - manager, err := createManager(t) + manager, updateManager, err := createManager(t) if err != nil { t.Fatal(err) } @@ -3026,10 +3037,10 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, peer2 := getPeer(manager, setupKey) peer3 := getPeer(manager, setupKey) - return manager, account, peer1, peer2, peer3 + return manager, updateManager, account, peer1, peer2, peer3 } -func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) { +func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) { t.Helper() select { case msg := <-updateMessage: @@ -3039,7 +3050,7 @@ func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessag } } -func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) { +func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) { t.Helper() select { @@ -3077,7 +3088,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { defer log.SetOutput(os.Stderr) for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { - manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) } @@ -3086,16 +3097,14 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { if err != nil { b.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + updateManager.CreateChannel(ctx, peerID) } - manager.peersUpdateManager.peerChannels = peerChannels b.ResetTimer() start := time.Now() for i := 0; i < b.N; i++ { - _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}) + _, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}) assert.NoError(b, err) } @@ -3140,7 +3149,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { defer log.SetOutput(os.Stderr) for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { - manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) } @@ -3149,11 +3158,10 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { if err != nil { b.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) + for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + updateManager.CreateChannel(ctx, peerID) } - manager.peersUpdateManager.peerChannels = peerChannels b.ResetTimer() start := time.Now() @@ -3210,7 +3218,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { defer log.SetOutput(os.Stderr) for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { - manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) } @@ -3219,11 +3227,10 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { if err != nil { b.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) + for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + updateManager.CreateChannel(ctx, peerID) } - manager.peersUpdateManager.peerChannels = peerChannels b.ResetTimer() start := time.Now() @@ -3282,7 +3289,7 @@ func TestMain(m *testing.M) { } func Test_GetCreateAccountByPrivateDomain(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -3328,7 +3335,7 @@ func Test_GetCreateAccountByPrivateDomain(t *testing.T) { } func Test_UpdateToPrimaryAccount(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -3358,7 +3365,7 @@ func Test_UpdateToPrimaryAccount(t *testing.T) { } func TestDefaultAccountManager_IsCacheCold(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err) t.Run("memory cache", func(t *testing.T) { @@ -3408,7 +3415,7 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) { } func TestPropagateUserGroupMemberships(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err) ctx := context.Background() @@ -3525,7 +3532,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { } func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err) account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") @@ -3557,7 +3564,7 @@ func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) { } func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err) account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") @@ -3596,7 +3603,7 @@ func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) { } func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") @@ -3663,7 +3670,7 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) { } func TestAddNewUserToDomainAccountWithApproval(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -3709,7 +3716,7 @@ func TestAddNewUserToDomainAccountWithApproval(t *testing.T) { } func TestAddNewUserToDomainAccountWithoutApproval(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } diff --git a/management/server/dns.go b/management/server/dns.go index decc5175d..baf6debc3 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -3,54 +3,23 @@ package server import ( "context" "slices" - "sync" log "github.com/sirupsen/logrus" - "golang.org/x/mod/semver" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/status" ) const ( dnsForwarderPort = nbdns.ForwarderServerPort - oldForwarderPort = nbdns.ForwarderClientPort ) -const dnsForwarderPortMinVersion = "v0.59.0" - -// DNSConfigCache is a thread-safe cache for DNS configuration components -type DNSConfigCache struct { - NameServerGroups sync.Map -} - -// GetNameServerGroup retrieves a cached name server group -func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) { - if c == nil { - return nil, false - } - if value, ok := c.NameServerGroups.Load(key); ok { - return value.(*proto.NameServerGroup), true - } - return nil, false -} - -// SetNameServerGroup stores a name server group in the cache -func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) { - if c == nil { - return - } - c.NameServerGroups.Store(key, value) -} - // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) { allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read) @@ -117,9 +86,6 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -194,99 +160,3 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID return validateGroups(settings.DisabledManagementGroups, groups) } - -// computeForwarderPort checks if all peers in the account have updated to a specific version or newer. -// If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0. -func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { - if len(peers) == 0 { - return int64(oldForwarderPort) - } - - reqVer := semver.Canonical(requiredVersion) - - // Check if all peers have the required version or newer - for _, peer := range peers { - - // Development version is always supported - if peer.Meta.WtVersion == "development" { - continue - } - peerVersion := semver.Canonical("v" + peer.Meta.WtVersion) - if peerVersion == "" { - // If any peer doesn't have version info, return 0 - return int64(oldForwarderPort) - } - - // Compare versions - if semver.Compare(peerVersion, reqVer) < 0 { - return int64(oldForwarderPort) - } - } - - // All peers have the required version or newer - return int64(dnsForwarderPort) -} - -// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache -func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache, forwardPort int64) *proto.DNSConfig { - protoUpdate := &proto.DNSConfig{ - ServiceEnable: update.ServiceEnable, - CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)), - NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)), - ForwarderPort: forwardPort, - } - - for _, zone := range update.CustomZones { - protoZone := convertToProtoCustomZone(zone) - protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone) - } - - for _, nsGroup := range update.NameServerGroups { - cacheKey := nsGroup.ID - if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists { - protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup) - } else { - protoGroup := convertToProtoNameServerGroup(nsGroup) - cache.SetNameServerGroup(cacheKey, protoGroup) - protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup) - } - } - - return protoUpdate -} - -// Helper function to convert nbdns.CustomZone to proto.CustomZone -func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone { - protoZone := &proto.CustomZone{ - Domain: zone.Domain, - Records: make([]*proto.SimpleRecord, 0, len(zone.Records)), - } - for _, record := range zone.Records { - protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{ - Name: record.Name, - Type: int64(record.Type), - Class: record.Class, - TTL: int64(record.TTL), - RData: record.RData, - }) - } - return protoZone -} - -// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup -func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup { - protoGroup := &proto.NameServerGroup{ - Primary: nsGroup.Primary, - Domains: nsGroup.Domains, - SearchDomainsEnabled: nsGroup.SearchDomainsEnabled, - NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)), - } - for _, ns := range nsGroup.NameServers { - protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{ - IP: ns.IP.String(), - Port: int64(ns.Port), - NSType: int64(ns.NSType), - }) - } - return protoGroup -} diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 96f73a390..356a2f640 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -2,9 +2,7 @@ package server import ( "context" - "fmt" "net/netip" - "reflect" "testing" "time" @@ -12,6 +10,8 @@ import ( "github.com/stretchr/testify/assert" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" @@ -218,7 +218,13 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { // return empty extra settings for expected calls to UpdateAccountPeers settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes() permissionsManager := permissions.NewManager(store) - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock()) + + return BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } func createDNSStore(t *testing.T) (store.Store, error) { @@ -344,247 +350,8 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account return am.Store.GetAccount(context.Background(), account.Id) } -func generateTestData(size int) nbdns.Config { - config := nbdns.Config{ - ServiceEnable: true, - CustomZones: make([]nbdns.CustomZone, size), - NameServerGroups: make([]*nbdns.NameServerGroup, size), - } - - for i := 0; i < size; i++ { - config.CustomZones[i] = nbdns.CustomZone{ - Domain: fmt.Sprintf("domain%d.com", i), - Records: []nbdns.SimpleRecord{ - { - Name: fmt.Sprintf("record%d", i), - Type: 1, - Class: "IN", - TTL: 3600, - RData: "192.168.1.1", - }, - }, - } - - config.NameServerGroups[i] = &nbdns.NameServerGroup{ - ID: fmt.Sprintf("group%d", i), - Primary: i == 0, - Domains: []string{fmt.Sprintf("domain%d.com", i)}, - SearchDomainsEnabled: true, - NameServers: []nbdns.NameServer{ - { - IP: netip.MustParseAddr("8.8.8.8"), - Port: 53, - NSType: 1, - }, - }, - } - } - - return config -} - -func BenchmarkToProtocolDNSConfig(b *testing.B) { - sizes := []int{10, 100, 1000} - - for _, size := range sizes { - testData := generateTestData(size) - - b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) { - cache := &DNSConfigCache{} - - b.ResetTimer() - for i := 0; i < b.N; i++ { - toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort)) - } - }) - - b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) { - b.ResetTimer() - for i := 0; i < b.N; i++ { - cache := &DNSConfigCache{} - toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort)) - } - }) - } -} - -func TestToProtocolDNSConfigWithCache(t *testing.T) { - var cache DNSConfigCache - - // Create two different configs - config1 := nbdns.Config{ - ServiceEnable: true, - CustomZones: []nbdns.CustomZone{ - { - Domain: "example.com", - Records: []nbdns.SimpleRecord{ - {Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"}, - }, - }, - }, - NameServerGroups: []*nbdns.NameServerGroup{ - { - ID: "group1", - Name: "Group 1", - NameServers: []nbdns.NameServer{ - {IP: netip.MustParseAddr("8.8.8.8"), Port: 53}, - }, - }, - }, - } - - config2 := nbdns.Config{ - ServiceEnable: true, - CustomZones: []nbdns.CustomZone{ - { - Domain: "example.org", - Records: []nbdns.SimpleRecord{ - {Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"}, - }, - }, - }, - NameServerGroups: []*nbdns.NameServerGroup{ - { - ID: "group2", - Name: "Group 2", - NameServers: []nbdns.NameServer{ - {IP: netip.MustParseAddr("8.8.4.4"), Port: 53}, - }, - }, - }, - } - - // First run with config1 - result1 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort)) - - // Second run with config2 - result2 := toProtocolDNSConfig(config2, &cache, int64(dnsForwarderPort)) - - // Third run with config1 again - result3 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort)) - - // Verify that result1 and result3 are identical - if !reflect.DeepEqual(result1, result3) { - t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3) - } - - // Verify that result2 is different from result1 and result3 - if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) { - t.Errorf("Results should be different for different inputs") - } - - if _, exists := cache.GetNameServerGroup("group1"); !exists { - t.Errorf("Cache should contain name server group 'group1'") - } - - if _, exists := cache.GetNameServerGroup("group2"); !exists { - t.Errorf("Cache should contain name server group 'group2'") - } -} - -func TestComputeForwarderPort(t *testing.T) { - // Test with empty peers list - peers := []*nbpeer.Peer{} - result := computeForwarderPort(peers, "v0.59.0") - if result != int64(oldForwarderPort) { - t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result) - } - - // Test with peers that have old versions - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.57.0", - }, - }, - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.26.0", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result != int64(oldForwarderPort) { - t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result) - } - - // Test with peers that have new versions - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.59.0", - }, - }, - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.59.0", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result != int64(dnsForwarderPort) { - t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result) - } - - // Test with peers that have mixed versions - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.59.0", - }, - }, - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "0.57.0", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result != int64(oldForwarderPort) { - t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result) - } - - // Test with peers that have empty version - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result != int64(oldForwarderPort) { - t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result) - } - - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "development", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result == int64(oldForwarderPort) { - t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result) - } - - // Test with peers that have unknown version string - peers = []*nbpeer.Peer{ - { - Meta: nbpeer.PeerSystemMeta{ - WtVersion: "unknown", - }, - }, - } - result = computeForwarderPort(peers, "v0.59.0") - if result != int64(oldForwarderPort) { - t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result) - } -} - func TestDNSAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) err := manager.CreateGroups(context.Background(), account.Id, userID, []*types.Group{ { @@ -600,9 +367,9 @@ func TestDNSAccountPeersUpdate(t *testing.T) { }) assert.NoError(t, err) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) // Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates diff --git a/management/server/event_test.go b/management/server/event_test.go index 8c56fd3f6..420e69866 100644 --- a/management/server/event_test.go +++ b/management/server/event_test.go @@ -28,7 +28,7 @@ func generateAndStoreEvents(t *testing.T, manager *DefaultAccountManager, typ ac } func TestDefaultAccountManager_GetEvents(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { return } diff --git a/management/server/group.go b/management/server/group.go index 3cf9290a2..84e641f26 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -114,9 +114,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -185,9 +182,6 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -256,9 +250,6 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -327,9 +318,6 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -376,7 +364,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err) return nil } - dnsDomain := am.GetDNSDomain(settings) + dnsDomain := am.networkMapController.GetDNSDomain(settings) for _, peerID := range addedPeers { peer, ok := peers[peerID] @@ -493,9 +481,6 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -534,9 +519,6 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -565,9 +547,6 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -606,9 +585,6 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun } if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/group_test.go b/management/server/group_test.go index 31ff29cbc..4935dac5d 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -37,7 +37,7 @@ const ( ) func TestDefaultAccountManager_CreateGroup(t *testing.T) { - am, err := createManager(t) + am, _, err := createManager(t) if err != nil { t.Error("failed to create account manager") } @@ -74,7 +74,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { } func TestDefaultAccountManager_DeleteGroup(t *testing.T) { - am, err := createManager(t) + am, _, err := createManager(t) if err != nil { t.Fatalf("failed to create account manager: %s", err) } @@ -156,7 +156,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) { } func TestDefaultAccountManager_DeleteGroups(t *testing.T) { - am, err := createManager(t) + am, _, err := createManager(t) assert.NoError(t, err, "Failed to create account manager") manager, account, err := initTestGroupAccount(am) @@ -408,7 +408,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t } func TestGroupAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) g := []*types.Group{ { @@ -442,9 +442,9 @@ func TestGroupAccountPeersUpdate(t *testing.T) { assert.NoError(t, err) } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) // Saving a group that is not linked to any resource should not update account peers @@ -748,7 +748,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { } func Test_AddPeerToGroup(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -805,7 +805,7 @@ func Test_AddPeerToGroup(t *testing.T) { } func Test_AddPeerToAll(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -862,7 +862,7 @@ func Test_AddPeerToAll(t *testing.T) { } func Test_AddPeerAndAddToAll(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -942,7 +942,7 @@ func uint32ToIP(n uint32) net.IP { } func Test_IncrementNetworkSerial(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return diff --git a/management/server/holder.go b/management/server/holder.go deleted file mode 100644 index e8a26e1d0..000000000 --- a/management/server/holder.go +++ /dev/null @@ -1,39 +0,0 @@ -package server - -import ( - "github.com/netbirdio/netbird/management/server/types" -) - -func (am *DefaultAccountManager) enrichAccountFromHolder(account *types.Account) { - a := am.holder.GetAccount(account.Id) - if a == nil { - am.holder.AddAccount(account) - return - } - account.NetworkMapCache = a.NetworkMapCache - if account.NetworkMapCache == nil { - return - } - account.NetworkMapCache.UpdateAccountPointer(account) - am.holder.AddAccount(account) -} - -func (am *DefaultAccountManager) getAccountFromHolder(accountID string) *types.Account { - return am.holder.GetAccount(accountID) -} - -func (am *DefaultAccountManager) getAccountFromHolderOrInit(accountID string) *types.Account { - a := am.holder.GetAccount(accountID) - if a != nil { - return a - } - account, err := am.holder.LoadOrStoreFunc(accountID, am.requestBuffer.GetAccountWithBackpressure) - if err != nil { - return nil - } - return account -} - -func (am *DefaultAccountManager) updateAccountInHolder(account *types.Account) { - am.holder.AddAccount(account) -} diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 4d2c224b4..c1a8c5885 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -13,6 +13,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/settings" @@ -65,6 +66,7 @@ func NewAPIHandler( permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, + networkMapController network_map.Controller, ) (http.Handler, error) { var rateLimitingConfig *middleware.RateLimiterConfig @@ -120,7 +122,7 @@ func NewAPIHandler( } accounts.AddEndpoints(accountManager, settingsManager, router) - peers.AddEndpoints(accountManager, router) + peers.AddEndpoints(accountManager, router, networkMapController) users.AddEndpoints(accountManager, router) setup_keys.AddEndpoints(accountManager, router) policies.AddEndpoints(accountManager, LocationManager, router) diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index df89c616c..c4c5ae165 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -10,6 +10,7 @@ import ( "github.com/gorilla/mux" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" nbcontext "github.com/netbirdio/netbird/management/server/context" @@ -23,11 +24,12 @@ import ( // Handler is a handler that returns peers of the account type Handler struct { - accountManager account.Manager + accountManager account.Manager + networkMapController network_map.Controller } -func AddEndpoints(accountManager account.Manager, router *mux.Router) { - peersHandler := NewHandler(accountManager) +func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller) { + peersHandler := NewHandler(accountManager, networkMapController) router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). Methods("GET", "PUT", "DELETE", "OPTIONS") @@ -36,9 +38,10 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) { } // NewHandler creates a new peers Handler -func NewHandler(accountManager account.Manager) *Handler { +func NewHandler(accountManager account.Manager, networkMapController network_map.Controller) *Handler { return &Handler{ - accountManager: accountManager, + accountManager: accountManager, + networkMapController: networkMapController, } } @@ -47,7 +50,7 @@ func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { if peer.Status.Connected { // Although we have online status in store we do not yet have an updated channel so have to show it as disconnected // This may happen after server restart when not all peers are yet connected - if !h.accountManager.HasConnectedChannel(peer.ID) { + if !h.networkMapController.IsConnected(peer.ID) { peerToReturn.Status.Connected = false } } @@ -73,7 +76,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, return } - dnsDomain := h.accountManager.GetDNSDomain(settings) + dnsDomain := h.networkMapController.GetDNSDomain(settings) grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) @@ -139,7 +142,7 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri util.WriteError(ctx, err, w) return } - dnsDomain := h.accountManager.GetDNSDomain(settings) + dnsDomain := h.networkMapController.GetDNSDomain(settings) peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) if err != nil { @@ -227,7 +230,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - dnsDomain := h.accountManager.GetDNSDomain(settings) + dnsDomain := h.networkMapController.GetDNSDomain(settings) grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) @@ -317,7 +320,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { return } - dnsDomain := h.accountManager.GetDNSDomain(account.Settings) + dnsDomain := h.networkMapController.GetDNSDomain(account.Settings) customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index 94564113f..7a5a6d911 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -14,12 +14,14 @@ import ( "time" "github.com/gorilla/mux" + "go.uber.org/mock/gomock" "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbcontext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/shared/management/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -36,7 +38,7 @@ const ( serviceUser = "service_user" ) -func initTestMetaData(peers ...*nbpeer.Peer) *Handler { +func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { peersMap := make(map[string]*nbpeer.Peer) for _, peer := range peers { @@ -99,6 +101,22 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { }, } + ctrl := gomock.NewController(t) + + networkMapController := network_map.NewMockController(ctrl) + networkMapController.EXPECT(). + GetDNSDomain(gomock.Any()). + Return("domain"). + AnyTimes() + networkMapController.EXPECT(). + IsConnected(noUpdateChannelTestPeerID). + Return(false). + AnyTimes() + networkMapController.EXPECT(). + IsConnected(gomock.Any()). + Return(true). + AnyTimes() + return &Handler{ accountManager: &mock_server.MockAccountManager{ UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { @@ -187,6 +205,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { return account.Settings, nil }, }, + networkMapController: networkMapController, } } @@ -270,7 +289,7 @@ func TestGetPeers(t *testing.T) { rr := httptest.NewRecorder() - p := initTestMetaData(peer, peer1) + p := initTestMetaData(t, peer, peer1) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -374,7 +393,7 @@ func TestGetAccessiblePeers(t *testing.T) { UserID: regularUser, } - p := initTestMetaData(peer1, peer2, peer3) + p := initTestMetaData(t, peer1, peer2, peer3) tt := []struct { name string @@ -477,7 +496,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) { }, } - p := initTestMetaData(testPeer) + p := initTestMetaData(t, testPeer) tt := []struct { name string diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index bdf56db6e..ab3f5437a 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -10,6 +10,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" @@ -31,7 +35,7 @@ import ( "github.com/netbirdio/netbird/management/server/users" ) -func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) { +func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *network_map.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) { store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir()) if err != nil { t.Fatalf("Failed to create test store: %v", err) @@ -43,7 +47,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee t.Fatalf("Failed to create metrics: %v", err) } - peersUpdateManager := server.NewPeersUpdateManager(nil) + peersUpdateManager := update_channel.NewPeersUpdateManager(nil) updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId) done := make(chan struct{}) if validateUpdate { @@ -63,7 +67,11 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee userManager := users.NewManager(store) permissionsManager := permissions.NewManager(store) settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager) - am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) + + ctx := context.Background() + requestBuffer := server.NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock()) + am, err := server.BuildManager(ctx, store, networkMapController, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) if err != nil { t.Fatalf("Failed to create manager: %v", err) } @@ -83,7 +91,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee groupsManagerMock := groups.NewManagerMock() peersManager := peers.NewManager(store, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } @@ -91,7 +99,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee return apiHandler, am, done } -func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage) { +func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage) { t.Helper() select { case msg := <-updateMessage: @@ -101,7 +109,7 @@ func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server } } -func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) { +func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *network_map.UpdateMessage, expected *network_map.UpdateMessage) { t.Helper() select { diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index a34d2086b..fc67e01af 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -22,10 +22,14 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/formatter/hook" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/server/config" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" + nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" @@ -321,99 +325,6 @@ func loginPeerWithValidSetupKey(key wgtypes.Key, client mgmtProto.ManagementServ return loginResp, nil } -func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { - testingServerKey, err := wgtypes.GeneratePrivateKey() - if err != nil { - t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err) - } - - testingClientKey, err := wgtypes.GeneratePrivateKey() - if err != nil { - t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err) - } - - testCases := []struct { - name string - inputFlow *config.DeviceAuthorizationFlow - expectedFlow *mgmtProto.DeviceAuthorizationFlow - expectedErrFunc require.ErrorAssertionFunc - expectedErrMSG string - expectedComparisonFunc require.ComparisonAssertionFunc - expectedComparisonMSG string - }{ - { - name: "Testing No Device Flow Config", - inputFlow: nil, - expectedErrFunc: require.Error, - expectedErrMSG: "should return error", - }, - { - name: "Testing Invalid Device Flow Provider Config", - inputFlow: &config.DeviceAuthorizationFlow{ - Provider: "NoNe", - ProviderConfig: config.ProviderConfig{ - ClientID: "test", - }, - }, - expectedErrFunc: require.Error, - expectedErrMSG: "should return error", - }, - { - name: "Testing Full Device Flow Config", - inputFlow: &config.DeviceAuthorizationFlow{ - Provider: "hosted", - ProviderConfig: config.ProviderConfig{ - ClientID: "test", - }, - }, - expectedFlow: &mgmtProto.DeviceAuthorizationFlow{ - Provider: 0, - ProviderConfig: &mgmtProto.ProviderConfig{ - ClientID: "test", - }, - }, - expectedErrFunc: require.NoError, - expectedErrMSG: "should not return error", - expectedComparisonFunc: require.Equal, - expectedComparisonMSG: "should match", - }, - } - - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - mgmtServer := &GRPCServer{ - wgKey: testingServerKey, - config: &config.Config{ - DeviceAuthorizationFlow: testCase.inputFlow, - }, - } - - message := &mgmtProto.DeviceAuthorizationFlowRequest{} - - encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message) - require.NoError(t, err, "should be able to encrypt message") - - resp, err := mgmtServer.GetDeviceAuthorizationFlow( - context.TODO(), - &mgmtProto.EncryptedMessage{ - WgPubKey: testingClientKey.PublicKey().String(), - Body: encryptedMSG, - }, - ) - testCase.expectedErrFunc(t, err, testCase.expectedErrMSG) - if testCase.expectedComparisonFunc != nil { - flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{} - - err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp) - require.NoError(t, err, "should be able to decrypt") - - testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG) - testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG) - } - }) - } -} - func startManagementForTest(t *testing.T, testFile string, config *config.Config) (*grpc.Server, *DefaultAccountManager, string, func(), error) { t.Helper() lis, err := net.Listen("tcp", "localhost:0") @@ -427,7 +338,6 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config t.Fatal(err) } - peersUpdateManager := NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} ctx := context.WithValue(context.Background(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck @@ -451,7 +361,10 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config permissionsManager := permissions.NewManager(store) groupsManager := groups.NewManagerMock() - accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + accountManager, err := BuildManager(ctx, store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) if err != nil { @@ -459,10 +372,10 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config return nil, nil, "", cleanup, err } - secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) ephemeralMgr := manager.NewEphemeralManager(store, accountManager) - mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}, networkMapController) if err != nil { return nil, nil, "", cleanup, err } @@ -764,9 +677,38 @@ func Test_LoginPerformance(t *testing.T) { peerLogin := types.PeerLogin{ WireGuardPubKey: key.String(), SSHKey: "random", - Meta: extractPeerMeta(context.Background(), meta), - SetupKey: setupKey.Key, - ConnectionIP: net.IP{1, 1, 1, 1}, + Meta: nbpeer.PeerSystemMeta{ + Hostname: meta.GetHostname(), + GoOS: meta.GetGoOS(), + Kernel: meta.GetKernel(), + Platform: meta.GetPlatform(), + OS: meta.GetOS(), + OSVersion: meta.GetOSVersion(), + WtVersion: meta.GetNetbirdVersion(), + UIVersion: meta.GetUiVersion(), + KernelVersion: meta.GetKernelVersion(), + SystemSerialNumber: meta.GetSysSerialNumber(), + SystemProductName: meta.GetSysProductName(), + SystemManufacturer: meta.GetSysManufacturer(), + Environment: nbpeer.Environment{ + Cloud: meta.GetEnvironment().GetCloud(), + Platform: meta.GetEnvironment().GetPlatform(), + }, + Flags: nbpeer.Flags{ + RosenpassEnabled: meta.GetFlags().GetRosenpassEnabled(), + RosenpassPermissive: meta.GetFlags().GetRosenpassPermissive(), + ServerSSHAllowed: meta.GetFlags().GetServerSSHAllowed(), + DisableClientRoutes: meta.GetFlags().GetDisableClientRoutes(), + DisableServerRoutes: meta.GetFlags().GetDisableServerRoutes(), + DisableDNS: meta.GetFlags().GetDisableDNS(), + DisableFirewall: meta.GetFlags().GetDisableFirewall(), + BlockLANAccess: meta.GetFlags().GetBlockLANAccess(), + BlockInbound: meta.GetFlags().GetBlockInbound(), + LazyConnectionEnabled: meta.GetFlags().GetLazyConnectionEnabled(), + }, + }, + SetupKey: setupKey.Key, + ConnectionIP: net.IP{1, 1, 1, 1}, } login := func() error { diff --git a/management/server/management_test.go b/management/server/management_test.go index 1a5e47354..930ecfb5a 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -20,7 +20,10 @@ import ( "google.golang.org/grpc/keepalive" "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/server/config" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" @@ -176,7 +179,6 @@ func startServer( log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } - peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) @@ -199,13 +201,18 @@ func startServer( AnyTimes() permissionsManager := permissions.NewManager(str) + + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := server.NewAccountRequestBuffer(ctx, str) + networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + accountManager, err := server.BuildManager( context.Background(), str, - peersUpdateManager, + networkMapController, nil, "", - "netbird.selfhosted", eventStore, nil, false, @@ -220,18 +227,18 @@ func startServer( } groupsManager := groups.NewManager(str, permissionsManager, accountManager) - secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := server.NewServer( - context.Background(), + secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + mgmtServer, err := nbgrpc.NewServer( config, accountManager, settingsMockManager, - peersUpdateManager, + updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, server.MockIntegratedValidator{}, + networkMapController, ) if err != nil { t.Fatalf("failed creating management server: %v", err) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 8baffa58b..781d84f5f 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -38,7 +38,7 @@ type MockAccountManager struct { ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) @@ -94,7 +94,7 @@ type MockAccountManager struct { GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) - SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) RejectUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error @@ -178,11 +178,11 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented") } -func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { if am.SyncAndMarkPeerFunc != nil { return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP) } - return nil, nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") + return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error { @@ -747,11 +747,11 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLog } // SyncPeer mocks SyncPeer of the AccountManager interface -func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { if am.SyncPeerFunc != nil { return am.SyncPeerFunc(ctx, sync, accountID) } - return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") + return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") } // GetAllConnectedPeers mocks GetAllConnectedPeers of the AccountManager interface diff --git a/management/server/nameserver.go b/management/server/nameserver.go index ee77a65bb..f278e1761 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -83,9 +83,6 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return nil, err - } am.UpdateAccountPeers(ctx, accountID) } @@ -137,9 +134,6 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -183,9 +177,6 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 6c985410c..35291b30c 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -11,6 +11,8 @@ import ( "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -785,7 +787,13 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { AnyTimes() permissionsManager := permissions.NewManager(store) - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + + return BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) } func createNSStore(t *testing.T) (store.Store, error) { @@ -975,7 +983,7 @@ func TestValidateDomain(t *testing.T) { } func TestNameServerAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) var newNameServerGroupA *nbdns.NameServerGroup var newNameServerGroupB *nbdns.NameServerGroup @@ -994,9 +1002,9 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { }) assert.NoError(t, err) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) // Creating a nameserver group with a distribution group no peers should not update account peers diff --git a/management/server/networkmap.go b/management/server/networkmap.go deleted file mode 100644 index 2a0627643..000000000 --- a/management/server/networkmap.go +++ /dev/null @@ -1,80 +0,0 @@ -package server - -import ( - "context" - - log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" - - nbdns "github.com/netbirdio/netbird/dns" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/management/server/types" -) - -func (am *DefaultAccountManager) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) { - am.enrichAccountFromHolder(account) - account.InitNetworkMapBuilderIfNeeded(validatedPeers) -} - -func (am *DefaultAccountManager) getPeerNetworkMapExp( - ctx context.Context, - accountId string, - peerId string, - validatedPeers map[string]struct{}, - customZone nbdns.CustomZone, - metrics *telemetry.AccountManagerMetrics, -) *types.NetworkMap { - account := am.getAccountFromHolderOrInit(accountId) - if account == nil { - log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId) - return &types.NetworkMap{ - Network: &types.Network{}, - } - } - return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics) -} - -func (am *DefaultAccountManager) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error { - am.enrichAccountFromHolder(account) - return account.OnPeerAddedUpdNetworkMapCache(peerId) -} - -func (am *DefaultAccountManager) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error { - am.enrichAccountFromHolder(account) - return account.OnPeerDeletedUpdNetworkMapCache(peerId) -} - -func (am *DefaultAccountManager) updatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) { - account := am.getAccountFromHolder(accountId) - if account == nil { - return - } - account.UpdatePeerInNetworkMapCache(peer) -} - -func (am *DefaultAccountManager) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) { - account.RecalculateNetworkMapCache(validatedPeers) - am.updateAccountInHolder(account) -} - -func (am *DefaultAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountId string) error { - if am.experimentalNetworkMap(accountId) { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId) - if err != nil { - return err - } - validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - log.WithContext(ctx).Errorf("failed to get validate peers: %v", err) - return err - } - am.recalculateNetworkMapCache(account, validatedPeers) - } - return nil -} - -func (am *DefaultAccountManager) experimentalNetworkMap(accountId string) bool { - _, ok := am.expNewNetworkMapAIDs[accountId] - return am.expNewNetworkMap || ok -} diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index 0e6d1631b..b6706ca45 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -177,9 +177,6 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw event() } - if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index b740610c2..66484d120 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -157,9 +157,6 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc event() } - if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil { - return nil, err - } go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) return resource, nil @@ -260,9 +257,6 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc event() } - if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil { - return nil, err - } go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) return resource, nil @@ -337,9 +331,6 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net event() } - if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go index 89ac419fd..82cac424a 100644 --- a/management/server/networks/routers/manager.go +++ b/management/server/networks/routers/manager.go @@ -119,9 +119,6 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network)) - if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil { - return nil, err - } go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) return router, nil @@ -186,9 +183,6 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network)) - if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil { - return nil, err - } go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) return router, nil @@ -223,9 +217,6 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo event() - if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/peer.go b/management/server/peer.go index 4c605b5eb..cd9fbe4c8 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -8,8 +8,6 @@ import ( "net" "slices" "strings" - "sync" - "sync/atomic" "time" "github.com/rs/xid" @@ -23,7 +21,6 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/shared/management/domain" - "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" @@ -31,7 +28,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/status" ) @@ -140,12 +136,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK } if expired { - if am.experimentalNetworkMap(accountID) { - am.updatePeerInNetworkMapCache(peer.AccountID, peer) - } - // we need to update other peers because when peer login expires all other peers are notified to disconnect from - // the expired one. Here we notify them that connection is now allowed again. - am.BufferUpdateAccountPeers(ctx, accountID) + am.networkMapController.OnPeerUpdated(accountID, peer) } return nil @@ -201,7 +192,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user var peer *nbpeer.Peer var settings *types.Settings var peerGroupList []string - var requiresPeerUpdates bool var peerLabelChanged bool var sshChanged bool var loginExpirationChanged bool @@ -224,9 +214,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user return err } - dnsDomain = am.GetDNSDomain(settings) + dnsDomain = am.networkMapController.GetDNSDomain(settings) - update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra) + update, _, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, dnsDomain, peerGroupList, settings.Extra) if err != nil { return err } @@ -319,15 +309,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } } - if am.experimentalNetworkMap(accountID) { - am.updatePeerInNetworkMapCache(peer.AccountID, peer) - } - - if peerLabelChanged || requiresPeerUpdates { - am.UpdateAccountPeers(ctx, accountID) - } else if sshChanged { - am.UpdateAccountPeer(ctx, accountID, peer.ID) - } + am.networkMapController.OnPeerUpdated(accountID, peer) return peer, nil } @@ -383,20 +365,13 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer storeEvent() } - if am.experimentalNetworkMap(accountID) { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return err - } - - if err := am.onPeerDeletedUpdNetworkMapCache(account, peerID); err != nil { - log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err) - } - + err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID) + if err != nil { + log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err) } - if userID != activity.SystemInitiator { - am.BufferUpdateAccountPeers(ctx, accountID) + if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peerID); err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err) } return nil @@ -404,47 +379,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) { - account, err := am.Store.GetAccountByPeerID(ctx, peerID) - if err != nil { - return nil, err - } - - peer := account.GetPeer(peerID) - if peer == nil { - return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID) - } - - groups := make(map[string][]string) - for groupID, group := range account.Groups { - groups[groupID] = group.Peers - } - - validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - return nil, err - } - customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) - - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers) - if err != nil { - log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) - return nil, err - } - - var networkMap *types.NetworkMap - - if am.experimentalNetworkMap(peer.AccountID) { - networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil) - } else { - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) - } - - proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] - if ok { - networkMap.Merge(proxyNetworkMap) - } - - return networkMap, nil + return am.networkMapController.GetNetworkMap(ctx, peerID) } // GetPeerNetwork returns the Network for a given peer @@ -703,27 +638,19 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe } opEvent.TargetID = newPeer.ID - opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain(settings)) + opEvent.Meta = newPeer.EventMeta(am.networkMapController.GetDNSDomain(settings)) if !addedByUser { opEvent.Meta["setup_key_name"] = setupKeyName } am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) - if am.experimentalNetworkMap(accountID) { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, err - } - - if err := am.onPeerAddedUpdNetworkMapCache(account, newPeer.ID); err != nil { - log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err) - } + if err := am.networkMapController.OnPeerAdded(ctx, accountID, newPeer.ID); err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err) } - am.BufferUpdateAccountPeers(ctx, accountID) - - return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) + p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer) + return p, nmap, pc, err } func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) { @@ -738,7 +665,7 @@ func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) { } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible -func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { var peer *nbpeer.Peer var peerNotValid bool var isStatusChanged bool @@ -748,7 +675,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, 0, err } err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { @@ -798,17 +725,14 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy return nil }) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, 0, err } if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) { - if am.experimentalNetworkMap(accountID) { - am.updatePeerInNetworkMapCache(peer.AccountID, peer) - } - am.BufferUpdateAccountPeers(ctx, accountID) + am.networkMapController.OnPeerUpdated(accountID, peer) } - return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer) + return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer) } func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { @@ -933,15 +857,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction)) if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { - if am.experimentalNetworkMap(accountID) { - am.updatePeerInNetworkMapCache(peer.AccountID, peer) - } - startBuffer := time.Now() - am.BufferUpdateAccountPeers(ctx, accountID) - log.WithContext(ctx).Debugf("LoginPeer: BufferUpdateAccountPeers took %v", time.Since(startBuffer)) + am.networkMapController.OnPeerUpdated(accountID, peer) } - return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) + p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) + return p, nmap, pc, err } // getPeerPostureChecks returns the posture checks for the peer. @@ -1033,68 +953,6 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co return nil } -func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - if isRequiresApproval { - network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) - if err != nil { - return nil, nil, nil, err - } - - emptyMap := &types.NetworkMap{ - Network: network.Copy(), - } - return peer, emptyMap, nil, nil - } - - var ( - account *types.Account - err error - ) - if am.experimentalNetworkMap(accountID) { - account = am.getAccountFromHolderOrInit(accountID) - } else { - account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, err - } - } - - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - return nil, nil, nil, err - } - - startPosture := time.Now() - postureChecks, err := am.getPeerPostureChecks(account, peer.ID) - if err != nil { - return nil, nil, nil, err - } - log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture)) - - customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) - - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers) - if err != nil { - log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) - return nil, nil, nil, err - } - - var networkMap *types.NetworkMap - - if am.experimentalNetworkMap(accountID) { - networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) - } else { - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()) - } - - proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] - if ok { - networkMap.Merge(proxyNetworkMap) - } - - return peer, networkMap, postureChecks, nil -} - func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transaction store.Store, user *types.User, peer *nbpeer.Peer) error { err := checkAuth(ctx, user.Id, peer) if err != nil { @@ -1118,7 +976,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact return fmt.Errorf("failed to get account settings: %w", err) } - am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain(settings))) + am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.networkMapController.GetDNSDomain(settings))) return nil } @@ -1214,232 +1072,17 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun // UpdateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { - log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) - var ( - account *types.Account - err error - ) - if am.experimentalNetworkMap(accountID) { - account = am.getAccountFromHolderOrInit(accountID) - } else { - account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err) - return - } - } - - globalStart := time.Now() - - hasPeersConnected := false - for _, peer := range account.Peers { - if am.peersUpdateManager.HasChannel(peer.ID) { - hasPeersConnected = true - break - } - - } - - if !hasPeersConnected { - return - } - - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get validate peers: %v", err) - return - } - - var wg sync.WaitGroup - semaphore := make(chan struct{}, 10) - - dnsCache := &DNSConfigCache{} - dnsDomain := am.GetDNSDomain(account.Settings) - customZone := account.GetPeersCustomZone(ctx, dnsDomain) - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - if am.experimentalNetworkMap(accountID) { - am.initNetworkMapBuilderIfNeeded(account, approvedPeersMap) - } - - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) - if err != nil { - log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) - return - } - - extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err) - return - } - - dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion) - - for _, peer := range account.Peers { - if !am.peersUpdateManager.HasChannel(peer.ID) { - log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) - continue - } - - wg.Add(1) - semaphore <- struct{}{} - go func(p *nbpeer.Peer) { - defer wg.Done() - defer func() { <-semaphore }() - - start := time.Now() - - postureChecks, err := am.getPeerPostureChecks(account, p.ID) - if err != nil { - log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err) - return - } - - am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start)) - start = time.Now() - - var remotePeerNetworkMap *types.NetworkMap - - if am.experimentalNetworkMap(accountID) { - remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) - } else { - remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) - } - - am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start)) - start = time.Now() - - proxyNetworkMap, ok := proxyNetworkMaps[p.ID] - if ok { - remotePeerNetworkMap.Merge(proxyNetworkMap) - } - am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start)) - - peerGroups := account.GetPeerGroups(p.ID) - start = time.Now() - update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) - am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start)) - - am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update}) - }(peer) - } - - // - - wg.Wait() - if am.metrics != nil { - am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart)) - } -} - -type bufferUpdate struct { - mu sync.Mutex - next *time.Timer - update atomic.Bool + _ = am.networkMapController.UpdateAccountPeers(ctx, accountID) } func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { - log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName()) - - bufUpd, _ := am.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) - b := bufUpd.(*bufferUpdate) - - if !b.mu.TryLock() { - b.update.Store(true) - return - } - - if b.next != nil { - b.next.Stop() - } - - go func() { - defer b.mu.Unlock() - am.UpdateAccountPeers(ctx, accountID) - if !b.update.Load() { - return - } - b.update.Store(false) - if b.next == nil { - b.next = time.AfterFunc(time.Duration(am.updateAccountPeersBufferInterval.Load()), func() { - am.UpdateAccountPeers(ctx, accountID) - }) - return - } - b.next.Reset(time.Duration(am.updateAccountPeersBufferInterval.Load())) - }() + _ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID) } // UpdateAccountPeer updates a single peer that belongs to an account. // Should be called when changes need to be synced to a specific peer only. func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) { - if !am.peersUpdateManager.HasChannel(peerId) { - log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peerId) - return - } - - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId) - if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peer %s. failed to get account: %v", peerId, err) - return - } - - peer := account.GetPeer(peerId) - if peer == nil { - log.WithContext(ctx).Tracef("peer %s doesn't exists in account %s", peerId, accountId) - return - } - - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) - if err != nil { - log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peerId, err) - return - } - - dnsCache := &DNSConfigCache{} - dnsDomain := am.GetDNSDomain(account.Settings) - customZone := account.GetPeersCustomZone(ctx, dnsDomain) - resourcePolicies := account.GetResourcePoliciesMap() - routers := account.GetResourceRoutersMap() - - postureChecks, err := am.getPeerPostureChecks(account, peerId) - if err != nil { - log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err) - return - } - - proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMaps(ctx, accountId, peerId, account.Peers) - if err != nil { - log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) - return - } - - var remotePeerNetworkMap *types.NetworkMap - - if am.experimentalNetworkMap(accountId) { - remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) - } else { - remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) - } - - proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] - if ok { - remotePeerNetworkMap.Merge(proxyNetworkMap) - } - - extraSettings, err := am.settingsManager.GetExtraSettings(ctx, peer.AccountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get extra settings: %v", err) - return - } - - peerGroups := account.GetPeerGroups(peerId) - dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion) - - update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) - am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update}) + _ = am.networkMapController.UpdateAccountPeer(ctx, accountId, peerId) } // getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. @@ -1594,14 +1237,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto if err != nil { return nil, err } - dnsDomain := am.GetDNSDomain(settings) - - network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) - if err != nil { - return nil, err - } - - dnsFwdPort := computeForwarderPort(peers, dnsForwarderPortMinVersion) + dnsDomain := am.networkMapController.GetDNSDomain(settings) for _, peer := range peers { if err := transaction.RemovePeerFromAllGroups(ctx, peer.ID); err != nil { @@ -1635,24 +1271,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil { return nil, err } - - am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{ - Update: &proto.SyncResponse{ - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - NetworkMap: &proto.NetworkMap{ - Serial: network.CurrentSerial(), - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - FirewallRules: []*proto.FirewallRule{}, - FirewallRulesIsEmpty: true, - DNSConfig: &proto.DNSConfig{ - ForwarderPort: dnsFwdPort, - }, - }, - }, - }) - am.peersUpdateManager.CloseChannel(ctx, peer.ID) peerDeletedEvents = append(peerDeletedEvents, func() { am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain)) }) @@ -1661,14 +1279,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto return peerDeletedEvents, nil } -func ConvertSliceToMap(existingLabels []string) map[string]struct{} { - labelMap := make(map[string]struct{}, len(existingLabels)) - for _, label := range existingLabels { - labelMap[label] = struct{}{} - } - return labelMap -} - // validatePeerDelete checks if the peer can be deleted. func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, transaction store.Store, accountId, peerId string) error { linkedInIngressPorts, err := am.proxyController.IsPeerInIngressPorts(ctx, accountId, peerId) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index e151f5abb..95c609595 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -13,7 +13,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "testing" "time" @@ -25,10 +24,14 @@ import ( "golang.org/x/exp/maps" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" - "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/shared/management/status" @@ -172,12 +175,12 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { } func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testGetNetworkMapGeneral(t) } func testGetNetworkMapGeneral(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -249,7 +252,7 @@ func testGetNetworkMapGeneral(t *testing.T) { func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { // TODO: disable until we start use policy again t.Skip() - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -426,7 +429,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { } func TestAccountManager_GetPeerNetwork(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -487,7 +490,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { } func TestDefaultAccountManager_GetPeer(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -674,7 +677,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -742,12 +745,12 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { } } -func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, string, string, error) { +func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccountManager, *update_channel.PeersUpdateManager, string, string, error) { b.Helper() - manager, err := createManager(b) + manager, updateManager, err := createManager(b) if err != nil { - return nil, "", "", err + return nil, nil, "", "", err } accountID := "test_account" @@ -798,7 +801,7 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou ips := account.GetTakenIPs() peerIP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { - return nil, "", "", err + return nil, nil, "", "", err } peerKey, _ := wgtypes.GeneratePrivateKey() @@ -904,10 +907,10 @@ func setupTestAccountManager(b testing.TB, peers int, groups int) (*DefaultAccou err = manager.Store.SaveAccount(context.Background(), account) if err != nil { - return nil, "", "", err + return nil, nil, "", "", err } - return manager, accountID, regularUser, nil + return manager, updateManager, accountID, regularUser, nil } func BenchmarkGetPeers(b *testing.B) { @@ -928,7 +931,7 @@ func BenchmarkGetPeers(b *testing.B) { defer log.SetOutput(os.Stderr) for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { - manager, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups) + manager, _, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) } @@ -968,7 +971,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { for _, bc := range benchCases { b.Run(bc.name, func(b *testing.B) { - manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + manager, updateManager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) if err != nil { b.Fatalf("Failed to setup test account manager: %v", err) } @@ -980,14 +983,10 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { b.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) - for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + updateManager.CreateChannel(ctx, peerID) } - manager.peersUpdateManager.peerChannels = peerChannels - b.ResetTimer() start := time.Now() @@ -1013,7 +1012,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { } func TestUpdateAccountPeers_Experimental(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") testUpdateAccountPeers(t) } @@ -1037,7 +1036,7 @@ func testUpdateAccountPeers(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - manager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups) + manager, updateManager, accountID, _, err := setupTestAccountManager(t, tc.peers, tc.groups) if err != nil { t.Fatalf("Failed to setup test account manager: %v", err) } @@ -1049,13 +1048,12 @@ func testUpdateAccountPeers(t *testing.T) { t.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) + peerChannels := make(map[string]chan *network_map.UpdateMessage) for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + peerChannels[peerID] = updateManager.CreateChannel(ctx, peerID) } - manager.peersUpdateManager.peerChannels = peerChannels manager.UpdateAccountPeers(ctx, account.Id) for _, channel := range peerChannels { @@ -1097,7 +1095,7 @@ func TestToSyncResponse(t *testing.T) { DNSLabel: "peer1", SSHKey: "peer1-ssh-key", } - turnRelayToken := &Token{ + turnRelayToken := &grpc.Token{ Payload: "turn-user", Signature: "turn-pass", } @@ -1177,9 +1175,9 @@ func TestToSyncResponse(t *testing.T) { }, }, } - dnsCache := &DNSConfigCache{} + dnsCache := &cache.DNSConfigCache{} accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true} - response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort)) + response := grpc.ToSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort)) assert.NotNil(t, response) // assert peer config @@ -1289,7 +1287,12 @@ func Test_RegisterPeerByUser(t *testing.T) { settingsMockManager := settings.NewMockManager(ctrl) permissionsManager := permissions.NewManager(s) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, s) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + + am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1369,7 +1372,12 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { AnyTimes() permissionsManager := permissions.NewManager(s) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, s) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + + am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1517,7 +1525,12 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { permissionsManager := permissions.NewManager(s) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, s) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + + am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1566,7 +1579,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { } func Test_LoginPeer(t *testing.T) { - t.Setenv(envNewNetworkMapBuilder, "true") + t.Setenv(network_map.EnvNewNetworkMapBuilder, "true") if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } @@ -1592,7 +1605,12 @@ func Test_LoginPeer(t *testing.T) { AnyTimes() permissionsManager := permissions.NewManager(s) - am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, s) + networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock()) + + am, err := BuildManager(context.Background(), s, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1725,7 +1743,7 @@ func Test_LoginPeer(t *testing.T) { } func TestPeerAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID) require.NoError(t, err) @@ -1782,13 +1800,14 @@ func TestPeerAccountPeersUpdate(t *testing.T) { var peer5 *nbpeer.Peer var peer6 *nbpeer.Peer - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) // Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) { + t.Skip("Currently all updates will trigger a network map") done := make(chan struct{}) go func() { peerShouldNotReceiveUpdate(t, updMsg) @@ -1890,6 +1909,8 @@ func TestPeerAccountPeersUpdate(t *testing.T) { }) t.Run("validator requires no update", func(t *testing.T) { + t.Skip("Currently all updates will trigger a network map") + requireNoUpdateFunc := func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) { return update, false, nil } @@ -2091,7 +2112,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { } func Test_DeletePeer(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -2188,7 +2209,7 @@ func Test_IsUniqueConstraintError(t *testing.T) { } func Test_AddPeer(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -2276,136 +2297,8 @@ func Test_AddPeer(t *testing.T) { assert.Equal(t, uint64(totalPeers), account.Network.Serial) } -func TestBufferUpdateAccountPeers(t *testing.T) { - const ( - peersCount = 1000 - updateAccountInterval = 50 * time.Millisecond - ) - - var ( - deletedPeers, updatePeersDeleted, updatePeersRuns atomic.Int32 - uapLastRun, dpLastRun atomic.Int64 - - totalNewRuns, totalOldRuns int - ) - - uap := func(ctx context.Context, accountID string) { - updatePeersDeleted.Store(deletedPeers.Load()) - updatePeersRuns.Add(1) - uapLastRun.Store(time.Now().UnixMilli()) - time.Sleep(100 * time.Millisecond) - } - - t.Run("new approach", func(t *testing.T) { - updatePeersRuns.Store(0) - updatePeersDeleted.Store(0) - deletedPeers.Store(0) - - var mustore sync.Map - bufupd := func(ctx context.Context, accountID string) { - mu, _ := mustore.LoadOrStore(accountID, &bufferUpdate{}) - b := mu.(*bufferUpdate) - - if !b.mu.TryLock() { - b.update.Store(true) - return - } - - if b.next != nil { - b.next.Stop() - } - - go func() { - defer b.mu.Unlock() - uap(ctx, accountID) - if !b.update.Load() { - return - } - b.update.Store(false) - b.next = time.AfterFunc(updateAccountInterval, func() { - uap(ctx, accountID) - }) - }() - } - dp := func(ctx context.Context, accountID, peerID, userID string) error { - deletedPeers.Add(1) - dpLastRun.Store(time.Now().UnixMilli()) - time.Sleep(10 * time.Millisecond) - bufupd(ctx, accountID) - return nil - } - - am := mock_server.MockAccountManager{ - UpdateAccountPeersFunc: uap, - BufferUpdateAccountPeersFunc: bufupd, - DeletePeerFunc: dp, - } - empty := "" - for range peersCount { - //nolint - am.DeletePeer(context.Background(), empty, empty, empty) - } - time.Sleep(100 * time.Millisecond) - - assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") - assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") - assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") - - totalNewRuns = int(updatePeersRuns.Load()) - }) - - t.Run("old approach", func(t *testing.T) { - updatePeersRuns.Store(0) - updatePeersDeleted.Store(0) - deletedPeers.Store(0) - - var mustore sync.Map - bufupd := func(ctx context.Context, accountID string) { - mu, _ := mustore.LoadOrStore(accountID, &sync.Mutex{}) - b := mu.(*sync.Mutex) - - if !b.TryLock() { - return - } - - go func() { - time.Sleep(updateAccountInterval) - b.Unlock() - uap(ctx, accountID) - }() - } - dp := func(ctx context.Context, accountID, peerID, userID string) error { - deletedPeers.Add(1) - dpLastRun.Store(time.Now().UnixMilli()) - time.Sleep(10 * time.Millisecond) - bufupd(ctx, accountID) - return nil - } - - am := mock_server.MockAccountManager{ - UpdateAccountPeersFunc: uap, - BufferUpdateAccountPeersFunc: bufupd, - DeletePeerFunc: dp, - } - empty := "" - for range peersCount { - //nolint - am.DeletePeer(context.Background(), empty, empty, empty) - } - time.Sleep(100 * time.Millisecond) - - assert.Equal(t, peersCount, int(deletedPeers.Load()), "Expected all peers to be deleted") - assert.Equal(t, peersCount, int(updatePeersDeleted.Load()), "Expected all peers to be updated in the buffer") - assert.GreaterOrEqual(t, uapLastRun.Load(), dpLastRun.Load(), "Expected update account peers to run after delete peer") - - totalOldRuns = int(updatePeersRuns.Load()) - }) - assert.Less(t, totalNewRuns, totalOldRuns, "Expected new approach to run less than old approach. New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) - t.Logf("New runs: %d, Old runs: %d", totalNewRuns, totalOldRuns) -} - func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -2442,7 +2335,7 @@ func TestAddPeer_UserPendingApprovalBlocked(t *testing.T) { } func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -2476,7 +2369,7 @@ func TestAddPeer_ApprovedUserCanAddPeers(t *testing.T) { } func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -2541,7 +2434,7 @@ func TestLoginPeer_UserPendingApprovalBlocked(t *testing.T) { } func TestLoginPeer_ApprovedUserCanLogin(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } diff --git a/management/server/policy.go b/management/server/policy.go index ff02d46aa..3e84c3d10 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -10,7 +10,6 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/posture" @@ -77,9 +76,6 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return nil, err - } am.UpdateAccountPeers(ctx, accountID) } @@ -123,9 +119,6 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -258,31 +251,3 @@ func getValidGroupIDs(groups map[string]*types.Group, groupIDs []string) []strin return validIDs } - -// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. -func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule { - result := make([]*proto.FirewallRule, len(rules)) - for i := range rules { - rule := rules[i] - - fwRule := &proto.FirewallRule{ - PolicyID: []byte(rule.PolicyID), - PeerIP: rule.PeerIP, - Direction: getProtoDirection(rule.Direction), - Action: getProtoAction(rule.Action), - Protocol: getProtoProtocol(rule.Protocol), - Port: rule.Port, - } - - if shouldUsePortRange(fwRule) { - fwRule.PortInfo = rule.PortRange.ToProto() - } - - result[i] = fwRule - } - return result -} - -func shouldUsePortRange(rule *proto.FirewallRule) bool { - return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP) -} diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 97ebbcf5a..90fe8f036 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -1135,7 +1135,7 @@ func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int { } func TestPolicyAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) g := []*types.Group{ { @@ -1164,9 +1164,9 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { assert.NoError(t, err) } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) var policyWithGroupRulesNoPeers *types.Policy diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index f457b994b..ac8ea35de 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -2,19 +2,15 @@ package server import ( "context" - "errors" - "fmt" "slices" "github.com/rs/xid" - "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/status" ) @@ -80,9 +76,6 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return nil, err - } am.UpdateAccountPeers(ctx, accountID) } @@ -139,27 +132,6 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID) } -// getPeerPostureChecks returns the posture checks applied for a given peer. -func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) { - peerPostureChecks := make(map[string]*posture.Checks) - - if len(account.PostureChecks) == 0 { - return nil, nil - } - - for _, policy := range account.Policies { - if !policy.Enabled || len(policy.SourcePostureChecks) == 0 { - continue - } - - if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil { - return nil, err - } - } - - return maps.Values(peerPostureChecks), nil -} - // arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) { policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) @@ -214,50 +186,6 @@ func validatePostureChecks(ctx context.Context, transaction store.Store, account return nil } -// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups. -func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error { - isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy) - if err != nil { - return err - } - - if !isInGroup { - return nil - } - - for _, sourcePostureCheckID := range policy.SourcePostureChecks { - postureCheck := account.GetPostureChecks(sourcePostureCheckID) - if postureCheck == nil { - return errors.New("failed to add policy posture checks: posture checks not found") - } - peerPostureChecks[sourcePostureCheckID] = postureCheck - } - - return nil -} - -// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. -func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) { - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - - for _, sourceGroup := range rule.Sources { - group := account.GetGroup(sourceGroup) - if group == nil { - return false, fmt.Errorf("failed to check peer in policy source group: group not found") - } - - if slices.Contains(group.Peers, peerID) { - return true, nil - } - } - } - - return false, nil -} - // isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error { policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 67760d55a..13152ed12 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -21,7 +21,7 @@ const ( ) func TestDefaultAccountManager_PostureCheck(t *testing.T) { - am, err := createManager(t) + am, _, err := createManager(t) if err != nil { t.Error("failed to create account manager") } @@ -123,7 +123,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, er } func TestPostureCheckAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) g := []*types.Group{ { @@ -147,9 +147,9 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { assert.NoError(t, err) } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) postureCheckA := &posture.Checks{ @@ -359,9 +359,9 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked posture check to policy where destination has peers but source does not // should trigger account peers update and send peer update t.Run("updating linked posture check to policy where destination has peers but source does not", func(t *testing.T) { - updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID) + updMsg1 := updateManager.CreateChannel(context.Background(), peer2.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) + updateManager.CloseChannel(context.Background(), peer2.ID) }) _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ @@ -445,7 +445,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } func TestArePostureCheckChangesAffectPeers(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) require.NoError(t, err, "failed to create account manager") account, err := initTestPostureChecksAccount(manager) diff --git a/management/server/route.go b/management/server/route.go index 05f7acf9e..2b4f11d05 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -16,7 +16,6 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/status" ) @@ -192,9 +191,6 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return nil, err - } am.UpdateAccountPeers(ctx, accountID) } @@ -249,9 +245,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) if oldRouteAffectsPeers || newRouteAffectsPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -295,9 +288,6 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) if updateAccountPeers { - if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { - return err - } am.UpdateAccountPeers(ctx, accountID) } @@ -381,103 +371,12 @@ func validateRouteGroups(ctx context.Context, transaction store.Store, accountID return groupsMap, nil } -func toProtocolRoute(route *route.Route) *proto.Route { - return &proto.Route{ - ID: string(route.ID), - NetID: string(route.NetID), - Network: route.Network.String(), - Domains: route.Domains.ToPunycodeList(), - NetworkType: int64(route.NetworkType), - Peer: route.Peer, - Metric: int64(route.Metric), - Masquerade: route.Masquerade, - KeepRoute: route.KeepRoute, - SkipAutoApply: route.SkipAutoApply, - } -} - -func toProtocolRoutes(routes []*route.Route) []*proto.Route { - protoRoutes := make([]*proto.Route, 0, len(routes)) - for _, r := range routes { - protoRoutes = append(protoRoutes, toProtocolRoute(r)) - } - return protoRoutes -} - // getPlaceholderIP returns a placeholder IP address for the route if domains are used func getPlaceholderIP() netip.Prefix { // Using an IP from the documentation range to minimize impact in case older clients try to set a route return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) } -func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule { - result := make([]*proto.RouteFirewallRule, len(rules)) - for i := range rules { - rule := rules[i] - result[i] = &proto.RouteFirewallRule{ - SourceRanges: rule.SourceRanges, - Action: getProtoAction(rule.Action), - Destination: rule.Destination, - Protocol: getProtoProtocol(rule.Protocol), - PortInfo: getProtoPortInfo(rule), - IsDynamic: rule.IsDynamic, - Domains: rule.Domains.ToPunycodeList(), - PolicyID: []byte(rule.PolicyID), - RouteID: string(rule.RouteID), - } - } - - return result -} - -// getProtoDirection converts the direction to proto.RuleDirection. -func getProtoDirection(direction int) proto.RuleDirection { - if direction == types.FirewallRuleDirectionOUT { - return proto.RuleDirection_OUT - } - return proto.RuleDirection_IN -} - -// getProtoAction converts the action to proto.RuleAction. -func getProtoAction(action string) proto.RuleAction { - if action == string(types.PolicyTrafficActionDrop) { - return proto.RuleAction_DROP - } - return proto.RuleAction_ACCEPT -} - -// getProtoProtocol converts the protocol to proto.RuleProtocol. -func getProtoProtocol(protocol string) proto.RuleProtocol { - switch types.PolicyRuleProtocolType(protocol) { - case types.PolicyRuleProtocolALL: - return proto.RuleProtocol_ALL - case types.PolicyRuleProtocolTCP: - return proto.RuleProtocol_TCP - case types.PolicyRuleProtocolUDP: - return proto.RuleProtocol_UDP - case types.PolicyRuleProtocolICMP: - return proto.RuleProtocol_ICMP - default: - return proto.RuleProtocol_UNKNOWN - } -} - -// getProtoPortInfo converts the port info to proto.PortInfo. -func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo { - var portInfo proto.PortInfo - if rule.Port != 0 { - portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)} - } else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 { - portInfo.PortSelection = &proto.PortInfo_Range_{ - Range: &proto.PortInfo_Range{ - Start: uint32(portRange.Start), - End: uint32(portRange.End), - }, - } - } - return &portInfo -} - // areRouteChangesAffectPeers checks if a given route affects peers by determining // if it has a routing peer, distribution, or peer groups that include peers. func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) { diff --git a/management/server/route_test.go b/management/server/route_test.go index 388db140c..27fe033c8 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -14,6 +14,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -432,7 +434,7 @@ func TestCreateRoute(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - am, err := createRouterManager(t) + am, _, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } @@ -922,7 +924,7 @@ func TestSaveRoute(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - am, err := createRouterManager(t) + am, _, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } @@ -1024,7 +1026,7 @@ func TestDeleteRoute(t *testing.T) { Enabled: true, } - am, err := createRouterManager(t) + am, _, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } @@ -1071,7 +1073,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { AccessControlGroups: []string{routeGroup1}, } - am, err := createRouterManager(t) + am, _, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } @@ -1163,7 +1165,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { AccessControlGroups: []string{routeGroup1}, } - am, err := createRouterManager(t) + am, _, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } @@ -1250,11 +1252,11 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.Len(t, peer1DeletedRoute.Routes, 0, "we should receive one route for peer1") } -func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { +func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel.PeersUpdateManager, error) { t.Helper() store, err := createRouterStore(t) if err != nil { - return nil, err + return nil, nil, err } eventStore := &activity.InMemoryEventStore{} @@ -1285,7 +1287,16 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { permissionsManager := permissions.NewManager(store) - return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + + am, err := BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + if err != nil { + return nil, nil, err + } + return am, updateManager, nil } func createRouterStore(t *testing.T) (store.Store, error) { @@ -1948,7 +1959,7 @@ func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFi } func TestRouteAccountPeersUpdate(t *testing.T) { - manager, err := createRouterManager(t) + manager, updateManager, err := createRouterManager(t) require.NoError(t, err, "failed to create account manager") account, err := initTestRouteAccount(t, manager) @@ -1976,9 +1987,9 @@ func TestRouteAccountPeersUpdate(t *testing.T) { require.NoError(t, err, "failed to create group %s", group.Name) } - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID) + updateManager.CloseChannel(context.Background(), peer1ID) }) // Creating a route with no routing peer and no peers in PeerGroups or Groups should not update account peers and not send peer update diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index e55b33c94..bc361bbd7 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -18,7 +18,7 @@ import ( ) func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -93,7 +93,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { } func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -198,7 +198,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { } func TestGetSetupKeys(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -396,7 +396,7 @@ func TestSetupKey_Copy(t *testing.T) { } func TestSetupKeyAccountPeersUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", @@ -420,9 +420,9 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) require.NoError(t, err) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) var setupKey *types.SetupKey @@ -465,7 +465,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { } func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } diff --git a/management/server/user.go b/management/server/user.go index 66bea314f..be4e491a8 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -965,7 +965,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if err != nil { return err } - dnsDomain := am.GetDNSDomain(settings) + dnsDomain := am.networkMapController.GetDNSDomain(settings) var peerIDs []string for _, peer := range peers { @@ -992,16 +992,13 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou activity.PeerLoginExpired, peer.EventMeta(dnsDomain), ) - if am.experimentalNetworkMap(accountID) { - am.updatePeerInNetworkMapCache(peer.AccountID, peer) - } + am.networkMapController.OnPeerUpdated(accountID, peer) } if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service log.Debugf("Expiring %d peers for account %s", len(peerIDs), accountID) - am.peersUpdateManager.CloseChannels(ctx, peerIDs) - am.BufferUpdateAccountPeers(ctx, accountID) + am.networkMapController.DisconnectPeers(ctx, peerIDs) } return nil } @@ -1115,6 +1112,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI var addPeerRemovedEvents []func() var updateAccountPeers bool + var userPeers []*nbpeer.Peer var targetUser *types.User var err error @@ -1124,7 +1122,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI return fmt.Errorf("failed to get user to delete: %w", err) } - userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID) + userPeers, err = transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID) if err != nil { return fmt.Errorf("failed to get user peers: %w", err) } @@ -1147,6 +1145,17 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI return false, err } + for _, peer := range userPeers { + err = am.networkMapController.DeletePeer(ctx, accountID, peer.ID) + if err != nil { + log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peer.ID, err) + } + + if err := am.networkMapController.OnPeerDeleted(ctx, accountID, peer.ID); err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peer.ID, err) + } + } + for _, addPeerRemovedEvent := range addPeerRemovedEvents { addPeerRemovedEvent() } diff --git a/management/server/user_test.go b/management/server/user_test.go index 5920a2a33..69b8c85ee 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1161,7 +1161,7 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) { } func TestDefaultAccountManager_SaveUser(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) return @@ -1333,7 +1333,7 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { func TestUserAccountPeersUpdate(t *testing.T) { // account groups propagation is enabled - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + manager, updateManager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) err := manager.CreateGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", @@ -1357,9 +1357,9 @@ func TestUserAccountPeersUpdate(t *testing.T) { _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) require.NoError(t, err) - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + updMsg := updateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + updateManager.CloseChannel(context.Background(), peer1.ID) }) // Creating a new regular user should not update account peers and not send peer update @@ -1468,9 +1468,9 @@ func TestUserAccountPeersUpdate(t *testing.T) { } }) - peer4UpdMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer4.ID) + peer4UpdMsg := updateManager.CreateChannel(context.Background(), peer4.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID) + updateManager.CloseChannel(context.Background(), peer4.ID) }) // deleting user with linked peers should update account peers and send peer update @@ -1748,7 +1748,7 @@ func mergeRolePermissions(role roles.RolePermissions) roles.Permissions { } func TestApproveUser(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } @@ -1807,7 +1807,7 @@ func TestApproveUser(t *testing.T) { } func TestRejectUser(t *testing.T) { - manager, err := createManager(t) + manager, _, err := createManager(t) if err != nil { t.Fatal(err) } diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index d4a9f1823..d3f341529 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -18,6 +18,9 @@ import ( "google.golang.org/grpc/status" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/encryption" @@ -68,7 +71,6 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { } t.Cleanup(cleanUp) - peersUpdateManager := mgmt.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} ctrl := gomock.NewController(t) @@ -111,15 +113,19 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { Return(&types.ExtraSettings{}, nil). AnyTimes() - accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + ctx := context.Background() + updateManager := update_channel.NewPeersUpdateManager(metrics) + requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock()) + accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) if err != nil { t.Fatal(err) } groupsManager := groups.NewManagerMock() - secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, mgmt.MockIntegratedValidator{}) + secretsManager := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) + mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, updateManager, secretsManager, nil, &manager.EphemeralManager{}, nil, mgmt.MockIntegratedValidator{}, networkMapController) if err != nil { t.Fatal(err) }